##// END OF EJS Templates
General improvements to database backend...
MinRK -
Show More
@@ -0,0 +1,182 b''
1 """Tests for db backends"""
2
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
9
10 #-------------------------------------------------------------------------------
11 # Imports
12 #-------------------------------------------------------------------------------
13
14
15 import tempfile
16 import time
17
18 import uuid
19
20 from datetime import datetime, timedelta
21 from random import choice, randint
22 from unittest import TestCase
23
24 from nose import SkipTest
25
26 from IPython.parallel import error, streamsession as ss
27 from IPython.parallel.controller.dictdb import DictDB
28 from IPython.parallel.controller.sqlitedb import SQLiteDB
29 from IPython.parallel.controller.hub import init_record, empty_record
30
31 #-------------------------------------------------------------------------------
32 # TestCases
33 #-------------------------------------------------------------------------------
34
35 class TestDictBackend(TestCase):
36 def setUp(self):
37 self.session = ss.StreamSession()
38 self.db = self.create_db()
39 self.load_records(16)
40
41 def create_db(self):
42 return DictDB()
43
44 def load_records(self, n=1):
45 """load n records for testing"""
46 #sleep 1/10 s, to ensure timestamp is different to previous calls
47 time.sleep(0.1)
48 msg_ids = []
49 for i in range(n):
50 msg = self.session.msg('apply_request', content=dict(a=5))
51 msg['buffers'] = []
52 rec = init_record(msg)
53 msg_ids.append(msg['msg_id'])
54 self.db.add_record(msg['msg_id'], rec)
55 return msg_ids
56
57 def test_add_record(self):
58 before = self.db.get_history()
59 self.load_records(5)
60 after = self.db.get_history()
61 self.assertEquals(len(after), len(before)+5)
62 self.assertEquals(after[:-5],before)
63
64 def test_drop_record(self):
65 msg_id = self.load_records()[-1]
66 rec = self.db.get_record(msg_id)
67 self.db.drop_record(msg_id)
68 self.assertRaises(KeyError,self.db.get_record, msg_id)
69
70 def _round_to_millisecond(self, dt):
71 """necessary because mongodb rounds microseconds"""
72 micro = dt.microsecond
73 extra = int(str(micro)[-3:])
74 return dt - timedelta(microseconds=extra)
75
76 def test_update_record(self):
77 now = self._round_to_millisecond(datetime.now())
78 #
79 msg_id = self.db.get_history()[-1]
80 rec1 = self.db.get_record(msg_id)
81 data = {'stdout': 'hello there', 'completed' : now}
82 self.db.update_record(msg_id, data)
83 rec2 = self.db.get_record(msg_id)
84 self.assertEquals(rec2['stdout'], 'hello there')
85 self.assertEquals(rec2['completed'], now)
86 rec1.update(data)
87 self.assertEquals(rec1, rec2)
88
89 # def test_update_record_bad(self):
90 # """test updating nonexistant records"""
91 # msg_id = str(uuid.uuid4())
92 # data = {'stdout': 'hello there'}
93 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
94
95 def test_find_records_dt(self):
96 """test finding records by date"""
97 hist = self.db.get_history()
98 middle = self.db.get_record(hist[len(hist)/2])
99 tic = middle['submitted']
100 before = self.db.find_records({'submitted' : {'$lt' : tic}})
101 after = self.db.find_records({'submitted' : {'$gte' : tic}})
102 self.assertEquals(len(before)+len(after),len(hist))
103 for b in before:
104 self.assertTrue(b['submitted'] < tic)
105 for a in after:
106 self.assertTrue(a['submitted'] >= tic)
107 same = self.db.find_records({'submitted' : tic})
108 for s in same:
109 self.assertTrue(s['submitted'] == tic)
110
111 def test_find_records_keys(self):
112 """test extracting subset of record keys"""
113 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
114 for rec in found:
115 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
116
117 def test_find_records_msg_id(self):
118 """ensure msg_id is always in found records"""
119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
120 for rec in found:
121 self.assertTrue('msg_id' in rec.keys())
122 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
123 for rec in found:
124 self.assertTrue('msg_id' in rec.keys())
125 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
126 for rec in found:
127 self.assertTrue('msg_id' in rec.keys())
128
129 def test_find_records_in(self):
130 """test finding records with '$in','$nin' operators"""
131 hist = self.db.get_history()
132 even = hist[::2]
133 odd = hist[1::2]
134 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
135 found = [ r['msg_id'] for r in recs ]
136 self.assertEquals(set(even), set(found))
137 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
138 found = [ r['msg_id'] for r in recs ]
139 self.assertEquals(set(odd), set(found))
140
141 def test_get_history(self):
142 msg_ids = self.db.get_history()
143 latest = datetime(1984,1,1)
144 for msg_id in msg_ids:
145 rec = self.db.get_record(msg_id)
146 newt = rec['submitted']
147 self.assertTrue(newt >= latest)
148 latest = newt
149 msg_id = self.load_records(1)[-1]
150 self.assertEquals(self.db.get_history()[-1],msg_id)
151
152 def test_datetime(self):
153 """get/set timestamps with datetime objects"""
154 msg_id = self.db.get_history()[-1]
155 rec = self.db.get_record(msg_id)
156 self.assertTrue(isinstance(rec['submitted'], datetime))
157 self.db.update_record(msg_id, dict(completed=datetime.now()))
158 rec = self.db.get_record(msg_id)
159 self.assertTrue(isinstance(rec['completed'], datetime))
160
161 class TestSQLiteBackend(TestDictBackend):
162 def create_db(self):
163 return SQLiteDB(location=tempfile.gettempdir())
164
165 def tearDown(self):
166 self.db._db.close()
167
168 # optional MongoDB test
169 try:
170 from IPython.parallel.controller.mongodb import MongoDB
171 except ImportError:
172 pass
173 else:
174 class TestMongoBackend(TestDictBackend):
175 def create_db(self):
176 try:
177 return MongoDB(database='iptestdb')
178 except Exception:
179 raise SkipTest("Couldn't connect to mongodb instance")
180
181 def tearDown(self):
182 self.db._connection.drop_database('iptestdb')
@@ -1,1219 +1,1292 b''
1 """A semi-synchronous Client for the ZMQ cluster"""
1 """A semi-synchronous Client for the ZMQ cluster"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 import os
13 import os
14 import json
14 import json
15 import time
15 import time
16 import warnings
16 import warnings
17 from datetime import datetime
17 from datetime import datetime
18 from getpass import getpass
18 from getpass import getpass
19 from pprint import pprint
19 from pprint import pprint
20
20
21 pjoin = os.path.join
21 pjoin = os.path.join
22
22
23 import zmq
23 import zmq
24 # from zmq.eventloop import ioloop, zmqstream
24 # from zmq.eventloop import ioloop, zmqstream
25
25
26 from IPython.utils.path import get_ipython_dir
26 from IPython.utils.path import get_ipython_dir
27 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
27 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
28 Dict, List, Bool, Str, Set)
28 Dict, List, Bool, Str, Set)
29 from IPython.external.decorator import decorator
29 from IPython.external.decorator import decorator
30 from IPython.external.ssh import tunnel
30 from IPython.external.ssh import tunnel
31
31
32 from IPython.parallel import error
32 from IPython.parallel import error
33 from IPython.parallel import streamsession as ss
33 from IPython.parallel import streamsession as ss
34 from IPython.parallel import util
34 from IPython.parallel import util
35
35
36 from .asyncresult import AsyncResult, AsyncHubResult
36 from .asyncresult import AsyncResult, AsyncHubResult
37 from IPython.parallel.apps.clusterdir import ClusterDir, ClusterDirError
37 from IPython.parallel.apps.clusterdir import ClusterDir, ClusterDirError
38 from .view import DirectView, LoadBalancedView
38 from .view import DirectView, LoadBalancedView
39
39
40 #--------------------------------------------------------------------------
40 #--------------------------------------------------------------------------
41 # Decorators for Client methods
41 # Decorators for Client methods
42 #--------------------------------------------------------------------------
42 #--------------------------------------------------------------------------
43
43
44 @decorator
44 @decorator
45 def spin_first(f, self, *args, **kwargs):
45 def spin_first(f, self, *args, **kwargs):
46 """Call spin() to sync state prior to calling the method."""
46 """Call spin() to sync state prior to calling the method."""
47 self.spin()
47 self.spin()
48 return f(self, *args, **kwargs)
48 return f(self, *args, **kwargs)
49
49
50
50
51 #--------------------------------------------------------------------------
51 #--------------------------------------------------------------------------
52 # Classes
52 # Classes
53 #--------------------------------------------------------------------------
53 #--------------------------------------------------------------------------
54
54
55 class Metadata(dict):
55 class Metadata(dict):
56 """Subclass of dict for initializing metadata values.
56 """Subclass of dict for initializing metadata values.
57
57
58 Attribute access works on keys.
58 Attribute access works on keys.
59
59
60 These objects have a strict set of keys - errors will raise if you try
60 These objects have a strict set of keys - errors will raise if you try
61 to add new keys.
61 to add new keys.
62 """
62 """
63 def __init__(self, *args, **kwargs):
63 def __init__(self, *args, **kwargs):
64 dict.__init__(self)
64 dict.__init__(self)
65 md = {'msg_id' : None,
65 md = {'msg_id' : None,
66 'submitted' : None,
66 'submitted' : None,
67 'started' : None,
67 'started' : None,
68 'completed' : None,
68 'completed' : None,
69 'received' : None,
69 'received' : None,
70 'engine_uuid' : None,
70 'engine_uuid' : None,
71 'engine_id' : None,
71 'engine_id' : None,
72 'follow' : None,
72 'follow' : None,
73 'after' : None,
73 'after' : None,
74 'status' : None,
74 'status' : None,
75
75
76 'pyin' : None,
76 'pyin' : None,
77 'pyout' : None,
77 'pyout' : None,
78 'pyerr' : None,
78 'pyerr' : None,
79 'stdout' : '',
79 'stdout' : '',
80 'stderr' : '',
80 'stderr' : '',
81 }
81 }
82 self.update(md)
82 self.update(md)
83 self.update(dict(*args, **kwargs))
83 self.update(dict(*args, **kwargs))
84
84
85 def __getattr__(self, key):
85 def __getattr__(self, key):
86 """getattr aliased to getitem"""
86 """getattr aliased to getitem"""
87 if key in self.iterkeys():
87 if key in self.iterkeys():
88 return self[key]
88 return self[key]
89 else:
89 else:
90 raise AttributeError(key)
90 raise AttributeError(key)
91
91
92 def __setattr__(self, key, value):
92 def __setattr__(self, key, value):
93 """setattr aliased to setitem, with strict"""
93 """setattr aliased to setitem, with strict"""
94 if key in self.iterkeys():
94 if key in self.iterkeys():
95 self[key] = value
95 self[key] = value
96 else:
96 else:
97 raise AttributeError(key)
97 raise AttributeError(key)
98
98
99 def __setitem__(self, key, value):
99 def __setitem__(self, key, value):
100 """strict static key enforcement"""
100 """strict static key enforcement"""
101 if key in self.iterkeys():
101 if key in self.iterkeys():
102 dict.__setitem__(self, key, value)
102 dict.__setitem__(self, key, value)
103 else:
103 else:
104 raise KeyError(key)
104 raise KeyError(key)
105
105
106
106
107 class Client(HasTraits):
107 class Client(HasTraits):
108 """A semi-synchronous client to the IPython ZMQ cluster
108 """A semi-synchronous client to the IPython ZMQ cluster
109
109
110 Parameters
110 Parameters
111 ----------
111 ----------
112
112
113 url_or_file : bytes; zmq url or path to ipcontroller-client.json
113 url_or_file : bytes; zmq url or path to ipcontroller-client.json
114 Connection information for the Hub's registration. If a json connector
114 Connection information for the Hub's registration. If a json connector
115 file is given, then likely no further configuration is necessary.
115 file is given, then likely no further configuration is necessary.
116 [Default: use profile]
116 [Default: use profile]
117 profile : bytes
117 profile : bytes
118 The name of the Cluster profile to be used to find connector information.
118 The name of the Cluster profile to be used to find connector information.
119 [Default: 'default']
119 [Default: 'default']
120 context : zmq.Context
120 context : zmq.Context
121 Pass an existing zmq.Context instance, otherwise the client will create its own.
121 Pass an existing zmq.Context instance, otherwise the client will create its own.
122 username : bytes
122 username : bytes
123 set username to be passed to the Session object
123 set username to be passed to the Session object
124 debug : bool
124 debug : bool
125 flag for lots of message printing for debug purposes
125 flag for lots of message printing for debug purposes
126
126
127 #-------------- ssh related args ----------------
127 #-------------- ssh related args ----------------
128 # These are args for configuring the ssh tunnel to be used
128 # These are args for configuring the ssh tunnel to be used
129 # credentials are used to forward connections over ssh to the Controller
129 # credentials are used to forward connections over ssh to the Controller
130 # Note that the ip given in `addr` needs to be relative to sshserver
130 # Note that the ip given in `addr` needs to be relative to sshserver
131 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
131 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
132 # and set sshserver as the same machine the Controller is on. However,
132 # and set sshserver as the same machine the Controller is on. However,
133 # the only requirement is that sshserver is able to see the Controller
133 # the only requirement is that sshserver is able to see the Controller
134 # (i.e. is within the same trusted network).
134 # (i.e. is within the same trusted network).
135
135
136 sshserver : str
136 sshserver : str
137 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
137 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
138 If keyfile or password is specified, and this is not, it will default to
138 If keyfile or password is specified, and this is not, it will default to
139 the ip given in addr.
139 the ip given in addr.
140 sshkey : str; path to public ssh key file
140 sshkey : str; path to public ssh key file
141 This specifies a key to be used in ssh login, default None.
141 This specifies a key to be used in ssh login, default None.
142 Regular default ssh keys will be used without specifying this argument.
142 Regular default ssh keys will be used without specifying this argument.
143 password : str
143 password : str
144 Your ssh password to sshserver. Note that if this is left None,
144 Your ssh password to sshserver. Note that if this is left None,
145 you will be prompted for it if passwordless key based login is unavailable.
145 you will be prompted for it if passwordless key based login is unavailable.
146 paramiko : bool
146 paramiko : bool
147 flag for whether to use paramiko instead of shell ssh for tunneling.
147 flag for whether to use paramiko instead of shell ssh for tunneling.
148 [default: True on win32, False else]
148 [default: True on win32, False else]
149
149
150 ------- exec authentication args -------
150 ------- exec authentication args -------
151 If even localhost is untrusted, you can have some protection against
151 If even localhost is untrusted, you can have some protection against
152 unauthorized execution by using a key. Messages are still sent
152 unauthorized execution by using a key. Messages are still sent
153 as cleartext, so if someone can snoop your loopback traffic this will
153 as cleartext, so if someone can snoop your loopback traffic this will
154 not help against malicious attacks.
154 not help against malicious attacks.
155
155
156 exec_key : str
156 exec_key : str
157 an authentication key or file containing a key
157 an authentication key or file containing a key
158 default: None
158 default: None
159
159
160
160
161 Attributes
161 Attributes
162 ----------
162 ----------
163
163
164 ids : list of int engine IDs
164 ids : list of int engine IDs
165 requesting the ids attribute always synchronizes
165 requesting the ids attribute always synchronizes
166 the registration state. To request ids without synchronization,
166 the registration state. To request ids without synchronization,
167 use semi-private _ids attributes.
167 use semi-private _ids attributes.
168
168
169 history : list of msg_ids
169 history : list of msg_ids
170 a list of msg_ids, keeping track of all the execution
170 a list of msg_ids, keeping track of all the execution
171 messages you have submitted in order.
171 messages you have submitted in order.
172
172
173 outstanding : set of msg_ids
173 outstanding : set of msg_ids
174 a set of msg_ids that have been submitted, but whose
174 a set of msg_ids that have been submitted, but whose
175 results have not yet been received.
175 results have not yet been received.
176
176
177 results : dict
177 results : dict
178 a dict of all our results, keyed by msg_id
178 a dict of all our results, keyed by msg_id
179
179
180 block : bool
180 block : bool
181 determines default behavior when block not specified
181 determines default behavior when block not specified
182 in execution methods
182 in execution methods
183
183
184 Methods
184 Methods
185 -------
185 -------
186
186
187 spin
187 spin
188 flushes incoming results and registration state changes
188 flushes incoming results and registration state changes
189 control methods spin, and requesting `ids` also ensures up to date
189 control methods spin, and requesting `ids` also ensures up to date
190
190
191 wait
191 wait
192 wait on one or more msg_ids
192 wait on one or more msg_ids
193
193
194 execution methods
194 execution methods
195 apply
195 apply
196 legacy: execute, run
196 legacy: execute, run
197
197
198 data movement
198 data movement
199 push, pull, scatter, gather
199 push, pull, scatter, gather
200
200
201 query methods
201 query methods
202 queue_status, get_result, purge, result_status
202 queue_status, get_result, purge, result_status
203
203
204 control methods
204 control methods
205 abort, shutdown
205 abort, shutdown
206
206
207 """
207 """
208
208
209
209
210 block = Bool(False)
210 block = Bool(False)
211 outstanding = Set()
211 outstanding = Set()
212 results = Instance('collections.defaultdict', (dict,))
212 results = Instance('collections.defaultdict', (dict,))
213 metadata = Instance('collections.defaultdict', (Metadata,))
213 metadata = Instance('collections.defaultdict', (Metadata,))
214 history = List()
214 history = List()
215 debug = Bool(False)
215 debug = Bool(False)
216 profile=CUnicode('default')
216 profile=CUnicode('default')
217
217
218 _outstanding_dict = Instance('collections.defaultdict', (set,))
218 _outstanding_dict = Instance('collections.defaultdict', (set,))
219 _ids = List()
219 _ids = List()
220 _connected=Bool(False)
220 _connected=Bool(False)
221 _ssh=Bool(False)
221 _ssh=Bool(False)
222 _context = Instance('zmq.Context')
222 _context = Instance('zmq.Context')
223 _config = Dict()
223 _config = Dict()
224 _engines=Instance(util.ReverseDict, (), {})
224 _engines=Instance(util.ReverseDict, (), {})
225 # _hub_socket=Instance('zmq.Socket')
225 # _hub_socket=Instance('zmq.Socket')
226 _query_socket=Instance('zmq.Socket')
226 _query_socket=Instance('zmq.Socket')
227 _control_socket=Instance('zmq.Socket')
227 _control_socket=Instance('zmq.Socket')
228 _iopub_socket=Instance('zmq.Socket')
228 _iopub_socket=Instance('zmq.Socket')
229 _notification_socket=Instance('zmq.Socket')
229 _notification_socket=Instance('zmq.Socket')
230 _mux_socket=Instance('zmq.Socket')
230 _mux_socket=Instance('zmq.Socket')
231 _task_socket=Instance('zmq.Socket')
231 _task_socket=Instance('zmq.Socket')
232 _task_scheme=Str()
232 _task_scheme=Str()
233 _closed = False
233 _closed = False
234 _ignored_control_replies=Int(0)
234 _ignored_control_replies=Int(0)
235 _ignored_hub_replies=Int(0)
235 _ignored_hub_replies=Int(0)
236
236
237 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
237 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
238 context=None, username=None, debug=False, exec_key=None,
238 context=None, username=None, debug=False, exec_key=None,
239 sshserver=None, sshkey=None, password=None, paramiko=None,
239 sshserver=None, sshkey=None, password=None, paramiko=None,
240 timeout=10
240 timeout=10
241 ):
241 ):
242 super(Client, self).__init__(debug=debug, profile=profile)
242 super(Client, self).__init__(debug=debug, profile=profile)
243 if context is None:
243 if context is None:
244 context = zmq.Context.instance()
244 context = zmq.Context.instance()
245 self._context = context
245 self._context = context
246
246
247
247
248 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
248 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
249 if self._cd is not None:
249 if self._cd is not None:
250 if url_or_file is None:
250 if url_or_file is None:
251 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
251 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
252 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
252 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
253 " Please specify at least one of url_or_file or profile."
253 " Please specify at least one of url_or_file or profile."
254
254
255 try:
255 try:
256 util.validate_url(url_or_file)
256 util.validate_url(url_or_file)
257 except AssertionError:
257 except AssertionError:
258 if not os.path.exists(url_or_file):
258 if not os.path.exists(url_or_file):
259 if self._cd:
259 if self._cd:
260 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
260 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
261 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
261 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
262 with open(url_or_file) as f:
262 with open(url_or_file) as f:
263 cfg = json.loads(f.read())
263 cfg = json.loads(f.read())
264 else:
264 else:
265 cfg = {'url':url_or_file}
265 cfg = {'url':url_or_file}
266
266
267 # sync defaults from args, json:
267 # sync defaults from args, json:
268 if sshserver:
268 if sshserver:
269 cfg['ssh'] = sshserver
269 cfg['ssh'] = sshserver
270 if exec_key:
270 if exec_key:
271 cfg['exec_key'] = exec_key
271 cfg['exec_key'] = exec_key
272 exec_key = cfg['exec_key']
272 exec_key = cfg['exec_key']
273 sshserver=cfg['ssh']
273 sshserver=cfg['ssh']
274 url = cfg['url']
274 url = cfg['url']
275 location = cfg.setdefault('location', None)
275 location = cfg.setdefault('location', None)
276 cfg['url'] = util.disambiguate_url(cfg['url'], location)
276 cfg['url'] = util.disambiguate_url(cfg['url'], location)
277 url = cfg['url']
277 url = cfg['url']
278
278
279 self._config = cfg
279 self._config = cfg
280
280
281 self._ssh = bool(sshserver or sshkey or password)
281 self._ssh = bool(sshserver or sshkey or password)
282 if self._ssh and sshserver is None:
282 if self._ssh and sshserver is None:
283 # default to ssh via localhost
283 # default to ssh via localhost
284 sshserver = url.split('://')[1].split(':')[0]
284 sshserver = url.split('://')[1].split(':')[0]
285 if self._ssh and password is None:
285 if self._ssh and password is None:
286 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
286 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
287 password=False
287 password=False
288 else:
288 else:
289 password = getpass("SSH Password for %s: "%sshserver)
289 password = getpass("SSH Password for %s: "%sshserver)
290 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
290 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
291 if exec_key is not None and os.path.isfile(exec_key):
291 if exec_key is not None and os.path.isfile(exec_key):
292 arg = 'keyfile'
292 arg = 'keyfile'
293 else:
293 else:
294 arg = 'key'
294 arg = 'key'
295 key_arg = {arg:exec_key}
295 key_arg = {arg:exec_key}
296 if username is None:
296 if username is None:
297 self.session = ss.StreamSession(**key_arg)
297 self.session = ss.StreamSession(**key_arg)
298 else:
298 else:
299 self.session = ss.StreamSession(username, **key_arg)
299 self.session = ss.StreamSession(username, **key_arg)
300 self._query_socket = self._context.socket(zmq.XREQ)
300 self._query_socket = self._context.socket(zmq.XREQ)
301 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
301 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
302 if self._ssh:
302 if self._ssh:
303 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
303 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
304 else:
304 else:
305 self._query_socket.connect(url)
305 self._query_socket.connect(url)
306
306
307 self.session.debug = self.debug
307 self.session.debug = self.debug
308
308
309 self._notification_handlers = {'registration_notification' : self._register_engine,
309 self._notification_handlers = {'registration_notification' : self._register_engine,
310 'unregistration_notification' : self._unregister_engine,
310 'unregistration_notification' : self._unregister_engine,
311 'shutdown_notification' : lambda msg: self.close(),
311 'shutdown_notification' : lambda msg: self.close(),
312 }
312 }
313 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
313 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
314 'apply_reply' : self._handle_apply_reply}
314 'apply_reply' : self._handle_apply_reply}
315 self._connect(sshserver, ssh_kwargs, timeout)
315 self._connect(sshserver, ssh_kwargs, timeout)
316
316
317 def __del__(self):
317 def __del__(self):
318 """cleanup sockets, but _not_ context."""
318 """cleanup sockets, but _not_ context."""
319 self.close()
319 self.close()
320
320
321 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
321 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
322 if ipython_dir is None:
322 if ipython_dir is None:
323 ipython_dir = get_ipython_dir()
323 ipython_dir = get_ipython_dir()
324 if cluster_dir is not None:
324 if cluster_dir is not None:
325 try:
325 try:
326 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
326 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
327 return
327 return
328 except ClusterDirError:
328 except ClusterDirError:
329 pass
329 pass
330 elif profile is not None:
330 elif profile is not None:
331 try:
331 try:
332 self._cd = ClusterDir.find_cluster_dir_by_profile(
332 self._cd = ClusterDir.find_cluster_dir_by_profile(
333 ipython_dir, profile)
333 ipython_dir, profile)
334 return
334 return
335 except ClusterDirError:
335 except ClusterDirError:
336 pass
336 pass
337 self._cd = None
337 self._cd = None
338
338
339 def _update_engines(self, engines):
339 def _update_engines(self, engines):
340 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
340 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
341 for k,v in engines.iteritems():
341 for k,v in engines.iteritems():
342 eid = int(k)
342 eid = int(k)
343 self._engines[eid] = bytes(v) # force not unicode
343 self._engines[eid] = bytes(v) # force not unicode
344 self._ids.append(eid)
344 self._ids.append(eid)
345 self._ids = sorted(self._ids)
345 self._ids = sorted(self._ids)
346 if sorted(self._engines.keys()) != range(len(self._engines)) and \
346 if sorted(self._engines.keys()) != range(len(self._engines)) and \
347 self._task_scheme == 'pure' and self._task_socket:
347 self._task_scheme == 'pure' and self._task_socket:
348 self._stop_scheduling_tasks()
348 self._stop_scheduling_tasks()
349
349
350 def _stop_scheduling_tasks(self):
350 def _stop_scheduling_tasks(self):
351 """Stop scheduling tasks because an engine has been unregistered
351 """Stop scheduling tasks because an engine has been unregistered
352 from a pure ZMQ scheduler.
352 from a pure ZMQ scheduler.
353 """
353 """
354 self._task_socket.close()
354 self._task_socket.close()
355 self._task_socket = None
355 self._task_socket = None
356 msg = "An engine has been unregistered, and we are using pure " +\
356 msg = "An engine has been unregistered, and we are using pure " +\
357 "ZMQ task scheduling. Task farming will be disabled."
357 "ZMQ task scheduling. Task farming will be disabled."
358 if self.outstanding:
358 if self.outstanding:
359 msg += " If you were running tasks when this happened, " +\
359 msg += " If you were running tasks when this happened, " +\
360 "some `outstanding` msg_ids may never resolve."
360 "some `outstanding` msg_ids may never resolve."
361 warnings.warn(msg, RuntimeWarning)
361 warnings.warn(msg, RuntimeWarning)
362
362
363 def _build_targets(self, targets):
363 def _build_targets(self, targets):
364 """Turn valid target IDs or 'all' into two lists:
364 """Turn valid target IDs or 'all' into two lists:
365 (int_ids, uuids).
365 (int_ids, uuids).
366 """
366 """
367 if not self._ids:
367 if not self._ids:
368 # flush notification socket if no engines yet, just in case
368 # flush notification socket if no engines yet, just in case
369 if not self.ids:
369 if not self.ids:
370 raise error.NoEnginesRegistered("Can't build targets without any engines")
370 raise error.NoEnginesRegistered("Can't build targets without any engines")
371
371
372 if targets is None:
372 if targets is None:
373 targets = self._ids
373 targets = self._ids
374 elif isinstance(targets, str):
374 elif isinstance(targets, str):
375 if targets.lower() == 'all':
375 if targets.lower() == 'all':
376 targets = self._ids
376 targets = self._ids
377 else:
377 else:
378 raise TypeError("%r not valid str target, must be 'all'"%(targets))
378 raise TypeError("%r not valid str target, must be 'all'"%(targets))
379 elif isinstance(targets, int):
379 elif isinstance(targets, int):
380 if targets < 0:
380 if targets < 0:
381 targets = self.ids[targets]
381 targets = self.ids[targets]
382 if targets not in self._ids:
382 if targets not in self._ids:
383 raise IndexError("No such engine: %i"%targets)
383 raise IndexError("No such engine: %i"%targets)
384 targets = [targets]
384 targets = [targets]
385
385
386 if isinstance(targets, slice):
386 if isinstance(targets, slice):
387 indices = range(len(self._ids))[targets]
387 indices = range(len(self._ids))[targets]
388 ids = self.ids
388 ids = self.ids
389 targets = [ ids[i] for i in indices ]
389 targets = [ ids[i] for i in indices ]
390
390
391 if not isinstance(targets, (tuple, list, xrange)):
391 if not isinstance(targets, (tuple, list, xrange)):
392 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
392 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
393
393
394 return [self._engines[t] for t in targets], list(targets)
394 return [self._engines[t] for t in targets], list(targets)
395
395
396 def _connect(self, sshserver, ssh_kwargs, timeout):
396 def _connect(self, sshserver, ssh_kwargs, timeout):
397 """setup all our socket connections to the cluster. This is called from
397 """setup all our socket connections to the cluster. This is called from
398 __init__."""
398 __init__."""
399
399
400 # Maybe allow reconnecting?
400 # Maybe allow reconnecting?
401 if self._connected:
401 if self._connected:
402 return
402 return
403 self._connected=True
403 self._connected=True
404
404
405 def connect_socket(s, url):
405 def connect_socket(s, url):
406 url = util.disambiguate_url(url, self._config['location'])
406 url = util.disambiguate_url(url, self._config['location'])
407 if self._ssh:
407 if self._ssh:
408 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
408 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
409 else:
409 else:
410 return s.connect(url)
410 return s.connect(url)
411
411
412 self.session.send(self._query_socket, 'connection_request')
412 self.session.send(self._query_socket, 'connection_request')
413 r,w,x = zmq.select([self._query_socket],[],[], timeout)
413 r,w,x = zmq.select([self._query_socket],[],[], timeout)
414 if not r:
414 if not r:
415 raise error.TimeoutError("Hub connection request timed out")
415 raise error.TimeoutError("Hub connection request timed out")
416 idents,msg = self.session.recv(self._query_socket,mode=0)
416 idents,msg = self.session.recv(self._query_socket,mode=0)
417 if self.debug:
417 if self.debug:
418 pprint(msg)
418 pprint(msg)
419 msg = ss.Message(msg)
419 msg = ss.Message(msg)
420 content = msg.content
420 content = msg.content
421 self._config['registration'] = dict(content)
421 self._config['registration'] = dict(content)
422 if content.status == 'ok':
422 if content.status == 'ok':
423 if content.mux:
423 if content.mux:
424 self._mux_socket = self._context.socket(zmq.XREQ)
424 self._mux_socket = self._context.socket(zmq.XREQ)
425 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
425 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
426 connect_socket(self._mux_socket, content.mux)
426 connect_socket(self._mux_socket, content.mux)
427 if content.task:
427 if content.task:
428 self._task_scheme, task_addr = content.task
428 self._task_scheme, task_addr = content.task
429 self._task_socket = self._context.socket(zmq.XREQ)
429 self._task_socket = self._context.socket(zmq.XREQ)
430 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
430 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
431 connect_socket(self._task_socket, task_addr)
431 connect_socket(self._task_socket, task_addr)
432 if content.notification:
432 if content.notification:
433 self._notification_socket = self._context.socket(zmq.SUB)
433 self._notification_socket = self._context.socket(zmq.SUB)
434 connect_socket(self._notification_socket, content.notification)
434 connect_socket(self._notification_socket, content.notification)
435 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
435 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
436 # if content.query:
436 # if content.query:
437 # self._query_socket = self._context.socket(zmq.XREQ)
437 # self._query_socket = self._context.socket(zmq.XREQ)
438 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
438 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
439 # connect_socket(self._query_socket, content.query)
439 # connect_socket(self._query_socket, content.query)
440 if content.control:
440 if content.control:
441 self._control_socket = self._context.socket(zmq.XREQ)
441 self._control_socket = self._context.socket(zmq.XREQ)
442 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
442 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
443 connect_socket(self._control_socket, content.control)
443 connect_socket(self._control_socket, content.control)
444 if content.iopub:
444 if content.iopub:
445 self._iopub_socket = self._context.socket(zmq.SUB)
445 self._iopub_socket = self._context.socket(zmq.SUB)
446 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
446 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
447 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
447 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
448 connect_socket(self._iopub_socket, content.iopub)
448 connect_socket(self._iopub_socket, content.iopub)
449 self._update_engines(dict(content.engines))
449 self._update_engines(dict(content.engines))
450 else:
450 else:
451 self._connected = False
451 self._connected = False
452 raise Exception("Failed to connect!")
452 raise Exception("Failed to connect!")
453
453
454 #--------------------------------------------------------------------------
454 #--------------------------------------------------------------------------
455 # handlers and callbacks for incoming messages
455 # handlers and callbacks for incoming messages
456 #--------------------------------------------------------------------------
456 #--------------------------------------------------------------------------
457
457
458 def _unwrap_exception(self, content):
458 def _unwrap_exception(self, content):
459 """unwrap exception, and remap engine_id to int."""
459 """unwrap exception, and remap engine_id to int."""
460 e = error.unwrap_exception(content)
460 e = error.unwrap_exception(content)
461 # print e.traceback
461 # print e.traceback
462 if e.engine_info:
462 if e.engine_info:
463 e_uuid = e.engine_info['engine_uuid']
463 e_uuid = e.engine_info['engine_uuid']
464 eid = self._engines[e_uuid]
464 eid = self._engines[e_uuid]
465 e.engine_info['engine_id'] = eid
465 e.engine_info['engine_id'] = eid
466 return e
466 return e
467
467
468 def _extract_metadata(self, header, parent, content):
468 def _extract_metadata(self, header, parent, content):
469 md = {'msg_id' : parent['msg_id'],
469 md = {'msg_id' : parent['msg_id'],
470 'received' : datetime.now(),
470 'received' : datetime.now(),
471 'engine_uuid' : header.get('engine', None),
471 'engine_uuid' : header.get('engine', None),
472 'follow' : parent.get('follow', []),
472 'follow' : parent.get('follow', []),
473 'after' : parent.get('after', []),
473 'after' : parent.get('after', []),
474 'status' : content['status'],
474 'status' : content['status'],
475 }
475 }
476
476
477 if md['engine_uuid'] is not None:
477 if md['engine_uuid'] is not None:
478 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
478 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
479
479
480 if 'date' in parent:
480 if 'date' in parent:
481 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
481 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
482 if 'started' in header:
482 if 'started' in header:
483 md['started'] = datetime.strptime(header['started'], util.ISO8601)
483 md['started'] = datetime.strptime(header['started'], util.ISO8601)
484 if 'date' in header:
484 if 'date' in header:
485 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
485 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
486 return md
486 return md
487
487
488 def _register_engine(self, msg):
488 def _register_engine(self, msg):
489 """Register a new engine, and update our connection info."""
489 """Register a new engine, and update our connection info."""
490 content = msg['content']
490 content = msg['content']
491 eid = content['id']
491 eid = content['id']
492 d = {eid : content['queue']}
492 d = {eid : content['queue']}
493 self._update_engines(d)
493 self._update_engines(d)
494
494
495 def _unregister_engine(self, msg):
495 def _unregister_engine(self, msg):
496 """Unregister an engine that has died."""
496 """Unregister an engine that has died."""
497 content = msg['content']
497 content = msg['content']
498 eid = int(content['id'])
498 eid = int(content['id'])
499 if eid in self._ids:
499 if eid in self._ids:
500 self._ids.remove(eid)
500 self._ids.remove(eid)
501 uuid = self._engines.pop(eid)
501 uuid = self._engines.pop(eid)
502
502
503 self._handle_stranded_msgs(eid, uuid)
503 self._handle_stranded_msgs(eid, uuid)
504
504
505 if self._task_socket and self._task_scheme == 'pure':
505 if self._task_socket and self._task_scheme == 'pure':
506 self._stop_scheduling_tasks()
506 self._stop_scheduling_tasks()
507
507
508 def _handle_stranded_msgs(self, eid, uuid):
508 def _handle_stranded_msgs(self, eid, uuid):
509 """Handle messages known to be on an engine when the engine unregisters.
509 """Handle messages known to be on an engine when the engine unregisters.
510
510
511 It is possible that this will fire prematurely - that is, an engine will
511 It is possible that this will fire prematurely - that is, an engine will
512 go down after completing a result, and the client will be notified
512 go down after completing a result, and the client will be notified
513 of the unregistration and later receive the successful result.
513 of the unregistration and later receive the successful result.
514 """
514 """
515
515
516 outstanding = self._outstanding_dict[uuid]
516 outstanding = self._outstanding_dict[uuid]
517
517
518 for msg_id in list(outstanding):
518 for msg_id in list(outstanding):
519 if msg_id in self.results:
519 if msg_id in self.results:
520 # we already
520 # we already
521 continue
521 continue
522 try:
522 try:
523 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
523 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
524 except:
524 except:
525 content = error.wrap_exception()
525 content = error.wrap_exception()
526 # build a fake message:
526 # build a fake message:
527 parent = {}
527 parent = {}
528 header = {}
528 header = {}
529 parent['msg_id'] = msg_id
529 parent['msg_id'] = msg_id
530 header['engine'] = uuid
530 header['engine'] = uuid
531 header['date'] = datetime.now().strftime(util.ISO8601)
531 header['date'] = datetime.now().strftime(util.ISO8601)
532 msg = dict(parent_header=parent, header=header, content=content)
532 msg = dict(parent_header=parent, header=header, content=content)
533 self._handle_apply_reply(msg)
533 self._handle_apply_reply(msg)
534
534
535 def _handle_execute_reply(self, msg):
535 def _handle_execute_reply(self, msg):
536 """Save the reply to an execute_request into our results.
536 """Save the reply to an execute_request into our results.
537
537
538 execute messages are never actually used. apply is used instead.
538 execute messages are never actually used. apply is used instead.
539 """
539 """
540
540
541 parent = msg['parent_header']
541 parent = msg['parent_header']
542 msg_id = parent['msg_id']
542 msg_id = parent['msg_id']
543 if msg_id not in self.outstanding:
543 if msg_id not in self.outstanding:
544 if msg_id in self.history:
544 if msg_id in self.history:
545 print ("got stale result: %s"%msg_id)
545 print ("got stale result: %s"%msg_id)
546 else:
546 else:
547 print ("got unknown result: %s"%msg_id)
547 print ("got unknown result: %s"%msg_id)
548 else:
548 else:
549 self.outstanding.remove(msg_id)
549 self.outstanding.remove(msg_id)
550 self.results[msg_id] = self._unwrap_exception(msg['content'])
550 self.results[msg_id] = self._unwrap_exception(msg['content'])
551
551
552 def _handle_apply_reply(self, msg):
552 def _handle_apply_reply(self, msg):
553 """Save the reply to an apply_request into our results."""
553 """Save the reply to an apply_request into our results."""
554 parent = msg['parent_header']
554 parent = msg['parent_header']
555 msg_id = parent['msg_id']
555 msg_id = parent['msg_id']
556 if msg_id not in self.outstanding:
556 if msg_id not in self.outstanding:
557 if msg_id in self.history:
557 if msg_id in self.history:
558 print ("got stale result: %s"%msg_id)
558 print ("got stale result: %s"%msg_id)
559 print self.results[msg_id]
559 print self.results[msg_id]
560 print msg
560 print msg
561 else:
561 else:
562 print ("got unknown result: %s"%msg_id)
562 print ("got unknown result: %s"%msg_id)
563 else:
563 else:
564 self.outstanding.remove(msg_id)
564 self.outstanding.remove(msg_id)
565 content = msg['content']
565 content = msg['content']
566 header = msg['header']
566 header = msg['header']
567
567
568 # construct metadata:
568 # construct metadata:
569 md = self.metadata[msg_id]
569 md = self.metadata[msg_id]
570 md.update(self._extract_metadata(header, parent, content))
570 md.update(self._extract_metadata(header, parent, content))
571 # is this redundant?
571 # is this redundant?
572 self.metadata[msg_id] = md
572 self.metadata[msg_id] = md
573
573
574 e_outstanding = self._outstanding_dict[md['engine_uuid']]
574 e_outstanding = self._outstanding_dict[md['engine_uuid']]
575 if msg_id in e_outstanding:
575 if msg_id in e_outstanding:
576 e_outstanding.remove(msg_id)
576 e_outstanding.remove(msg_id)
577
577
578 # construct result:
578 # construct result:
579 if content['status'] == 'ok':
579 if content['status'] == 'ok':
580 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
580 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
581 elif content['status'] == 'aborted':
581 elif content['status'] == 'aborted':
582 self.results[msg_id] = error.TaskAborted(msg_id)
582 self.results[msg_id] = error.TaskAborted(msg_id)
583 elif content['status'] == 'resubmitted':
583 elif content['status'] == 'resubmitted':
584 # TODO: handle resubmission
584 # TODO: handle resubmission
585 pass
585 pass
586 else:
586 else:
587 self.results[msg_id] = self._unwrap_exception(content)
587 self.results[msg_id] = self._unwrap_exception(content)
588
588
589 def _flush_notifications(self):
589 def _flush_notifications(self):
590 """Flush notifications of engine registrations waiting
590 """Flush notifications of engine registrations waiting
591 in ZMQ queue."""
591 in ZMQ queue."""
592 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
592 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
593 while msg is not None:
593 while msg is not None:
594 if self.debug:
594 if self.debug:
595 pprint(msg)
595 pprint(msg)
596 msg = msg[-1]
596 msg = msg[-1]
597 msg_type = msg['msg_type']
597 msg_type = msg['msg_type']
598 handler = self._notification_handlers.get(msg_type, None)
598 handler = self._notification_handlers.get(msg_type, None)
599 if handler is None:
599 if handler is None:
600 raise Exception("Unhandled message type: %s"%msg.msg_type)
600 raise Exception("Unhandled message type: %s"%msg.msg_type)
601 else:
601 else:
602 handler(msg)
602 handler(msg)
603 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
603 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
604
604
605 def _flush_results(self, sock):
605 def _flush_results(self, sock):
606 """Flush task or queue results waiting in ZMQ queue."""
606 """Flush task or queue results waiting in ZMQ queue."""
607 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
607 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
608 while msg is not None:
608 while msg is not None:
609 if self.debug:
609 if self.debug:
610 pprint(msg)
610 pprint(msg)
611 msg = msg[-1]
611 msg = msg[-1]
612 msg_type = msg['msg_type']
612 msg_type = msg['msg_type']
613 handler = self._queue_handlers.get(msg_type, None)
613 handler = self._queue_handlers.get(msg_type, None)
614 if handler is None:
614 if handler is None:
615 raise Exception("Unhandled message type: %s"%msg.msg_type)
615 raise Exception("Unhandled message type: %s"%msg.msg_type)
616 else:
616 else:
617 handler(msg)
617 handler(msg)
618 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
618 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
619
619
620 def _flush_control(self, sock):
620 def _flush_control(self, sock):
621 """Flush replies from the control channel waiting
621 """Flush replies from the control channel waiting
622 in the ZMQ queue.
622 in the ZMQ queue.
623
623
624 Currently: ignore them."""
624 Currently: ignore them."""
625 if self._ignored_control_replies <= 0:
625 if self._ignored_control_replies <= 0:
626 return
626 return
627 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
627 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
628 while msg is not None:
628 while msg is not None:
629 self._ignored_control_replies -= 1
629 self._ignored_control_replies -= 1
630 if self.debug:
630 if self.debug:
631 pprint(msg)
631 pprint(msg)
632 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
632 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
633
633
634 def _flush_ignored_control(self):
634 def _flush_ignored_control(self):
635 """flush ignored control replies"""
635 """flush ignored control replies"""
636 while self._ignored_control_replies > 0:
636 while self._ignored_control_replies > 0:
637 self.session.recv(self._control_socket)
637 self.session.recv(self._control_socket)
638 self._ignored_control_replies -= 1
638 self._ignored_control_replies -= 1
639
639
640 def _flush_ignored_hub_replies(self):
640 def _flush_ignored_hub_replies(self):
641 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
641 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
642 while msg is not None:
642 while msg is not None:
643 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
643 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
644
644
645 def _flush_iopub(self, sock):
645 def _flush_iopub(self, sock):
646 """Flush replies from the iopub channel waiting
646 """Flush replies from the iopub channel waiting
647 in the ZMQ queue.
647 in the ZMQ queue.
648 """
648 """
649 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
649 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
650 while msg is not None:
650 while msg is not None:
651 if self.debug:
651 if self.debug:
652 pprint(msg)
652 pprint(msg)
653 msg = msg[-1]
653 msg = msg[-1]
654 parent = msg['parent_header']
654 parent = msg['parent_header']
655 msg_id = parent['msg_id']
655 msg_id = parent['msg_id']
656 content = msg['content']
656 content = msg['content']
657 header = msg['header']
657 header = msg['header']
658 msg_type = msg['msg_type']
658 msg_type = msg['msg_type']
659
659
660 # init metadata:
660 # init metadata:
661 md = self.metadata[msg_id]
661 md = self.metadata[msg_id]
662
662
663 if msg_type == 'stream':
663 if msg_type == 'stream':
664 name = content['name']
664 name = content['name']
665 s = md[name] or ''
665 s = md[name] or ''
666 md[name] = s + content['data']
666 md[name] = s + content['data']
667 elif msg_type == 'pyerr':
667 elif msg_type == 'pyerr':
668 md.update({'pyerr' : self._unwrap_exception(content)})
668 md.update({'pyerr' : self._unwrap_exception(content)})
669 elif msg_type == 'pyin':
669 elif msg_type == 'pyin':
670 md.update({'pyin' : content['code']})
670 md.update({'pyin' : content['code']})
671 else:
671 else:
672 md.update({msg_type : content.get('data', '')})
672 md.update({msg_type : content.get('data', '')})
673
673
674 # reduntant?
674 # reduntant?
675 self.metadata[msg_id] = md
675 self.metadata[msg_id] = md
676
676
677 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
677 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
678
678
679 #--------------------------------------------------------------------------
679 #--------------------------------------------------------------------------
680 # len, getitem
680 # len, getitem
681 #--------------------------------------------------------------------------
681 #--------------------------------------------------------------------------
682
682
683 def __len__(self):
683 def __len__(self):
684 """len(client) returns # of engines."""
684 """len(client) returns # of engines."""
685 return len(self.ids)
685 return len(self.ids)
686
686
687 def __getitem__(self, key):
687 def __getitem__(self, key):
688 """index access returns DirectView multiplexer objects
688 """index access returns DirectView multiplexer objects
689
689
690 Must be int, slice, or list/tuple/xrange of ints"""
690 Must be int, slice, or list/tuple/xrange of ints"""
691 if not isinstance(key, (int, slice, tuple, list, xrange)):
691 if not isinstance(key, (int, slice, tuple, list, xrange)):
692 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
692 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
693 else:
693 else:
694 return self.direct_view(key)
694 return self.direct_view(key)
695
695
696 #--------------------------------------------------------------------------
696 #--------------------------------------------------------------------------
697 # Begin public methods
697 # Begin public methods
698 #--------------------------------------------------------------------------
698 #--------------------------------------------------------------------------
699
699
700 @property
700 @property
701 def ids(self):
701 def ids(self):
702 """Always up-to-date ids property."""
702 """Always up-to-date ids property."""
703 self._flush_notifications()
703 self._flush_notifications()
704 # always copy:
704 # always copy:
705 return list(self._ids)
705 return list(self._ids)
706
706
707 def close(self):
707 def close(self):
708 if self._closed:
708 if self._closed:
709 return
709 return
710 snames = filter(lambda n: n.endswith('socket'), dir(self))
710 snames = filter(lambda n: n.endswith('socket'), dir(self))
711 for socket in map(lambda name: getattr(self, name), snames):
711 for socket in map(lambda name: getattr(self, name), snames):
712 if isinstance(socket, zmq.Socket) and not socket.closed:
712 if isinstance(socket, zmq.Socket) and not socket.closed:
713 socket.close()
713 socket.close()
714 self._closed = True
714 self._closed = True
715
715
716 def spin(self):
716 def spin(self):
717 """Flush any registration notifications and execution results
717 """Flush any registration notifications and execution results
718 waiting in the ZMQ queue.
718 waiting in the ZMQ queue.
719 """
719 """
720 if self._notification_socket:
720 if self._notification_socket:
721 self._flush_notifications()
721 self._flush_notifications()
722 if self._mux_socket:
722 if self._mux_socket:
723 self._flush_results(self._mux_socket)
723 self._flush_results(self._mux_socket)
724 if self._task_socket:
724 if self._task_socket:
725 self._flush_results(self._task_socket)
725 self._flush_results(self._task_socket)
726 if self._control_socket:
726 if self._control_socket:
727 self._flush_control(self._control_socket)
727 self._flush_control(self._control_socket)
728 if self._iopub_socket:
728 if self._iopub_socket:
729 self._flush_iopub(self._iopub_socket)
729 self._flush_iopub(self._iopub_socket)
730 if self._query_socket:
730 if self._query_socket:
731 self._flush_ignored_hub_replies()
731 self._flush_ignored_hub_replies()
732
732
733 def wait(self, jobs=None, timeout=-1):
733 def wait(self, jobs=None, timeout=-1):
734 """waits on one or more `jobs`, for up to `timeout` seconds.
734 """waits on one or more `jobs`, for up to `timeout` seconds.
735
735
736 Parameters
736 Parameters
737 ----------
737 ----------
738
738
739 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
739 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
740 ints are indices to self.history
740 ints are indices to self.history
741 strs are msg_ids
741 strs are msg_ids
742 default: wait on all outstanding messages
742 default: wait on all outstanding messages
743 timeout : float
743 timeout : float
744 a time in seconds, after which to give up.
744 a time in seconds, after which to give up.
745 default is -1, which means no timeout
745 default is -1, which means no timeout
746
746
747 Returns
747 Returns
748 -------
748 -------
749
749
750 True : when all msg_ids are done
750 True : when all msg_ids are done
751 False : timeout reached, some msg_ids still outstanding
751 False : timeout reached, some msg_ids still outstanding
752 """
752 """
753 tic = time.time()
753 tic = time.time()
754 if jobs is None:
754 if jobs is None:
755 theids = self.outstanding
755 theids = self.outstanding
756 else:
756 else:
757 if isinstance(jobs, (int, str, AsyncResult)):
757 if isinstance(jobs, (int, str, AsyncResult)):
758 jobs = [jobs]
758 jobs = [jobs]
759 theids = set()
759 theids = set()
760 for job in jobs:
760 for job in jobs:
761 if isinstance(job, int):
761 if isinstance(job, int):
762 # index access
762 # index access
763 job = self.history[job]
763 job = self.history[job]
764 elif isinstance(job, AsyncResult):
764 elif isinstance(job, AsyncResult):
765 map(theids.add, job.msg_ids)
765 map(theids.add, job.msg_ids)
766 continue
766 continue
767 theids.add(job)
767 theids.add(job)
768 if not theids.intersection(self.outstanding):
768 if not theids.intersection(self.outstanding):
769 return True
769 return True
770 self.spin()
770 self.spin()
771 while theids.intersection(self.outstanding):
771 while theids.intersection(self.outstanding):
772 if timeout >= 0 and ( time.time()-tic ) > timeout:
772 if timeout >= 0 and ( time.time()-tic ) > timeout:
773 break
773 break
774 time.sleep(1e-3)
774 time.sleep(1e-3)
775 self.spin()
775 self.spin()
776 return len(theids.intersection(self.outstanding)) == 0
776 return len(theids.intersection(self.outstanding)) == 0
777
777
778 #--------------------------------------------------------------------------
778 #--------------------------------------------------------------------------
779 # Control methods
779 # Control methods
780 #--------------------------------------------------------------------------
780 #--------------------------------------------------------------------------
781
781
782 @spin_first
782 @spin_first
783 def clear(self, targets=None, block=None):
783 def clear(self, targets=None, block=None):
784 """Clear the namespace in target(s)."""
784 """Clear the namespace in target(s)."""
785 block = self.block if block is None else block
785 block = self.block if block is None else block
786 targets = self._build_targets(targets)[0]
786 targets = self._build_targets(targets)[0]
787 for t in targets:
787 for t in targets:
788 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
788 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
789 error = False
789 error = False
790 if block:
790 if block:
791 self._flush_ignored_control()
791 self._flush_ignored_control()
792 for i in range(len(targets)):
792 for i in range(len(targets)):
793 idents,msg = self.session.recv(self._control_socket,0)
793 idents,msg = self.session.recv(self._control_socket,0)
794 if self.debug:
794 if self.debug:
795 pprint(msg)
795 pprint(msg)
796 if msg['content']['status'] != 'ok':
796 if msg['content']['status'] != 'ok':
797 error = self._unwrap_exception(msg['content'])
797 error = self._unwrap_exception(msg['content'])
798 else:
798 else:
799 self._ignored_control_replies += len(targets)
799 self._ignored_control_replies += len(targets)
800 if error:
800 if error:
801 raise error
801 raise error
802
802
803
803
804 @spin_first
804 @spin_first
805 def abort(self, jobs=None, targets=None, block=None):
805 def abort(self, jobs=None, targets=None, block=None):
806 """Abort specific jobs from the execution queues of target(s).
806 """Abort specific jobs from the execution queues of target(s).
807
807
808 This is a mechanism to prevent jobs that have already been submitted
808 This is a mechanism to prevent jobs that have already been submitted
809 from executing.
809 from executing.
810
810
811 Parameters
811 Parameters
812 ----------
812 ----------
813
813
814 jobs : msg_id, list of msg_ids, or AsyncResult
814 jobs : msg_id, list of msg_ids, or AsyncResult
815 The jobs to be aborted
815 The jobs to be aborted
816
816
817
817
818 """
818 """
819 block = self.block if block is None else block
819 block = self.block if block is None else block
820 targets = self._build_targets(targets)[0]
820 targets = self._build_targets(targets)[0]
821 msg_ids = []
821 msg_ids = []
822 if isinstance(jobs, (basestring,AsyncResult)):
822 if isinstance(jobs, (basestring,AsyncResult)):
823 jobs = [jobs]
823 jobs = [jobs]
824 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
824 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
825 if bad_ids:
825 if bad_ids:
826 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
826 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
827 for j in jobs:
827 for j in jobs:
828 if isinstance(j, AsyncResult):
828 if isinstance(j, AsyncResult):
829 msg_ids.extend(j.msg_ids)
829 msg_ids.extend(j.msg_ids)
830 else:
830 else:
831 msg_ids.append(j)
831 msg_ids.append(j)
832 content = dict(msg_ids=msg_ids)
832 content = dict(msg_ids=msg_ids)
833 for t in targets:
833 for t in targets:
834 self.session.send(self._control_socket, 'abort_request',
834 self.session.send(self._control_socket, 'abort_request',
835 content=content, ident=t)
835 content=content, ident=t)
836 error = False
836 error = False
837 if block:
837 if block:
838 self._flush_ignored_control()
838 self._flush_ignored_control()
839 for i in range(len(targets)):
839 for i in range(len(targets)):
840 idents,msg = self.session.recv(self._control_socket,0)
840 idents,msg = self.session.recv(self._control_socket,0)
841 if self.debug:
841 if self.debug:
842 pprint(msg)
842 pprint(msg)
843 if msg['content']['status'] != 'ok':
843 if msg['content']['status'] != 'ok':
844 error = self._unwrap_exception(msg['content'])
844 error = self._unwrap_exception(msg['content'])
845 else:
845 else:
846 self._ignored_control_replies += len(targets)
846 self._ignored_control_replies += len(targets)
847 if error:
847 if error:
848 raise error
848 raise error
849
849
850 @spin_first
850 @spin_first
851 def shutdown(self, targets=None, restart=False, hub=False, block=None):
851 def shutdown(self, targets=None, restart=False, hub=False, block=None):
852 """Terminates one or more engine processes, optionally including the hub."""
852 """Terminates one or more engine processes, optionally including the hub."""
853 block = self.block if block is None else block
853 block = self.block if block is None else block
854 if hub:
854 if hub:
855 targets = 'all'
855 targets = 'all'
856 targets = self._build_targets(targets)[0]
856 targets = self._build_targets(targets)[0]
857 for t in targets:
857 for t in targets:
858 self.session.send(self._control_socket, 'shutdown_request',
858 self.session.send(self._control_socket, 'shutdown_request',
859 content={'restart':restart},ident=t)
859 content={'restart':restart},ident=t)
860 error = False
860 error = False
861 if block or hub:
861 if block or hub:
862 self._flush_ignored_control()
862 self._flush_ignored_control()
863 for i in range(len(targets)):
863 for i in range(len(targets)):
864 idents,msg = self.session.recv(self._control_socket, 0)
864 idents,msg = self.session.recv(self._control_socket, 0)
865 if self.debug:
865 if self.debug:
866 pprint(msg)
866 pprint(msg)
867 if msg['content']['status'] != 'ok':
867 if msg['content']['status'] != 'ok':
868 error = self._unwrap_exception(msg['content'])
868 error = self._unwrap_exception(msg['content'])
869 else:
869 else:
870 self._ignored_control_replies += len(targets)
870 self._ignored_control_replies += len(targets)
871
871
872 if hub:
872 if hub:
873 time.sleep(0.25)
873 time.sleep(0.25)
874 self.session.send(self._query_socket, 'shutdown_request')
874 self.session.send(self._query_socket, 'shutdown_request')
875 idents,msg = self.session.recv(self._query_socket, 0)
875 idents,msg = self.session.recv(self._query_socket, 0)
876 if self.debug:
876 if self.debug:
877 pprint(msg)
877 pprint(msg)
878 if msg['content']['status'] != 'ok':
878 if msg['content']['status'] != 'ok':
879 error = self._unwrap_exception(msg['content'])
879 error = self._unwrap_exception(msg['content'])
880
880
881 if error:
881 if error:
882 raise error
882 raise error
883
883
884 #--------------------------------------------------------------------------
884 #--------------------------------------------------------------------------
885 # Execution related methods
885 # Execution related methods
886 #--------------------------------------------------------------------------
886 #--------------------------------------------------------------------------
887
887
888 def _maybe_raise(self, result):
888 def _maybe_raise(self, result):
889 """wrapper for maybe raising an exception if apply failed."""
889 """wrapper for maybe raising an exception if apply failed."""
890 if isinstance(result, error.RemoteError):
890 if isinstance(result, error.RemoteError):
891 raise result
891 raise result
892
892
893 return result
893 return result
894
894
895 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
895 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
896 ident=None):
896 ident=None):
897 """construct and send an apply message via a socket.
897 """construct and send an apply message via a socket.
898
898
899 This is the principal method with which all engine execution is performed by views.
899 This is the principal method with which all engine execution is performed by views.
900 """
900 """
901
901
902 assert not self._closed, "cannot use me anymore, I'm closed!"
902 assert not self._closed, "cannot use me anymore, I'm closed!"
903 # defaults:
903 # defaults:
904 args = args if args is not None else []
904 args = args if args is not None else []
905 kwargs = kwargs if kwargs is not None else {}
905 kwargs = kwargs if kwargs is not None else {}
906 subheader = subheader if subheader is not None else {}
906 subheader = subheader if subheader is not None else {}
907
907
908 # validate arguments
908 # validate arguments
909 if not callable(f):
909 if not callable(f):
910 raise TypeError("f must be callable, not %s"%type(f))
910 raise TypeError("f must be callable, not %s"%type(f))
911 if not isinstance(args, (tuple, list)):
911 if not isinstance(args, (tuple, list)):
912 raise TypeError("args must be tuple or list, not %s"%type(args))
912 raise TypeError("args must be tuple or list, not %s"%type(args))
913 if not isinstance(kwargs, dict):
913 if not isinstance(kwargs, dict):
914 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
914 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
915 if not isinstance(subheader, dict):
915 if not isinstance(subheader, dict):
916 raise TypeError("subheader must be dict, not %s"%type(subheader))
916 raise TypeError("subheader must be dict, not %s"%type(subheader))
917
917
918 bufs = util.pack_apply_message(f,args,kwargs)
918 bufs = util.pack_apply_message(f,args,kwargs)
919
919
920 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
920 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
921 subheader=subheader, track=track)
921 subheader=subheader, track=track)
922
922
923 msg_id = msg['msg_id']
923 msg_id = msg['msg_id']
924 self.outstanding.add(msg_id)
924 self.outstanding.add(msg_id)
925 if ident:
925 if ident:
926 # possibly routed to a specific engine
926 # possibly routed to a specific engine
927 if isinstance(ident, list):
927 if isinstance(ident, list):
928 ident = ident[-1]
928 ident = ident[-1]
929 if ident in self._engines.values():
929 if ident in self._engines.values():
930 # save for later, in case of engine death
930 # save for later, in case of engine death
931 self._outstanding_dict[ident].add(msg_id)
931 self._outstanding_dict[ident].add(msg_id)
932 self.history.append(msg_id)
932 self.history.append(msg_id)
933 self.metadata[msg_id]['submitted'] = datetime.now()
933 self.metadata[msg_id]['submitted'] = datetime.now()
934
934
935 return msg
935 return msg
936
936
937 #--------------------------------------------------------------------------
937 #--------------------------------------------------------------------------
938 # construct a View object
938 # construct a View object
939 #--------------------------------------------------------------------------
939 #--------------------------------------------------------------------------
940
940
941 def load_balanced_view(self, targets=None):
941 def load_balanced_view(self, targets=None):
942 """construct a DirectView object.
942 """construct a DirectView object.
943
943
944 If no arguments are specified, create a LoadBalancedView
944 If no arguments are specified, create a LoadBalancedView
945 using all engines.
945 using all engines.
946
946
947 Parameters
947 Parameters
948 ----------
948 ----------
949
949
950 targets: list,slice,int,etc. [default: use all engines]
950 targets: list,slice,int,etc. [default: use all engines]
951 The subset of engines across which to load-balance
951 The subset of engines across which to load-balance
952 """
952 """
953 if targets is not None:
953 if targets is not None:
954 targets = self._build_targets(targets)[1]
954 targets = self._build_targets(targets)[1]
955 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
955 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
956
956
957 def direct_view(self, targets='all'):
957 def direct_view(self, targets='all'):
958 """construct a DirectView object.
958 """construct a DirectView object.
959
959
960 If no targets are specified, create a DirectView
960 If no targets are specified, create a DirectView
961 using all engines.
961 using all engines.
962
962
963 Parameters
963 Parameters
964 ----------
964 ----------
965
965
966 targets: list,slice,int,etc. [default: use all engines]
966 targets: list,slice,int,etc. [default: use all engines]
967 The engines to use for the View
967 The engines to use for the View
968 """
968 """
969 single = isinstance(targets, int)
969 single = isinstance(targets, int)
970 targets = self._build_targets(targets)[1]
970 targets = self._build_targets(targets)[1]
971 if single:
971 if single:
972 targets = targets[0]
972 targets = targets[0]
973 return DirectView(client=self, socket=self._mux_socket, targets=targets)
973 return DirectView(client=self, socket=self._mux_socket, targets=targets)
974
974
975 #--------------------------------------------------------------------------
975 #--------------------------------------------------------------------------
976 # Query methods
976 # Query methods
977 #--------------------------------------------------------------------------
977 #--------------------------------------------------------------------------
978
978
979 @spin_first
979 @spin_first
980 def get_result(self, indices_or_msg_ids=None, block=None):
980 def get_result(self, indices_or_msg_ids=None, block=None):
981 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
981 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
982
982
983 If the client already has the results, no request to the Hub will be made.
983 If the client already has the results, no request to the Hub will be made.
984
984
985 This is a convenient way to construct AsyncResult objects, which are wrappers
985 This is a convenient way to construct AsyncResult objects, which are wrappers
986 that include metadata about execution, and allow for awaiting results that
986 that include metadata about execution, and allow for awaiting results that
987 were not submitted by this Client.
987 were not submitted by this Client.
988
988
989 It can also be a convenient way to retrieve the metadata associated with
989 It can also be a convenient way to retrieve the metadata associated with
990 blocking execution, since it always retrieves
990 blocking execution, since it always retrieves
991
991
992 Examples
992 Examples
993 --------
993 --------
994 ::
994 ::
995
995
996 In [10]: r = client.apply()
996 In [10]: r = client.apply()
997
997
998 Parameters
998 Parameters
999 ----------
999 ----------
1000
1000
1001 indices_or_msg_ids : integer history index, str msg_id, or list of either
1001 indices_or_msg_ids : integer history index, str msg_id, or list of either
1002 The indices or msg_ids of indices to be retrieved
1002 The indices or msg_ids of indices to be retrieved
1003
1003
1004 block : bool
1004 block : bool
1005 Whether to wait for the result to be done
1005 Whether to wait for the result to be done
1006
1006
1007 Returns
1007 Returns
1008 -------
1008 -------
1009
1009
1010 AsyncResult
1010 AsyncResult
1011 A single AsyncResult object will always be returned.
1011 A single AsyncResult object will always be returned.
1012
1012
1013 AsyncHubResult
1013 AsyncHubResult
1014 A subclass of AsyncResult that retrieves results from the Hub
1014 A subclass of AsyncResult that retrieves results from the Hub
1015
1015
1016 """
1016 """
1017 block = self.block if block is None else block
1017 block = self.block if block is None else block
1018 if indices_or_msg_ids is None:
1018 if indices_or_msg_ids is None:
1019 indices_or_msg_ids = -1
1019 indices_or_msg_ids = -1
1020
1020
1021 if not isinstance(indices_or_msg_ids, (list,tuple)):
1021 if not isinstance(indices_or_msg_ids, (list,tuple)):
1022 indices_or_msg_ids = [indices_or_msg_ids]
1022 indices_or_msg_ids = [indices_or_msg_ids]
1023
1023
1024 theids = []
1024 theids = []
1025 for id in indices_or_msg_ids:
1025 for id in indices_or_msg_ids:
1026 if isinstance(id, int):
1026 if isinstance(id, int):
1027 id = self.history[id]
1027 id = self.history[id]
1028 if not isinstance(id, str):
1028 if not isinstance(id, str):
1029 raise TypeError("indices must be str or int, not %r"%id)
1029 raise TypeError("indices must be str or int, not %r"%id)
1030 theids.append(id)
1030 theids.append(id)
1031
1031
1032 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1032 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1033 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1033 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1034
1034
1035 if remote_ids:
1035 if remote_ids:
1036 ar = AsyncHubResult(self, msg_ids=theids)
1036 ar = AsyncHubResult(self, msg_ids=theids)
1037 else:
1037 else:
1038 ar = AsyncResult(self, msg_ids=theids)
1038 ar = AsyncResult(self, msg_ids=theids)
1039
1039
1040 if block:
1040 if block:
1041 ar.wait()
1041 ar.wait()
1042
1042
1043 return ar
1043 return ar
1044
1044
1045 @spin_first
1045 @spin_first
1046 def result_status(self, msg_ids, status_only=True):
1046 def result_status(self, msg_ids, status_only=True):
1047 """Check on the status of the result(s) of the apply request with `msg_ids`.
1047 """Check on the status of the result(s) of the apply request with `msg_ids`.
1048
1048
1049 If status_only is False, then the actual results will be retrieved, else
1049 If status_only is False, then the actual results will be retrieved, else
1050 only the status of the results will be checked.
1050 only the status of the results will be checked.
1051
1051
1052 Parameters
1052 Parameters
1053 ----------
1053 ----------
1054
1054
1055 msg_ids : list of msg_ids
1055 msg_ids : list of msg_ids
1056 if int:
1056 if int:
1057 Passed as index to self.history for convenience.
1057 Passed as index to self.history for convenience.
1058 status_only : bool (default: True)
1058 status_only : bool (default: True)
1059 if False:
1059 if False:
1060 Retrieve the actual results of completed tasks.
1060 Retrieve the actual results of completed tasks.
1061
1061
1062 Returns
1062 Returns
1063 -------
1063 -------
1064
1064
1065 results : dict
1065 results : dict
1066 There will always be the keys 'pending' and 'completed', which will
1066 There will always be the keys 'pending' and 'completed', which will
1067 be lists of msg_ids that are incomplete or complete. If `status_only`
1067 be lists of msg_ids that are incomplete or complete. If `status_only`
1068 is False, then completed results will be keyed by their `msg_id`.
1068 is False, then completed results will be keyed by their `msg_id`.
1069 """
1069 """
1070 if not isinstance(msg_ids, (list,tuple)):
1070 if not isinstance(msg_ids, (list,tuple)):
1071 msg_ids = [msg_ids]
1071 msg_ids = [msg_ids]
1072
1072
1073 theids = []
1073 theids = []
1074 for msg_id in msg_ids:
1074 for msg_id in msg_ids:
1075 if isinstance(msg_id, int):
1075 if isinstance(msg_id, int):
1076 msg_id = self.history[msg_id]
1076 msg_id = self.history[msg_id]
1077 if not isinstance(msg_id, basestring):
1077 if not isinstance(msg_id, basestring):
1078 raise TypeError("msg_ids must be str, not %r"%msg_id)
1078 raise TypeError("msg_ids must be str, not %r"%msg_id)
1079 theids.append(msg_id)
1079 theids.append(msg_id)
1080
1080
1081 completed = []
1081 completed = []
1082 local_results = {}
1082 local_results = {}
1083
1083
1084 # comment this block out to temporarily disable local shortcut:
1084 # comment this block out to temporarily disable local shortcut:
1085 for msg_id in theids:
1085 for msg_id in theids:
1086 if msg_id in self.results:
1086 if msg_id in self.results:
1087 completed.append(msg_id)
1087 completed.append(msg_id)
1088 local_results[msg_id] = self.results[msg_id]
1088 local_results[msg_id] = self.results[msg_id]
1089 theids.remove(msg_id)
1089 theids.remove(msg_id)
1090
1090
1091 if theids: # some not locally cached
1091 if theids: # some not locally cached
1092 content = dict(msg_ids=theids, status_only=status_only)
1092 content = dict(msg_ids=theids, status_only=status_only)
1093 msg = self.session.send(self._query_socket, "result_request", content=content)
1093 msg = self.session.send(self._query_socket, "result_request", content=content)
1094 zmq.select([self._query_socket], [], [])
1094 zmq.select([self._query_socket], [], [])
1095 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1095 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1096 if self.debug:
1096 if self.debug:
1097 pprint(msg)
1097 pprint(msg)
1098 content = msg['content']
1098 content = msg['content']
1099 if content['status'] != 'ok':
1099 if content['status'] != 'ok':
1100 raise self._unwrap_exception(content)
1100 raise self._unwrap_exception(content)
1101 buffers = msg['buffers']
1101 buffers = msg['buffers']
1102 else:
1102 else:
1103 content = dict(completed=[],pending=[])
1103 content = dict(completed=[],pending=[])
1104
1104
1105 content['completed'].extend(completed)
1105 content['completed'].extend(completed)
1106
1106
1107 if status_only:
1107 if status_only:
1108 return content
1108 return content
1109
1109
1110 failures = []
1110 failures = []
1111 # load cached results into result:
1111 # load cached results into result:
1112 content.update(local_results)
1112 content.update(local_results)
1113 # update cache with results:
1113 # update cache with results:
1114 for msg_id in sorted(theids):
1114 for msg_id in sorted(theids):
1115 if msg_id in content['completed']:
1115 if msg_id in content['completed']:
1116 rec = content[msg_id]
1116 rec = content[msg_id]
1117 parent = rec['header']
1117 parent = rec['header']
1118 header = rec['result_header']
1118 header = rec['result_header']
1119 rcontent = rec['result_content']
1119 rcontent = rec['result_content']
1120 iodict = rec['io']
1120 iodict = rec['io']
1121 if isinstance(rcontent, str):
1121 if isinstance(rcontent, str):
1122 rcontent = self.session.unpack(rcontent)
1122 rcontent = self.session.unpack(rcontent)
1123
1123
1124 md = self.metadata[msg_id]
1124 md = self.metadata[msg_id]
1125 md.update(self._extract_metadata(header, parent, rcontent))
1125 md.update(self._extract_metadata(header, parent, rcontent))
1126 md.update(iodict)
1126 md.update(iodict)
1127
1127
1128 if rcontent['status'] == 'ok':
1128 if rcontent['status'] == 'ok':
1129 res,buffers = util.unserialize_object(buffers)
1129 res,buffers = util.unserialize_object(buffers)
1130 else:
1130 else:
1131 print rcontent
1131 print rcontent
1132 res = self._unwrap_exception(rcontent)
1132 res = self._unwrap_exception(rcontent)
1133 failures.append(res)
1133 failures.append(res)
1134
1134
1135 self.results[msg_id] = res
1135 self.results[msg_id] = res
1136 content[msg_id] = res
1136 content[msg_id] = res
1137
1137
1138 if len(theids) == 1 and failures:
1138 if len(theids) == 1 and failures:
1139 raise failures[0]
1139 raise failures[0]
1140
1140
1141 error.collect_exceptions(failures, "result_status")
1141 error.collect_exceptions(failures, "result_status")
1142 return content
1142 return content
1143
1143
1144 @spin_first
1144 @spin_first
1145 def queue_status(self, targets='all', verbose=False):
1145 def queue_status(self, targets='all', verbose=False):
1146 """Fetch the status of engine queues.
1146 """Fetch the status of engine queues.
1147
1147
1148 Parameters
1148 Parameters
1149 ----------
1149 ----------
1150
1150
1151 targets : int/str/list of ints/strs
1151 targets : int/str/list of ints/strs
1152 the engines whose states are to be queried.
1152 the engines whose states are to be queried.
1153 default : all
1153 default : all
1154 verbose : bool
1154 verbose : bool
1155 Whether to return lengths only, or lists of ids for each element
1155 Whether to return lengths only, or lists of ids for each element
1156 """
1156 """
1157 engine_ids = self._build_targets(targets)[1]
1157 engine_ids = self._build_targets(targets)[1]
1158 content = dict(targets=engine_ids, verbose=verbose)
1158 content = dict(targets=engine_ids, verbose=verbose)
1159 self.session.send(self._query_socket, "queue_request", content=content)
1159 self.session.send(self._query_socket, "queue_request", content=content)
1160 idents,msg = self.session.recv(self._query_socket, 0)
1160 idents,msg = self.session.recv(self._query_socket, 0)
1161 if self.debug:
1161 if self.debug:
1162 pprint(msg)
1162 pprint(msg)
1163 content = msg['content']
1163 content = msg['content']
1164 status = content.pop('status')
1164 status = content.pop('status')
1165 if status != 'ok':
1165 if status != 'ok':
1166 raise self._unwrap_exception(content)
1166 raise self._unwrap_exception(content)
1167 content = util.rekey(content)
1167 content = util.rekey(content)
1168 if isinstance(targets, int):
1168 if isinstance(targets, int):
1169 return content[targets]
1169 return content[targets]
1170 else:
1170 else:
1171 return content
1171 return content
1172
1172
1173 @spin_first
1173 @spin_first
1174 def purge_results(self, jobs=[], targets=[]):
1174 def purge_results(self, jobs=[], targets=[]):
1175 """Tell the Hub to forget results.
1175 """Tell the Hub to forget results.
1176
1176
1177 Individual results can be purged by msg_id, or the entire
1177 Individual results can be purged by msg_id, or the entire
1178 history of specific targets can be purged.
1178 history of specific targets can be purged.
1179
1179
1180 Parameters
1180 Parameters
1181 ----------
1181 ----------
1182
1182
1183 jobs : str or list of str or AsyncResult objects
1183 jobs : str or list of str or AsyncResult objects
1184 the msg_ids whose results should be forgotten.
1184 the msg_ids whose results should be forgotten.
1185 targets : int/str/list of ints/strs
1185 targets : int/str/list of ints/strs
1186 The targets, by uuid or int_id, whose entire history is to be purged.
1186 The targets, by uuid or int_id, whose entire history is to be purged.
1187 Use `targets='all'` to scrub everything from the Hub's memory.
1187 Use `targets='all'` to scrub everything from the Hub's memory.
1188
1188
1189 default : None
1189 default : None
1190 """
1190 """
1191 if not targets and not jobs:
1191 if not targets and not jobs:
1192 raise ValueError("Must specify at least one of `targets` and `jobs`")
1192 raise ValueError("Must specify at least one of `targets` and `jobs`")
1193 if targets:
1193 if targets:
1194 targets = self._build_targets(targets)[1]
1194 targets = self._build_targets(targets)[1]
1195
1195
1196 # construct msg_ids from jobs
1196 # construct msg_ids from jobs
1197 msg_ids = []
1197 msg_ids = []
1198 if isinstance(jobs, (basestring,AsyncResult)):
1198 if isinstance(jobs, (basestring,AsyncResult)):
1199 jobs = [jobs]
1199 jobs = [jobs]
1200 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1200 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1201 if bad_ids:
1201 if bad_ids:
1202 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1202 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1203 for j in jobs:
1203 for j in jobs:
1204 if isinstance(j, AsyncResult):
1204 if isinstance(j, AsyncResult):
1205 msg_ids.extend(j.msg_ids)
1205 msg_ids.extend(j.msg_ids)
1206 else:
1206 else:
1207 msg_ids.append(j)
1207 msg_ids.append(j)
1208
1208
1209 content = dict(targets=targets, msg_ids=msg_ids)
1209 content = dict(targets=targets, msg_ids=msg_ids)
1210 self.session.send(self._query_socket, "purge_request", content=content)
1210 self.session.send(self._query_socket, "purge_request", content=content)
1211 idents, msg = self.session.recv(self._query_socket, 0)
1211 idents, msg = self.session.recv(self._query_socket, 0)
1212 if self.debug:
1212 if self.debug:
1213 pprint(msg)
1213 pprint(msg)
1214 content = msg['content']
1214 content = msg['content']
1215 if content['status'] != 'ok':
1215 if content['status'] != 'ok':
1216 raise self._unwrap_exception(content)
1216 raise self._unwrap_exception(content)
1217
1217
1218 @spin_first
1219 def hub_history(self):
1220 """Get the Hub's history
1221
1222 Just like the Client, the Hub has a history, which is a list of msg_ids.
1223 This will contain the history of all clients, and, depending on configuration,
1224 may contain history across multiple cluster sessions.
1225
1226 Any msg_id returned here is a valid argument to `get_result`.
1227
1228 Returns
1229 -------
1230
1231 msg_ids : list of strs
1232 list of all msg_ids, ordered by task submission time.
1233 """
1234
1235 self.session.send(self._query_socket, "history_request", content={})
1236 idents, msg = self.session.recv(self._query_socket, 0)
1237
1238 if self.debug:
1239 pprint(msg)
1240 content = msg['content']
1241 if content['status'] != 'ok':
1242 raise self._unwrap_exception(content)
1243 else:
1244 return content['history']
1245
1246 @spin_first
1247 def db_query(self, query, keys=None):
1248 """Query the Hub's TaskRecord database
1249
1250 This will return a list of task record dicts that match `query`
1251
1252 Parameters
1253 ----------
1254
1255 query : mongodb query dict
1256 The search dict. See mongodb query docs for details.
1257 keys : list of strs [optional]
1258 THe subset of keys to be returned. The default is to fetch everything.
1259 'msg_id' will *always* be included.
1260 """
1261 content = dict(query=query, keys=keys)
1262 self.session.send(self._query_socket, "db_request", content=content)
1263 idents, msg = self.session.recv(self._query_socket, 0)
1264 if self.debug:
1265 pprint(msg)
1266 content = msg['content']
1267 if content['status'] != 'ok':
1268 raise self._unwrap_exception(content)
1269
1270 records = content['records']
1271 buffer_lens = content['buffer_lens']
1272 result_buffer_lens = content['result_buffer_lens']
1273 buffers = msg['buffers']
1274 has_bufs = buffer_lens is not None
1275 has_rbufs = result_buffer_lens is not None
1276 for i,rec in enumerate(records):
1277 # relink buffers
1278 if has_bufs:
1279 blen = buffer_lens[i]
1280 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1281 if has_rbufs:
1282 blen = result_buffer_lens[i]
1283 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1284 # turn timestamps back into times
1285 for key in 'submitted started completed resubmitted'.split():
1286 maybedate = rec.get(key, None)
1287 if maybedate and util.ISO8601_RE.match(maybedate):
1288 rec[key] = datetime.strptime(maybedate, util.ISO8601)
1289
1290 return records
1218
1291
1219 __all__ = [ 'Client' ]
1292 __all__ = [ 'Client' ]
@@ -1,155 +1,180 b''
1 """A Task logger that presents our DB interface,
1 """A Task logger that presents our DB interface,
2 but exists entirely in memory and implemented with dicts.
2 but exists entirely in memory and implemented with dicts.
3
3
4 TaskRecords are dicts of the form:
4 TaskRecords are dicts of the form:
5 {
5 {
6 'msg_id' : str(uuid),
6 'msg_id' : str(uuid),
7 'client_uuid' : str(uuid),
7 'client_uuid' : str(uuid),
8 'engine_uuid' : str(uuid) or None,
8 'engine_uuid' : str(uuid) or None,
9 'header' : dict(header),
9 'header' : dict(header),
10 'content': dict(content),
10 'content': dict(content),
11 'buffers': list(buffers),
11 'buffers': list(buffers),
12 'submitted': datetime,
12 'submitted': datetime,
13 'started': datetime or None,
13 'started': datetime or None,
14 'completed': datetime or None,
14 'completed': datetime or None,
15 'resubmitted': datetime or None,
15 'resubmitted': datetime or None,
16 'result_header' : dict(header) or None,
16 'result_header' : dict(header) or None,
17 'result_content' : dict(content) or None,
17 'result_content' : dict(content) or None,
18 'result_buffers' : list(buffers) or None,
18 'result_buffers' : list(buffers) or None,
19 }
19 }
20 With this info, many of the special categories of tasks can be defined by query:
20 With this info, many of the special categories of tasks can be defined by query:
21
21
22 pending: completed is None
22 pending: completed is None
23 client's outstanding: client_uuid = uuid && completed is None
23 client's outstanding: client_uuid = uuid && completed is None
24 MIA: arrived is None (and completed is None)
24 MIA: arrived is None (and completed is None)
25 etc.
25 etc.
26
26
27 EngineRecords are dicts of the form:
27 EngineRecords are dicts of the form:
28 {
28 {
29 'eid' : int(id),
29 'eid' : int(id),
30 'uuid': str(uuid)
30 'uuid': str(uuid)
31 }
31 }
32 This may be extended, but is currently.
32 This may be extended, but is currently.
33
33
34 We support a subset of mongodb operators:
34 We support a subset of mongodb operators:
35 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
35 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
36 """
36 """
37 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
38 # Copyright (C) 2010 The IPython Development Team
38 # Copyright (C) 2010 The IPython Development Team
39 #
39 #
40 # Distributed under the terms of the BSD License. The full license is in
40 # Distributed under the terms of the BSD License. The full license is in
41 # the file COPYING, distributed as part of this software.
41 # the file COPYING, distributed as part of this software.
42 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
43
43
44
44
45 from datetime import datetime
45 from datetime import datetime
46
46
47 from IPython.config.configurable import Configurable
47 from IPython.config.configurable import Configurable
48
48
49 from IPython.utils.traitlets import Dict, CUnicode
49 from IPython.utils.traitlets import Dict, CUnicode
50
50
51 filters = {
51 filters = {
52 '$lt' : lambda a,b: a < b,
52 '$lt' : lambda a,b: a < b,
53 '$gt' : lambda a,b: b > a,
53 '$gt' : lambda a,b: b > a,
54 '$eq' : lambda a,b: a == b,
54 '$eq' : lambda a,b: a == b,
55 '$ne' : lambda a,b: a != b,
55 '$ne' : lambda a,b: a != b,
56 '$lte': lambda a,b: a <= b,
56 '$lte': lambda a,b: a <= b,
57 '$gte': lambda a,b: a >= b,
57 '$gte': lambda a,b: a >= b,
58 '$in' : lambda a,b: a in b,
58 '$in' : lambda a,b: a in b,
59 '$nin': lambda a,b: a not in b,
59 '$nin': lambda a,b: a not in b,
60 '$all': lambda a,b: all([ a in bb for bb in b ]),
60 '$all': lambda a,b: all([ a in bb for bb in b ]),
61 '$mod': lambda a,b: a%b[0] == b[1],
61 '$mod': lambda a,b: a%b[0] == b[1],
62 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
62 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
63 }
63 }
64
64
65
65
66 class CompositeFilter(object):
66 class CompositeFilter(object):
67 """Composite filter for matching multiple properties."""
67 """Composite filter for matching multiple properties."""
68
68
69 def __init__(self, dikt):
69 def __init__(self, dikt):
70 self.tests = []
70 self.tests = []
71 self.values = []
71 self.values = []
72 for key, value in dikt.iteritems():
72 for key, value in dikt.iteritems():
73 self.tests.append(filters[key])
73 self.tests.append(filters[key])
74 self.values.append(value)
74 self.values.append(value)
75
75
76 def __call__(self, value):
76 def __call__(self, value):
77 for test,check in zip(self.tests, self.values):
77 for test,check in zip(self.tests, self.values):
78 if not test(value, check):
78 if not test(value, check):
79 return False
79 return False
80 return True
80 return True
81
81
82 class BaseDB(Configurable):
82 class BaseDB(Configurable):
83 """Empty Parent class so traitlets work on DB."""
83 """Empty Parent class so traitlets work on DB."""
84 # base configurable traits:
84 # base configurable traits:
85 session = CUnicode("")
85 session = CUnicode("")
86
86
87 class DictDB(BaseDB):
87 class DictDB(BaseDB):
88 """Basic in-memory dict-based object for saving Task Records.
88 """Basic in-memory dict-based object for saving Task Records.
89
89
90 This is the first object to present the DB interface
90 This is the first object to present the DB interface
91 for logging tasks out of memory.
91 for logging tasks out of memory.
92
92
93 The interface is based on MongoDB, so adding a MongoDB
93 The interface is based on MongoDB, so adding a MongoDB
94 backend should be straightforward.
94 backend should be straightforward.
95 """
95 """
96
96
97 _records = Dict()
97 _records = Dict()
98
98
99 def _match_one(self, rec, tests):
99 def _match_one(self, rec, tests):
100 """Check if a specific record matches tests."""
100 """Check if a specific record matches tests."""
101 for key,test in tests.iteritems():
101 for key,test in tests.iteritems():
102 if not test(rec.get(key, None)):
102 if not test(rec.get(key, None)):
103 return False
103 return False
104 return True
104 return True
105
105
106 def _match(self, check, id_only=True):
106 def _match(self, check):
107 """Find all the matches for a check dict."""
107 """Find all the matches for a check dict."""
108 matches = {}
108 matches = []
109 tests = {}
109 tests = {}
110 for k,v in check.iteritems():
110 for k,v in check.iteritems():
111 if isinstance(v, dict):
111 if isinstance(v, dict):
112 tests[k] = CompositeFilter(v)
112 tests[k] = CompositeFilter(v)
113 else:
113 else:
114 tests[k] = lambda o: o==v
114 tests[k] = lambda o: o==v
115
115
116 for msg_id, rec in self._records.iteritems():
116 for rec in self._records.itervalues():
117 if self._match_one(rec, tests):
117 if self._match_one(rec, tests):
118 matches[msg_id] = rec
118 matches.append(rec)
119 if id_only:
119 return matches
120 return matches.keys()
120
121 else:
121 def _extract_subdict(self, rec, keys):
122 return matches
122 """extract subdict of keys"""
123
123 d = {}
124 d['msg_id'] = rec['msg_id']
125 for key in keys:
126 d[key] = rec[key]
127 return d
124
128
125 def add_record(self, msg_id, rec):
129 def add_record(self, msg_id, rec):
126 """Add a new Task Record, by msg_id."""
130 """Add a new Task Record, by msg_id."""
127 if self._records.has_key(msg_id):
131 if self._records.has_key(msg_id):
128 raise KeyError("Already have msg_id %r"%(msg_id))
132 raise KeyError("Already have msg_id %r"%(msg_id))
129 self._records[msg_id] = rec
133 self._records[msg_id] = rec
130
134
131 def get_record(self, msg_id):
135 def get_record(self, msg_id):
132 """Get a specific Task Record, by msg_id."""
136 """Get a specific Task Record, by msg_id."""
133 if not self._records.has_key(msg_id):
137 if not self._records.has_key(msg_id):
134 raise KeyError("No such msg_id %r"%(msg_id))
138 raise KeyError("No such msg_id %r"%(msg_id))
135 return self._records[msg_id]
139 return self._records[msg_id]
136
140
137 def update_record(self, msg_id, rec):
141 def update_record(self, msg_id, rec):
138 """Update the data in an existing record."""
142 """Update the data in an existing record."""
139 self._records[msg_id].update(rec)
143 self._records[msg_id].update(rec)
140
144
141 def drop_matching_records(self, check):
145 def drop_matching_records(self, check):
142 """Remove a record from the DB."""
146 """Remove a record from the DB."""
143 matches = self._match(check, id_only=True)
147 matches = self._match(check)
144 for m in matches:
148 for m in matches:
145 del self._records[m]
149 del self._records[m]
146
150
147 def drop_record(self, msg_id):
151 def drop_record(self, msg_id):
148 """Remove a record from the DB."""
152 """Remove a record from the DB."""
149 del self._records[msg_id]
153 del self._records[msg_id]
150
154
151
155
152 def find_records(self, check, id_only=False):
156 def find_records(self, check, keys=None):
153 """Find records matching a query dict."""
157 """Find records matching a query dict, optionally extracting subset of keys.
154 matches = self._match(check, id_only)
158
155 return matches No newline at end of file
159 Returns dict keyed by msg_id of matching records.
160
161 Parameters
162 ----------
163
164 check: dict
165 mongodb-style query argument
166 keys: list of strs [optional]
167 if specified, the subset of keys to extract. msg_id will *always* be
168 included.
169 """
170 matches = self._match(check)
171 if keys:
172 return [ self._extract_subdict(rec, keys) for rec in matches ]
173 else:
174 return matches
175
176
177 def get_history(self):
178 """get all msg_ids, ordered by time submitted."""
179 msg_ids = self._records.keys()
180 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
@@ -1,1095 +1,1193 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """The IPython Controller Hub with 0MQ
2 """The IPython Controller Hub with 0MQ
3 This is the master object that handles connections from engines and clients,
3 This is the master object that handles connections from engines and clients,
4 and monitors traffic through the various queues.
4 and monitors traffic through the various queues.
5 """
5 """
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2010 The IPython Development Team
7 # Copyright (C) 2010 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 from __future__ import print_function
16 from __future__ import print_function
17
17
18 import sys
18 import sys
19 import time
19 import time
20 from datetime import datetime
20 from datetime import datetime
21
21
22 import zmq
22 import zmq
23 from zmq.eventloop import ioloop
23 from zmq.eventloop import ioloop
24 from zmq.eventloop.zmqstream import ZMQStream
24 from zmq.eventloop.zmqstream import ZMQStream
25
25
26 # internal:
26 # internal:
27 from IPython.utils.importstring import import_item
27 from IPython.utils.importstring import import_item
28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
29
29
30 from IPython.parallel import error
30 from IPython.parallel import error, util
31 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
31 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
32 from IPython.parallel.util import select_random_ports, validate_url_container, ISO8601
33
32
34 from .heartmonitor import HeartMonitor
33 from .heartmonitor import HeartMonitor
35
34
36 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
37 # Code
36 # Code
38 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
39
38
40 def _passer(*args, **kwargs):
39 def _passer(*args, **kwargs):
41 return
40 return
42
41
43 def _printer(*args, **kwargs):
42 def _printer(*args, **kwargs):
44 print (args)
43 print (args)
45 print (kwargs)
44 print (kwargs)
46
45
47 def empty_record():
46 def empty_record():
48 """Return an empty dict with all record keys."""
47 """Return an empty dict with all record keys."""
49 return {
48 return {
50 'msg_id' : None,
49 'msg_id' : None,
51 'header' : None,
50 'header' : None,
52 'content': None,
51 'content': None,
53 'buffers': None,
52 'buffers': None,
54 'submitted': None,
53 'submitted': None,
55 'client_uuid' : None,
54 'client_uuid' : None,
56 'engine_uuid' : None,
55 'engine_uuid' : None,
57 'started': None,
56 'started': None,
58 'completed': None,
57 'completed': None,
59 'resubmitted': None,
58 'resubmitted': None,
60 'result_header' : None,
59 'result_header' : None,
61 'result_content' : None,
60 'result_content' : None,
62 'result_buffers' : None,
61 'result_buffers' : None,
63 'queue' : None,
62 'queue' : None,
64 'pyin' : None,
63 'pyin' : None,
65 'pyout': None,
64 'pyout': None,
66 'pyerr': None,
65 'pyerr': None,
67 'stdout': '',
66 'stdout': '',
68 'stderr': '',
67 'stderr': '',
69 }
68 }
70
69
71 def init_record(msg):
70 def init_record(msg):
72 """Initialize a TaskRecord based on a request."""
71 """Initialize a TaskRecord based on a request."""
73 header = msg['header']
72 header = msg['header']
74 return {
73 return {
75 'msg_id' : header['msg_id'],
74 'msg_id' : header['msg_id'],
76 'header' : header,
75 'header' : header,
77 'content': msg['content'],
76 'content': msg['content'],
78 'buffers': msg['buffers'],
77 'buffers': msg['buffers'],
79 'submitted': datetime.strptime(header['date'], ISO8601),
78 'submitted': datetime.strptime(header['date'], util.ISO8601),
80 'client_uuid' : None,
79 'client_uuid' : None,
81 'engine_uuid' : None,
80 'engine_uuid' : None,
82 'started': None,
81 'started': None,
83 'completed': None,
82 'completed': None,
84 'resubmitted': None,
83 'resubmitted': None,
85 'result_header' : None,
84 'result_header' : None,
86 'result_content' : None,
85 'result_content' : None,
87 'result_buffers' : None,
86 'result_buffers' : None,
88 'queue' : None,
87 'queue' : None,
89 'pyin' : None,
88 'pyin' : None,
90 'pyout': None,
89 'pyout': None,
91 'pyerr': None,
90 'pyerr': None,
92 'stdout': '',
91 'stdout': '',
93 'stderr': '',
92 'stderr': '',
94 }
93 }
95
94
96
95
97 class EngineConnector(HasTraits):
96 class EngineConnector(HasTraits):
98 """A simple object for accessing the various zmq connections of an object.
97 """A simple object for accessing the various zmq connections of an object.
99 Attributes are:
98 Attributes are:
100 id (int): engine ID
99 id (int): engine ID
101 uuid (str): uuid (unused?)
100 uuid (str): uuid (unused?)
102 queue (str): identity of queue's XREQ socket
101 queue (str): identity of queue's XREQ socket
103 registration (str): identity of registration XREQ socket
102 registration (str): identity of registration XREQ socket
104 heartbeat (str): identity of heartbeat XREQ socket
103 heartbeat (str): identity of heartbeat XREQ socket
105 """
104 """
106 id=Int(0)
105 id=Int(0)
107 queue=Str()
106 queue=Str()
108 control=Str()
107 control=Str()
109 registration=Str()
108 registration=Str()
110 heartbeat=Str()
109 heartbeat=Str()
111 pending=Set()
110 pending=Set()
112
111
113 class HubFactory(RegistrationFactory):
112 class HubFactory(RegistrationFactory):
114 """The Configurable for setting up a Hub."""
113 """The Configurable for setting up a Hub."""
115
114
116 # name of a scheduler scheme
115 # name of a scheduler scheme
117 scheme = Str('leastload', config=True)
116 scheme = Str('leastload', config=True)
118
117
119 # port-pairs for monitoredqueues:
118 # port-pairs for monitoredqueues:
120 hb = Instance(list, config=True)
119 hb = Instance(list, config=True)
121 def _hb_default(self):
120 def _hb_default(self):
122 return select_random_ports(2)
121 return util.select_random_ports(2)
123
122
124 mux = Instance(list, config=True)
123 mux = Instance(list, config=True)
125 def _mux_default(self):
124 def _mux_default(self):
126 return select_random_ports(2)
125 return util.select_random_ports(2)
127
126
128 task = Instance(list, config=True)
127 task = Instance(list, config=True)
129 def _task_default(self):
128 def _task_default(self):
130 return select_random_ports(2)
129 return util.select_random_ports(2)
131
130
132 control = Instance(list, config=True)
131 control = Instance(list, config=True)
133 def _control_default(self):
132 def _control_default(self):
134 return select_random_ports(2)
133 return util.select_random_ports(2)
135
134
136 iopub = Instance(list, config=True)
135 iopub = Instance(list, config=True)
137 def _iopub_default(self):
136 def _iopub_default(self):
138 return select_random_ports(2)
137 return util.select_random_ports(2)
139
138
140 # single ports:
139 # single ports:
141 mon_port = Instance(int, config=True)
140 mon_port = Instance(int, config=True)
142 def _mon_port_default(self):
141 def _mon_port_default(self):
143 return select_random_ports(1)[0]
142 return util.select_random_ports(1)[0]
144
143
145 notifier_port = Instance(int, config=True)
144 notifier_port = Instance(int, config=True)
146 def _notifier_port_default(self):
145 def _notifier_port_default(self):
147 return select_random_ports(1)[0]
146 return util.select_random_ports(1)[0]
148
147
149 ping = Int(1000, config=True) # ping frequency
148 ping = Int(1000, config=True) # ping frequency
150
149
151 engine_ip = CStr('127.0.0.1', config=True)
150 engine_ip = CStr('127.0.0.1', config=True)
152 engine_transport = CStr('tcp', config=True)
151 engine_transport = CStr('tcp', config=True)
153
152
154 client_ip = CStr('127.0.0.1', config=True)
153 client_ip = CStr('127.0.0.1', config=True)
155 client_transport = CStr('tcp', config=True)
154 client_transport = CStr('tcp', config=True)
156
155
157 monitor_ip = CStr('127.0.0.1', config=True)
156 monitor_ip = CStr('127.0.0.1', config=True)
158 monitor_transport = CStr('tcp', config=True)
157 monitor_transport = CStr('tcp', config=True)
159
158
160 monitor_url = CStr('')
159 monitor_url = CStr('')
161
160
162 db_class = CStr('IPython.parallel.controller.dictdb.DictDB', config=True)
161 db_class = CStr('IPython.parallel.controller.dictdb.DictDB', config=True)
163
162
164 # not configurable
163 # not configurable
165 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
164 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
166 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
165 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
167 subconstructors = List()
166 subconstructors = List()
168 _constructed = Bool(False)
167 _constructed = Bool(False)
169
168
170 def _ip_changed(self, name, old, new):
169 def _ip_changed(self, name, old, new):
171 self.engine_ip = new
170 self.engine_ip = new
172 self.client_ip = new
171 self.client_ip = new
173 self.monitor_ip = new
172 self.monitor_ip = new
174 self._update_monitor_url()
173 self._update_monitor_url()
175
174
176 def _update_monitor_url(self):
175 def _update_monitor_url(self):
177 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
176 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
178
177
179 def _transport_changed(self, name, old, new):
178 def _transport_changed(self, name, old, new):
180 self.engine_transport = new
179 self.engine_transport = new
181 self.client_transport = new
180 self.client_transport = new
182 self.monitor_transport = new
181 self.monitor_transport = new
183 self._update_monitor_url()
182 self._update_monitor_url()
184
183
185 def __init__(self, **kwargs):
184 def __init__(self, **kwargs):
186 super(HubFactory, self).__init__(**kwargs)
185 super(HubFactory, self).__init__(**kwargs)
187 self._update_monitor_url()
186 self._update_monitor_url()
188 # self.on_trait_change(self._sync_ips, 'ip')
187 # self.on_trait_change(self._sync_ips, 'ip')
189 # self.on_trait_change(self._sync_transports, 'transport')
188 # self.on_trait_change(self._sync_transports, 'transport')
190 self.subconstructors.append(self.construct_hub)
189 self.subconstructors.append(self.construct_hub)
191
190
192
191
193 def construct(self):
192 def construct(self):
194 assert not self._constructed, "already constructed!"
193 assert not self._constructed, "already constructed!"
195
194
196 for subc in self.subconstructors:
195 for subc in self.subconstructors:
197 subc()
196 subc()
198
197
199 self._constructed = True
198 self._constructed = True
200
199
201
200
202 def start(self):
201 def start(self):
203 assert self._constructed, "must be constructed by self.construct() first!"
202 assert self._constructed, "must be constructed by self.construct() first!"
204 self.heartmonitor.start()
203 self.heartmonitor.start()
205 self.log.info("Heartmonitor started")
204 self.log.info("Heartmonitor started")
206
205
207 def construct_hub(self):
206 def construct_hub(self):
208 """construct"""
207 """construct"""
209 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
208 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
210 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
209 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
211
210
212 ctx = self.context
211 ctx = self.context
213 loop = self.loop
212 loop = self.loop
214
213
215 # Registrar socket
214 # Registrar socket
216 q = ZMQStream(ctx.socket(zmq.XREP), loop)
215 q = ZMQStream(ctx.socket(zmq.XREP), loop)
217 q.bind(client_iface % self.regport)
216 q.bind(client_iface % self.regport)
218 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
217 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
219 if self.client_ip != self.engine_ip:
218 if self.client_ip != self.engine_ip:
220 q.bind(engine_iface % self.regport)
219 q.bind(engine_iface % self.regport)
221 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
220 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
222
221
223 ### Engine connections ###
222 ### Engine connections ###
224
223
225 # heartbeat
224 # heartbeat
226 hpub = ctx.socket(zmq.PUB)
225 hpub = ctx.socket(zmq.PUB)
227 hpub.bind(engine_iface % self.hb[0])
226 hpub.bind(engine_iface % self.hb[0])
228 hrep = ctx.socket(zmq.XREP)
227 hrep = ctx.socket(zmq.XREP)
229 hrep.bind(engine_iface % self.hb[1])
228 hrep.bind(engine_iface % self.hb[1])
230 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
229 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
231 period=self.ping, logname=self.log.name)
230 period=self.ping, logname=self.log.name)
232
231
233 ### Client connections ###
232 ### Client connections ###
234 # Notifier socket
233 # Notifier socket
235 n = ZMQStream(ctx.socket(zmq.PUB), loop)
234 n = ZMQStream(ctx.socket(zmq.PUB), loop)
236 n.bind(client_iface%self.notifier_port)
235 n.bind(client_iface%self.notifier_port)
237
236
238 ### build and launch the queues ###
237 ### build and launch the queues ###
239
238
240 # monitor socket
239 # monitor socket
241 sub = ctx.socket(zmq.SUB)
240 sub = ctx.socket(zmq.SUB)
242 sub.setsockopt(zmq.SUBSCRIBE, "")
241 sub.setsockopt(zmq.SUBSCRIBE, "")
243 sub.bind(self.monitor_url)
242 sub.bind(self.monitor_url)
244 sub.bind('inproc://monitor')
243 sub.bind('inproc://monitor')
245 sub = ZMQStream(sub, loop)
244 sub = ZMQStream(sub, loop)
246
245
247 # connect the db
246 # connect the db
248 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
247 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
249 # cdir = self.config.Global.cluster_dir
248 # cdir = self.config.Global.cluster_dir
250 self.db = import_item(self.db_class)(session=self.session.session, config=self.config)
249 self.db = import_item(self.db_class)(session=self.session.session, config=self.config)
251 time.sleep(.25)
250 time.sleep(.25)
252
251
253 # build connection dicts
252 # build connection dicts
254 self.engine_info = {
253 self.engine_info = {
255 'control' : engine_iface%self.control[1],
254 'control' : engine_iface%self.control[1],
256 'mux': engine_iface%self.mux[1],
255 'mux': engine_iface%self.mux[1],
257 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
256 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
258 'task' : engine_iface%self.task[1],
257 'task' : engine_iface%self.task[1],
259 'iopub' : engine_iface%self.iopub[1],
258 'iopub' : engine_iface%self.iopub[1],
260 # 'monitor' : engine_iface%self.mon_port,
259 # 'monitor' : engine_iface%self.mon_port,
261 }
260 }
262
261
263 self.client_info = {
262 self.client_info = {
264 'control' : client_iface%self.control[0],
263 'control' : client_iface%self.control[0],
265 'mux': client_iface%self.mux[0],
264 'mux': client_iface%self.mux[0],
266 'task' : (self.scheme, client_iface%self.task[0]),
265 'task' : (self.scheme, client_iface%self.task[0]),
267 'iopub' : client_iface%self.iopub[0],
266 'iopub' : client_iface%self.iopub[0],
268 'notification': client_iface%self.notifier_port
267 'notification': client_iface%self.notifier_port
269 }
268 }
270 self.log.debug("Hub engine addrs: %s"%self.engine_info)
269 self.log.debug("Hub engine addrs: %s"%self.engine_info)
271 self.log.debug("Hub client addrs: %s"%self.client_info)
270 self.log.debug("Hub client addrs: %s"%self.client_info)
272 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
271 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
273 query=q, notifier=n, db=self.db,
272 query=q, notifier=n, db=self.db,
274 engine_info=self.engine_info, client_info=self.client_info,
273 engine_info=self.engine_info, client_info=self.client_info,
275 logname=self.log.name)
274 logname=self.log.name)
276
275
277
276
278 class Hub(LoggingFactory):
277 class Hub(LoggingFactory):
279 """The IPython Controller Hub with 0MQ connections
278 """The IPython Controller Hub with 0MQ connections
280
279
281 Parameters
280 Parameters
282 ==========
281 ==========
283 loop: zmq IOLoop instance
282 loop: zmq IOLoop instance
284 session: StreamSession object
283 session: StreamSession object
285 <removed> context: zmq context for creating new connections (?)
284 <removed> context: zmq context for creating new connections (?)
286 queue: ZMQStream for monitoring the command queue (SUB)
285 queue: ZMQStream for monitoring the command queue (SUB)
287 query: ZMQStream for engine registration and client queries requests (XREP)
286 query: ZMQStream for engine registration and client queries requests (XREP)
288 heartbeat: HeartMonitor object checking the pulse of the engines
287 heartbeat: HeartMonitor object checking the pulse of the engines
289 notifier: ZMQStream for broadcasting engine registration changes (PUB)
288 notifier: ZMQStream for broadcasting engine registration changes (PUB)
290 db: connection to db for out of memory logging of commands
289 db: connection to db for out of memory logging of commands
291 NotImplemented
290 NotImplemented
292 engine_info: dict of zmq connection information for engines to connect
291 engine_info: dict of zmq connection information for engines to connect
293 to the queues.
292 to the queues.
294 client_info: dict of zmq connection information for engines to connect
293 client_info: dict of zmq connection information for engines to connect
295 to the queues.
294 to the queues.
296 """
295 """
297 # internal data structures:
296 # internal data structures:
298 ids=Set() # engine IDs
297 ids=Set() # engine IDs
299 keytable=Dict()
298 keytable=Dict()
300 by_ident=Dict()
299 by_ident=Dict()
301 engines=Dict()
300 engines=Dict()
302 clients=Dict()
301 clients=Dict()
303 hearts=Dict()
302 hearts=Dict()
304 pending=Set()
303 pending=Set()
305 queues=Dict() # pending msg_ids keyed by engine_id
304 queues=Dict() # pending msg_ids keyed by engine_id
306 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
305 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
307 completed=Dict() # completed msg_ids keyed by engine_id
306 completed=Dict() # completed msg_ids keyed by engine_id
308 all_completed=Set() # completed msg_ids keyed by engine_id
307 all_completed=Set() # completed msg_ids keyed by engine_id
309 dead_engines=Set() # completed msg_ids keyed by engine_id
308 dead_engines=Set() # completed msg_ids keyed by engine_id
310 unassigned=Set() # set of task msg_ds not yet assigned a destination
309 unassigned=Set() # set of task msg_ds not yet assigned a destination
311 incoming_registrations=Dict()
310 incoming_registrations=Dict()
312 registration_timeout=Int()
311 registration_timeout=Int()
313 _idcounter=Int(0)
312 _idcounter=Int(0)
314
313
315 # objects from constructor:
314 # objects from constructor:
316 loop=Instance(ioloop.IOLoop)
315 loop=Instance(ioloop.IOLoop)
317 query=Instance(ZMQStream)
316 query=Instance(ZMQStream)
318 monitor=Instance(ZMQStream)
317 monitor=Instance(ZMQStream)
319 heartmonitor=Instance(HeartMonitor)
318 heartmonitor=Instance(HeartMonitor)
320 notifier=Instance(ZMQStream)
319 notifier=Instance(ZMQStream)
321 db=Instance(object)
320 db=Instance(object)
322 client_info=Dict()
321 client_info=Dict()
323 engine_info=Dict()
322 engine_info=Dict()
324
323
325
324
326 def __init__(self, **kwargs):
325 def __init__(self, **kwargs):
327 """
326 """
328 # universal:
327 # universal:
329 loop: IOLoop for creating future connections
328 loop: IOLoop for creating future connections
330 session: streamsession for sending serialized data
329 session: streamsession for sending serialized data
331 # engine:
330 # engine:
332 queue: ZMQStream for monitoring queue messages
331 queue: ZMQStream for monitoring queue messages
333 query: ZMQStream for engine+client registration and client requests
332 query: ZMQStream for engine+client registration and client requests
334 heartbeat: HeartMonitor object for tracking engines
333 heartbeat: HeartMonitor object for tracking engines
335 # extra:
334 # extra:
336 db: ZMQStream for db connection (NotImplemented)
335 db: ZMQStream for db connection (NotImplemented)
337 engine_info: zmq address/protocol dict for engine connections
336 engine_info: zmq address/protocol dict for engine connections
338 client_info: zmq address/protocol dict for client connections
337 client_info: zmq address/protocol dict for client connections
339 """
338 """
340
339
341 super(Hub, self).__init__(**kwargs)
340 super(Hub, self).__init__(**kwargs)
342 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
341 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
343
342
344 # validate connection dicts:
343 # validate connection dicts:
345 for k,v in self.client_info.iteritems():
344 for k,v in self.client_info.iteritems():
346 if k == 'task':
345 if k == 'task':
347 validate_url_container(v[1])
346 util.validate_url_container(v[1])
348 else:
347 else:
349 validate_url_container(v)
348 util.validate_url_container(v)
350 # validate_url_container(self.client_info)
349 # util.validate_url_container(self.client_info)
351 validate_url_container(self.engine_info)
350 util.validate_url_container(self.engine_info)
352
351
353 # register our callbacks
352 # register our callbacks
354 self.query.on_recv(self.dispatch_query)
353 self.query.on_recv(self.dispatch_query)
355 self.monitor.on_recv(self.dispatch_monitor_traffic)
354 self.monitor.on_recv(self.dispatch_monitor_traffic)
356
355
357 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
356 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
358 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
357 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
359
358
360 self.monitor_handlers = { 'in' : self.save_queue_request,
359 self.monitor_handlers = { 'in' : self.save_queue_request,
361 'out': self.save_queue_result,
360 'out': self.save_queue_result,
362 'intask': self.save_task_request,
361 'intask': self.save_task_request,
363 'outtask': self.save_task_result,
362 'outtask': self.save_task_result,
364 'tracktask': self.save_task_destination,
363 'tracktask': self.save_task_destination,
365 'incontrol': _passer,
364 'incontrol': _passer,
366 'outcontrol': _passer,
365 'outcontrol': _passer,
367 'iopub': self.save_iopub_message,
366 'iopub': self.save_iopub_message,
368 }
367 }
369
368
370 self.query_handlers = {'queue_request': self.queue_status,
369 self.query_handlers = {'queue_request': self.queue_status,
371 'result_request': self.get_results,
370 'result_request': self.get_results,
371 'history_request': self.get_history,
372 'db_request': self.db_query,
372 'purge_request': self.purge_results,
373 'purge_request': self.purge_results,
373 'load_request': self.check_load,
374 'load_request': self.check_load,
374 'resubmit_request': self.resubmit_task,
375 'resubmit_request': self.resubmit_task,
375 'shutdown_request': self.shutdown_request,
376 'shutdown_request': self.shutdown_request,
376 'registration_request' : self.register_engine,
377 'registration_request' : self.register_engine,
377 'unregistration_request' : self.unregister_engine,
378 'unregistration_request' : self.unregister_engine,
378 'connection_request': self.connection_request,
379 'connection_request': self.connection_request,
379 }
380 }
380
381
381 self.log.info("hub::created hub")
382 self.log.info("hub::created hub")
382
383
383 @property
384 @property
384 def _next_id(self):
385 def _next_id(self):
385 """gemerate a new ID.
386 """gemerate a new ID.
386
387
387 No longer reuse old ids, just count from 0."""
388 No longer reuse old ids, just count from 0."""
388 newid = self._idcounter
389 newid = self._idcounter
389 self._idcounter += 1
390 self._idcounter += 1
390 return newid
391 return newid
391 # newid = 0
392 # newid = 0
392 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
393 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
393 # # print newid, self.ids, self.incoming_registrations
394 # # print newid, self.ids, self.incoming_registrations
394 # while newid in self.ids or newid in incoming:
395 # while newid in self.ids or newid in incoming:
395 # newid += 1
396 # newid += 1
396 # return newid
397 # return newid
397
398
398 #-----------------------------------------------------------------------------
399 #-----------------------------------------------------------------------------
399 # message validation
400 # message validation
400 #-----------------------------------------------------------------------------
401 #-----------------------------------------------------------------------------
401
402
402 def _validate_targets(self, targets):
403 def _validate_targets(self, targets):
403 """turn any valid targets argument into a list of integer ids"""
404 """turn any valid targets argument into a list of integer ids"""
404 if targets is None:
405 if targets is None:
405 # default to all
406 # default to all
406 targets = self.ids
407 targets = self.ids
407
408
408 if isinstance(targets, (int,str,unicode)):
409 if isinstance(targets, (int,str,unicode)):
409 # only one target specified
410 # only one target specified
410 targets = [targets]
411 targets = [targets]
411 _targets = []
412 _targets = []
412 for t in targets:
413 for t in targets:
413 # map raw identities to ids
414 # map raw identities to ids
414 if isinstance(t, (str,unicode)):
415 if isinstance(t, (str,unicode)):
415 t = self.by_ident.get(t, t)
416 t = self.by_ident.get(t, t)
416 _targets.append(t)
417 _targets.append(t)
417 targets = _targets
418 targets = _targets
418 bad_targets = [ t for t in targets if t not in self.ids ]
419 bad_targets = [ t for t in targets if t not in self.ids ]
419 if bad_targets:
420 if bad_targets:
420 raise IndexError("No Such Engine: %r"%bad_targets)
421 raise IndexError("No Such Engine: %r"%bad_targets)
421 if not targets:
422 if not targets:
422 raise IndexError("No Engines Registered")
423 raise IndexError("No Engines Registered")
423 return targets
424 return targets
424
425
425 #-----------------------------------------------------------------------------
426 #-----------------------------------------------------------------------------
426 # dispatch methods (1 per stream)
427 # dispatch methods (1 per stream)
427 #-----------------------------------------------------------------------------
428 #-----------------------------------------------------------------------------
428
429
429 # def dispatch_registration_request(self, msg):
430 # def dispatch_registration_request(self, msg):
430 # """"""
431 # """"""
431 # self.log.debug("registration::dispatch_register_request(%s)"%msg)
432 # self.log.debug("registration::dispatch_register_request(%s)"%msg)
432 # idents,msg = self.session.feed_identities(msg)
433 # idents,msg = self.session.feed_identities(msg)
433 # if not idents:
434 # if not idents:
434 # self.log.error("Bad Query Message: %s"%msg, exc_info=True)
435 # self.log.error("Bad Query Message: %s"%msg, exc_info=True)
435 # return
436 # return
436 # try:
437 # try:
437 # msg = self.session.unpack_message(msg,content=True)
438 # msg = self.session.unpack_message(msg,content=True)
438 # except:
439 # except:
439 # self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
440 # self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
440 # return
441 # return
441 #
442 #
442 # msg_type = msg['msg_type']
443 # msg_type = msg['msg_type']
443 # content = msg['content']
444 # content = msg['content']
444 #
445 #
445 # handler = self.query_handlers.get(msg_type, None)
446 # handler = self.query_handlers.get(msg_type, None)
446 # if handler is None:
447 # if handler is None:
447 # self.log.error("registration::got bad registration message: %s"%msg)
448 # self.log.error("registration::got bad registration message: %s"%msg)
448 # else:
449 # else:
449 # handler(idents, msg)
450 # handler(idents, msg)
450
451
451 def dispatch_monitor_traffic(self, msg):
452 def dispatch_monitor_traffic(self, msg):
452 """all ME and Task queue messages come through here, as well as
453 """all ME and Task queue messages come through here, as well as
453 IOPub traffic."""
454 IOPub traffic."""
454 self.log.debug("monitor traffic: %s"%msg[:2])
455 self.log.debug("monitor traffic: %s"%msg[:2])
455 switch = msg[0]
456 switch = msg[0]
456 idents, msg = self.session.feed_identities(msg[1:])
457 idents, msg = self.session.feed_identities(msg[1:])
457 if not idents:
458 if not idents:
458 self.log.error("Bad Monitor Message: %s"%msg)
459 self.log.error("Bad Monitor Message: %s"%msg)
459 return
460 return
460 handler = self.monitor_handlers.get(switch, None)
461 handler = self.monitor_handlers.get(switch, None)
461 if handler is not None:
462 if handler is not None:
462 handler(idents, msg)
463 handler(idents, msg)
463 else:
464 else:
464 self.log.error("Invalid monitor topic: %s"%switch)
465 self.log.error("Invalid monitor topic: %s"%switch)
465
466
466
467
467 def dispatch_query(self, msg):
468 def dispatch_query(self, msg):
468 """Route registration requests and queries from clients."""
469 """Route registration requests and queries from clients."""
469 idents, msg = self.session.feed_identities(msg)
470 idents, msg = self.session.feed_identities(msg)
470 if not idents:
471 if not idents:
471 self.log.error("Bad Query Message: %s"%msg)
472 self.log.error("Bad Query Message: %s"%msg)
472 return
473 return
473 client_id = idents[0]
474 client_id = idents[0]
474 try:
475 try:
475 msg = self.session.unpack_message(msg, content=True)
476 msg = self.session.unpack_message(msg, content=True)
476 except:
477 except:
477 content = error.wrap_exception()
478 content = error.wrap_exception()
478 self.log.error("Bad Query Message: %s"%msg, exc_info=True)
479 self.log.error("Bad Query Message: %s"%msg, exc_info=True)
479 self.session.send(self.query, "hub_error", ident=client_id,
480 self.session.send(self.query, "hub_error", ident=client_id,
480 content=content)
481 content=content)
481 return
482 return
482
483
483 # print client_id, header, parent, content
484 # print client_id, header, parent, content
484 #switch on message type:
485 #switch on message type:
485 msg_type = msg['msg_type']
486 msg_type = msg['msg_type']
486 self.log.info("client::client %s requested %s"%(client_id, msg_type))
487 self.log.info("client::client %s requested %s"%(client_id, msg_type))
487 handler = self.query_handlers.get(msg_type, None)
488 handler = self.query_handlers.get(msg_type, None)
488 try:
489 try:
489 assert handler is not None, "Bad Message Type: %s"%msg_type
490 assert handler is not None, "Bad Message Type: %s"%msg_type
490 except:
491 except:
491 content = error.wrap_exception()
492 content = error.wrap_exception()
492 self.log.error("Bad Message Type: %s"%msg_type, exc_info=True)
493 self.log.error("Bad Message Type: %s"%msg_type, exc_info=True)
493 self.session.send(self.query, "hub_error", ident=client_id,
494 self.session.send(self.query, "hub_error", ident=client_id,
494 content=content)
495 content=content)
495 return
496 return
496 else:
497 else:
497 handler(idents, msg)
498 handler(idents, msg)
498
499
499 def dispatch_db(self, msg):
500 def dispatch_db(self, msg):
500 """"""
501 """"""
501 raise NotImplementedError
502 raise NotImplementedError
502
503
503 #---------------------------------------------------------------------------
504 #---------------------------------------------------------------------------
504 # handler methods (1 per event)
505 # handler methods (1 per event)
505 #---------------------------------------------------------------------------
506 #---------------------------------------------------------------------------
506
507
507 #----------------------- Heartbeat --------------------------------------
508 #----------------------- Heartbeat --------------------------------------
508
509
509 def handle_new_heart(self, heart):
510 def handle_new_heart(self, heart):
510 """handler to attach to heartbeater.
511 """handler to attach to heartbeater.
511 Called when a new heart starts to beat.
512 Called when a new heart starts to beat.
512 Triggers completion of registration."""
513 Triggers completion of registration."""
513 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
514 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
514 if heart not in self.incoming_registrations:
515 if heart not in self.incoming_registrations:
515 self.log.info("heartbeat::ignoring new heart: %r"%heart)
516 self.log.info("heartbeat::ignoring new heart: %r"%heart)
516 else:
517 else:
517 self.finish_registration(heart)
518 self.finish_registration(heart)
518
519
519
520
520 def handle_heart_failure(self, heart):
521 def handle_heart_failure(self, heart):
521 """handler to attach to heartbeater.
522 """handler to attach to heartbeater.
522 called when a previously registered heart fails to respond to beat request.
523 called when a previously registered heart fails to respond to beat request.
523 triggers unregistration"""
524 triggers unregistration"""
524 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
525 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
525 eid = self.hearts.get(heart, None)
526 eid = self.hearts.get(heart, None)
526 queue = self.engines[eid].queue
527 queue = self.engines[eid].queue
527 if eid is None:
528 if eid is None:
528 self.log.info("heartbeat::ignoring heart failure %r"%heart)
529 self.log.info("heartbeat::ignoring heart failure %r"%heart)
529 else:
530 else:
530 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
531 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
531
532
532 #----------------------- MUX Queue Traffic ------------------------------
533 #----------------------- MUX Queue Traffic ------------------------------
533
534
534 def save_queue_request(self, idents, msg):
535 def save_queue_request(self, idents, msg):
535 if len(idents) < 2:
536 if len(idents) < 2:
536 self.log.error("invalid identity prefix: %s"%idents)
537 self.log.error("invalid identity prefix: %s"%idents)
537 return
538 return
538 queue_id, client_id = idents[:2]
539 queue_id, client_id = idents[:2]
539 try:
540 try:
540 msg = self.session.unpack_message(msg, content=False)
541 msg = self.session.unpack_message(msg, content=False)
541 except:
542 except:
542 self.log.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
543 self.log.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
543 return
544 return
544
545
545 eid = self.by_ident.get(queue_id, None)
546 eid = self.by_ident.get(queue_id, None)
546 if eid is None:
547 if eid is None:
547 self.log.error("queue::target %r not registered"%queue_id)
548 self.log.error("queue::target %r not registered"%queue_id)
548 self.log.debug("queue:: valid are: %s"%(self.by_ident.keys()))
549 self.log.debug("queue:: valid are: %s"%(self.by_ident.keys()))
549 return
550 return
550
551
551 header = msg['header']
552 header = msg['header']
552 msg_id = header['msg_id']
553 msg_id = header['msg_id']
553 record = init_record(msg)
554 record = init_record(msg)
554 record['engine_uuid'] = queue_id
555 record['engine_uuid'] = queue_id
555 record['client_uuid'] = client_id
556 record['client_uuid'] = client_id
556 record['queue'] = 'mux'
557 record['queue'] = 'mux'
557
558
558 try:
559 try:
559 # it's posible iopub arrived first:
560 # it's posible iopub arrived first:
560 existing = self.db.get_record(msg_id)
561 existing = self.db.get_record(msg_id)
561 for key,evalue in existing.iteritems():
562 for key,evalue in existing.iteritems():
562 rvalue = record[key]
563 rvalue = record[key]
563 if evalue and rvalue and evalue != rvalue:
564 if evalue and rvalue and evalue != rvalue:
564 self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
565 self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
565 elif evalue and not rvalue:
566 elif evalue and not rvalue:
566 record[key] = evalue
567 record[key] = evalue
567 self.db.update_record(msg_id, record)
568 self.db.update_record(msg_id, record)
568 except KeyError:
569 except KeyError:
569 self.db.add_record(msg_id, record)
570 self.db.add_record(msg_id, record)
570
571
571 self.pending.add(msg_id)
572 self.pending.add(msg_id)
572 self.queues[eid].append(msg_id)
573 self.queues[eid].append(msg_id)
573
574
574 def save_queue_result(self, idents, msg):
575 def save_queue_result(self, idents, msg):
575 if len(idents) < 2:
576 if len(idents) < 2:
576 self.log.error("invalid identity prefix: %s"%idents)
577 self.log.error("invalid identity prefix: %s"%idents)
577 return
578 return
578
579
579 client_id, queue_id = idents[:2]
580 client_id, queue_id = idents[:2]
580 try:
581 try:
581 msg = self.session.unpack_message(msg, content=False)
582 msg = self.session.unpack_message(msg, content=False)
582 except:
583 except:
583 self.log.error("queue::engine %r sent invalid message to %r: %s"%(
584 self.log.error("queue::engine %r sent invalid message to %r: %s"%(
584 queue_id,client_id, msg), exc_info=True)
585 queue_id,client_id, msg), exc_info=True)
585 return
586 return
586
587
587 eid = self.by_ident.get(queue_id, None)
588 eid = self.by_ident.get(queue_id, None)
588 if eid is None:
589 if eid is None:
589 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
590 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
590 # self.log.debug("queue:: %s"%msg[2:])
591 # self.log.debug("queue:: %s"%msg[2:])
591 return
592 return
592
593
593 parent = msg['parent_header']
594 parent = msg['parent_header']
594 if not parent:
595 if not parent:
595 return
596 return
596 msg_id = parent['msg_id']
597 msg_id = parent['msg_id']
597 if msg_id in self.pending:
598 if msg_id in self.pending:
598 self.pending.remove(msg_id)
599 self.pending.remove(msg_id)
599 self.all_completed.add(msg_id)
600 self.all_completed.add(msg_id)
600 self.queues[eid].remove(msg_id)
601 self.queues[eid].remove(msg_id)
601 self.completed[eid].append(msg_id)
602 self.completed[eid].append(msg_id)
602 elif msg_id not in self.all_completed:
603 elif msg_id not in self.all_completed:
603 # it could be a result from a dead engine that died before delivering the
604 # it could be a result from a dead engine that died before delivering the
604 # result
605 # result
605 self.log.warn("queue:: unknown msg finished %s"%msg_id)
606 self.log.warn("queue:: unknown msg finished %s"%msg_id)
606 return
607 return
607 # update record anyway, because the unregistration could have been premature
608 # update record anyway, because the unregistration could have been premature
608 rheader = msg['header']
609 rheader = msg['header']
609 completed = datetime.strptime(rheader['date'], ISO8601)
610 completed = datetime.strptime(rheader['date'], util.ISO8601)
610 started = rheader.get('started', None)
611 started = rheader.get('started', None)
611 if started is not None:
612 if started is not None:
612 started = datetime.strptime(started, ISO8601)
613 started = datetime.strptime(started, util.ISO8601)
613 result = {
614 result = {
614 'result_header' : rheader,
615 'result_header' : rheader,
615 'result_content': msg['content'],
616 'result_content': msg['content'],
616 'started' : started,
617 'started' : started,
617 'completed' : completed
618 'completed' : completed
618 }
619 }
619
620
620 result['result_buffers'] = msg['buffers']
621 result['result_buffers'] = msg['buffers']
621 self.db.update_record(msg_id, result)
622 try:
623 self.db.update_record(msg_id, result)
624 except Exception:
625 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
622
626
623
627
624 #--------------------- Task Queue Traffic ------------------------------
628 #--------------------- Task Queue Traffic ------------------------------
625
629
626 def save_task_request(self, idents, msg):
630 def save_task_request(self, idents, msg):
627 """Save the submission of a task."""
631 """Save the submission of a task."""
628 client_id = idents[0]
632 client_id = idents[0]
629
633
630 try:
634 try:
631 msg = self.session.unpack_message(msg, content=False)
635 msg = self.session.unpack_message(msg, content=False)
632 except:
636 except:
633 self.log.error("task::client %r sent invalid task message: %s"%(
637 self.log.error("task::client %r sent invalid task message: %s"%(
634 client_id, msg), exc_info=True)
638 client_id, msg), exc_info=True)
635 return
639 return
636 record = init_record(msg)
640 record = init_record(msg)
637
641
638 record['client_uuid'] = client_id
642 record['client_uuid'] = client_id
639 record['queue'] = 'task'
643 record['queue'] = 'task'
640 header = msg['header']
644 header = msg['header']
641 msg_id = header['msg_id']
645 msg_id = header['msg_id']
642 self.pending.add(msg_id)
646 self.pending.add(msg_id)
643 self.unassigned.add(msg_id)
647 self.unassigned.add(msg_id)
644 try:
648 try:
645 # it's posible iopub arrived first:
649 # it's posible iopub arrived first:
646 existing = self.db.get_record(msg_id)
650 existing = self.db.get_record(msg_id)
647 for key,evalue in existing.iteritems():
651 for key,evalue in existing.iteritems():
648 rvalue = record[key]
652 rvalue = record[key]
649 if evalue and rvalue and evalue != rvalue:
653 if evalue and rvalue and evalue != rvalue:
650 self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
654 self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
651 elif evalue and not rvalue:
655 elif evalue and not rvalue:
652 record[key] = evalue
656 record[key] = evalue
653 self.db.update_record(msg_id, record)
657 self.db.update_record(msg_id, record)
654 except KeyError:
658 except KeyError:
655 self.db.add_record(msg_id, record)
659 self.db.add_record(msg_id, record)
660 except Exception:
661 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
656
662
657 def save_task_result(self, idents, msg):
663 def save_task_result(self, idents, msg):
658 """save the result of a completed task."""
664 """save the result of a completed task."""
659 client_id = idents[0]
665 client_id = idents[0]
660 try:
666 try:
661 msg = self.session.unpack_message(msg, content=False)
667 msg = self.session.unpack_message(msg, content=False)
662 except:
668 except:
663 self.log.error("task::invalid task result message send to %r: %s"%(
669 self.log.error("task::invalid task result message send to %r: %s"%(
664 client_id, msg), exc_info=True)
670 client_id, msg), exc_info=True)
665 raise
671 raise
666 return
672 return
667
673
668 parent = msg['parent_header']
674 parent = msg['parent_header']
669 if not parent:
675 if not parent:
670 # print msg
676 # print msg
671 self.log.warn("Task %r had no parent!"%msg)
677 self.log.warn("Task %r had no parent!"%msg)
672 return
678 return
673 msg_id = parent['msg_id']
679 msg_id = parent['msg_id']
674 if msg_id in self.unassigned:
680 if msg_id in self.unassigned:
675 self.unassigned.remove(msg_id)
681 self.unassigned.remove(msg_id)
676
682
677 header = msg['header']
683 header = msg['header']
678 engine_uuid = header.get('engine', None)
684 engine_uuid = header.get('engine', None)
679 eid = self.by_ident.get(engine_uuid, None)
685 eid = self.by_ident.get(engine_uuid, None)
680
686
681 if msg_id in self.pending:
687 if msg_id in self.pending:
682 self.pending.remove(msg_id)
688 self.pending.remove(msg_id)
683 self.all_completed.add(msg_id)
689 self.all_completed.add(msg_id)
684 if eid is not None:
690 if eid is not None:
685 self.completed[eid].append(msg_id)
691 self.completed[eid].append(msg_id)
686 if msg_id in self.tasks[eid]:
692 if msg_id in self.tasks[eid]:
687 self.tasks[eid].remove(msg_id)
693 self.tasks[eid].remove(msg_id)
688 completed = datetime.strptime(header['date'], ISO8601)
694 completed = datetime.strptime(header['date'], util.ISO8601)
689 started = header.get('started', None)
695 started = header.get('started', None)
690 if started is not None:
696 if started is not None:
691 started = datetime.strptime(started, ISO8601)
697 started = datetime.strptime(started, util.ISO8601)
692 result = {
698 result = {
693 'result_header' : header,
699 'result_header' : header,
694 'result_content': msg['content'],
700 'result_content': msg['content'],
695 'started' : started,
701 'started' : started,
696 'completed' : completed,
702 'completed' : completed,
697 'engine_uuid': engine_uuid
703 'engine_uuid': engine_uuid
698 }
704 }
699
705
700 result['result_buffers'] = msg['buffers']
706 result['result_buffers'] = msg['buffers']
701 self.db.update_record(msg_id, result)
707 try:
708 self.db.update_record(msg_id, result)
709 except Exception:
710 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
702
711
703 else:
712 else:
704 self.log.debug("task::unknown task %s finished"%msg_id)
713 self.log.debug("task::unknown task %s finished"%msg_id)
705
714
706 def save_task_destination(self, idents, msg):
715 def save_task_destination(self, idents, msg):
707 try:
716 try:
708 msg = self.session.unpack_message(msg, content=True)
717 msg = self.session.unpack_message(msg, content=True)
709 except:
718 except:
710 self.log.error("task::invalid task tracking message", exc_info=True)
719 self.log.error("task::invalid task tracking message", exc_info=True)
711 return
720 return
712 content = msg['content']
721 content = msg['content']
713 # print (content)
722 # print (content)
714 msg_id = content['msg_id']
723 msg_id = content['msg_id']
715 engine_uuid = content['engine_id']
724 engine_uuid = content['engine_id']
716 eid = self.by_ident[engine_uuid]
725 eid = self.by_ident[engine_uuid]
717
726
718 self.log.info("task::task %s arrived on %s"%(msg_id, eid))
727 self.log.info("task::task %s arrived on %s"%(msg_id, eid))
719 if msg_id in self.unassigned:
728 if msg_id in self.unassigned:
720 self.unassigned.remove(msg_id)
729 self.unassigned.remove(msg_id)
721 # else:
730 # else:
722 # self.log.debug("task::task %s not listed as MIA?!"%(msg_id))
731 # self.log.debug("task::task %s not listed as MIA?!"%(msg_id))
723
732
724 self.tasks[eid].append(msg_id)
733 self.tasks[eid].append(msg_id)
725 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
734 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
726 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
735 try:
736 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
737 except Exception:
738 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
739
727
740
728 def mia_task_request(self, idents, msg):
741 def mia_task_request(self, idents, msg):
729 raise NotImplementedError
742 raise NotImplementedError
730 client_id = idents[0]
743 client_id = idents[0]
731 # content = dict(mia=self.mia,status='ok')
744 # content = dict(mia=self.mia,status='ok')
732 # self.session.send('mia_reply', content=content, idents=client_id)
745 # self.session.send('mia_reply', content=content, idents=client_id)
733
746
734
747
735 #--------------------- IOPub Traffic ------------------------------
748 #--------------------- IOPub Traffic ------------------------------
736
749
737 def save_iopub_message(self, topics, msg):
750 def save_iopub_message(self, topics, msg):
738 """save an iopub message into the db"""
751 """save an iopub message into the db"""
739 # print (topics)
752 # print (topics)
740 try:
753 try:
741 msg = self.session.unpack_message(msg, content=True)
754 msg = self.session.unpack_message(msg, content=True)
742 except:
755 except:
743 self.log.error("iopub::invalid IOPub message", exc_info=True)
756 self.log.error("iopub::invalid IOPub message", exc_info=True)
744 return
757 return
745
758
746 parent = msg['parent_header']
759 parent = msg['parent_header']
747 if not parent:
760 if not parent:
748 self.log.error("iopub::invalid IOPub message: %s"%msg)
761 self.log.error("iopub::invalid IOPub message: %s"%msg)
749 return
762 return
750 msg_id = parent['msg_id']
763 msg_id = parent['msg_id']
751 msg_type = msg['msg_type']
764 msg_type = msg['msg_type']
752 content = msg['content']
765 content = msg['content']
753
766
754 # ensure msg_id is in db
767 # ensure msg_id is in db
755 try:
768 try:
756 rec = self.db.get_record(msg_id)
769 rec = self.db.get_record(msg_id)
757 except KeyError:
770 except KeyError:
758 rec = empty_record()
771 rec = empty_record()
759 rec['msg_id'] = msg_id
772 rec['msg_id'] = msg_id
760 self.db.add_record(msg_id, rec)
773 self.db.add_record(msg_id, rec)
761 # stream
774 # stream
762 d = {}
775 d = {}
763 if msg_type == 'stream':
776 if msg_type == 'stream':
764 name = content['name']
777 name = content['name']
765 s = rec[name] or ''
778 s = rec[name] or ''
766 d[name] = s + content['data']
779 d[name] = s + content['data']
767
780
768 elif msg_type == 'pyerr':
781 elif msg_type == 'pyerr':
769 d['pyerr'] = content
782 d['pyerr'] = content
770 elif msg_type == 'pyin':
783 elif msg_type == 'pyin':
771 d['pyin'] = content['code']
784 d['pyin'] = content['code']
772 else:
785 else:
773 d[msg_type] = content.get('data', '')
786 d[msg_type] = content.get('data', '')
774
787
775 self.db.update_record(msg_id, d)
788 try:
789 self.db.update_record(msg_id, d)
790 except Exception:
791 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
776
792
777
793
778
794
779 #-------------------------------------------------------------------------
795 #-------------------------------------------------------------------------
780 # Registration requests
796 # Registration requests
781 #-------------------------------------------------------------------------
797 #-------------------------------------------------------------------------
782
798
783 def connection_request(self, client_id, msg):
799 def connection_request(self, client_id, msg):
784 """Reply with connection addresses for clients."""
800 """Reply with connection addresses for clients."""
785 self.log.info("client::client %s connected"%client_id)
801 self.log.info("client::client %s connected"%client_id)
786 content = dict(status='ok')
802 content = dict(status='ok')
787 content.update(self.client_info)
803 content.update(self.client_info)
788 jsonable = {}
804 jsonable = {}
789 for k,v in self.keytable.iteritems():
805 for k,v in self.keytable.iteritems():
790 if v not in self.dead_engines:
806 if v not in self.dead_engines:
791 jsonable[str(k)] = v
807 jsonable[str(k)] = v
792 content['engines'] = jsonable
808 content['engines'] = jsonable
793 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
809 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
794
810
795 def register_engine(self, reg, msg):
811 def register_engine(self, reg, msg):
796 """Register a new engine."""
812 """Register a new engine."""
797 content = msg['content']
813 content = msg['content']
798 try:
814 try:
799 queue = content['queue']
815 queue = content['queue']
800 except KeyError:
816 except KeyError:
801 self.log.error("registration::queue not specified", exc_info=True)
817 self.log.error("registration::queue not specified", exc_info=True)
802 return
818 return
803 heart = content.get('heartbeat', None)
819 heart = content.get('heartbeat', None)
804 """register a new engine, and create the socket(s) necessary"""
820 """register a new engine, and create the socket(s) necessary"""
805 eid = self._next_id
821 eid = self._next_id
806 # print (eid, queue, reg, heart)
822 # print (eid, queue, reg, heart)
807
823
808 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
824 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
809
825
810 content = dict(id=eid,status='ok')
826 content = dict(id=eid,status='ok')
811 content.update(self.engine_info)
827 content.update(self.engine_info)
812 # check if requesting available IDs:
828 # check if requesting available IDs:
813 if queue in self.by_ident:
829 if queue in self.by_ident:
814 try:
830 try:
815 raise KeyError("queue_id %r in use"%queue)
831 raise KeyError("queue_id %r in use"%queue)
816 except:
832 except:
817 content = error.wrap_exception()
833 content = error.wrap_exception()
818 self.log.error("queue_id %r in use"%queue, exc_info=True)
834 self.log.error("queue_id %r in use"%queue, exc_info=True)
819 elif heart in self.hearts: # need to check unique hearts?
835 elif heart in self.hearts: # need to check unique hearts?
820 try:
836 try:
821 raise KeyError("heart_id %r in use"%heart)
837 raise KeyError("heart_id %r in use"%heart)
822 except:
838 except:
823 self.log.error("heart_id %r in use"%heart, exc_info=True)
839 self.log.error("heart_id %r in use"%heart, exc_info=True)
824 content = error.wrap_exception()
840 content = error.wrap_exception()
825 else:
841 else:
826 for h, pack in self.incoming_registrations.iteritems():
842 for h, pack in self.incoming_registrations.iteritems():
827 if heart == h:
843 if heart == h:
828 try:
844 try:
829 raise KeyError("heart_id %r in use"%heart)
845 raise KeyError("heart_id %r in use"%heart)
830 except:
846 except:
831 self.log.error("heart_id %r in use"%heart, exc_info=True)
847 self.log.error("heart_id %r in use"%heart, exc_info=True)
832 content = error.wrap_exception()
848 content = error.wrap_exception()
833 break
849 break
834 elif queue == pack[1]:
850 elif queue == pack[1]:
835 try:
851 try:
836 raise KeyError("queue_id %r in use"%queue)
852 raise KeyError("queue_id %r in use"%queue)
837 except:
853 except:
838 self.log.error("queue_id %r in use"%queue, exc_info=True)
854 self.log.error("queue_id %r in use"%queue, exc_info=True)
839 content = error.wrap_exception()
855 content = error.wrap_exception()
840 break
856 break
841
857
842 msg = self.session.send(self.query, "registration_reply",
858 msg = self.session.send(self.query, "registration_reply",
843 content=content,
859 content=content,
844 ident=reg)
860 ident=reg)
845
861
846 if content['status'] == 'ok':
862 if content['status'] == 'ok':
847 if heart in self.heartmonitor.hearts:
863 if heart in self.heartmonitor.hearts:
848 # already beating
864 # already beating
849 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
865 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
850 self.finish_registration(heart)
866 self.finish_registration(heart)
851 else:
867 else:
852 purge = lambda : self._purge_stalled_registration(heart)
868 purge = lambda : self._purge_stalled_registration(heart)
853 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
869 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
854 dc.start()
870 dc.start()
855 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
871 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
856 else:
872 else:
857 self.log.error("registration::registration %i failed: %s"%(eid, content['evalue']))
873 self.log.error("registration::registration %i failed: %s"%(eid, content['evalue']))
858 return eid
874 return eid
859
875
860 def unregister_engine(self, ident, msg):
876 def unregister_engine(self, ident, msg):
861 """Unregister an engine that explicitly requested to leave."""
877 """Unregister an engine that explicitly requested to leave."""
862 try:
878 try:
863 eid = msg['content']['id']
879 eid = msg['content']['id']
864 except:
880 except:
865 self.log.error("registration::bad engine id for unregistration: %s"%ident, exc_info=True)
881 self.log.error("registration::bad engine id for unregistration: %s"%ident, exc_info=True)
866 return
882 return
867 self.log.info("registration::unregister_engine(%s)"%eid)
883 self.log.info("registration::unregister_engine(%s)"%eid)
868 # print (eid)
884 # print (eid)
869 uuid = self.keytable[eid]
885 uuid = self.keytable[eid]
870 content=dict(id=eid, queue=uuid)
886 content=dict(id=eid, queue=uuid)
871 self.dead_engines.add(uuid)
887 self.dead_engines.add(uuid)
872 # self.ids.remove(eid)
888 # self.ids.remove(eid)
873 # uuid = self.keytable.pop(eid)
889 # uuid = self.keytable.pop(eid)
874 #
890 #
875 # ec = self.engines.pop(eid)
891 # ec = self.engines.pop(eid)
876 # self.hearts.pop(ec.heartbeat)
892 # self.hearts.pop(ec.heartbeat)
877 # self.by_ident.pop(ec.queue)
893 # self.by_ident.pop(ec.queue)
878 # self.completed.pop(eid)
894 # self.completed.pop(eid)
879 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
895 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
880 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
896 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
881 dc.start()
897 dc.start()
882 ############## TODO: HANDLE IT ################
898 ############## TODO: HANDLE IT ################
883
899
884 if self.notifier:
900 if self.notifier:
885 self.session.send(self.notifier, "unregistration_notification", content=content)
901 self.session.send(self.notifier, "unregistration_notification", content=content)
886
902
887 def _handle_stranded_msgs(self, eid, uuid):
903 def _handle_stranded_msgs(self, eid, uuid):
888 """Handle messages known to be on an engine when the engine unregisters.
904 """Handle messages known to be on an engine when the engine unregisters.
889
905
890 It is possible that this will fire prematurely - that is, an engine will
906 It is possible that this will fire prematurely - that is, an engine will
891 go down after completing a result, and the client will be notified
907 go down after completing a result, and the client will be notified
892 that the result failed and later receive the actual result.
908 that the result failed and later receive the actual result.
893 """
909 """
894
910
895 outstanding = self.queues[eid]
911 outstanding = self.queues[eid]
896
912
897 for msg_id in outstanding:
913 for msg_id in outstanding:
898 self.pending.remove(msg_id)
914 self.pending.remove(msg_id)
899 self.all_completed.add(msg_id)
915 self.all_completed.add(msg_id)
900 try:
916 try:
901 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
917 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
902 except:
918 except:
903 content = error.wrap_exception()
919 content = error.wrap_exception()
904 # build a fake header:
920 # build a fake header:
905 header = {}
921 header = {}
906 header['engine'] = uuid
922 header['engine'] = uuid
907 header['date'] = datetime.now().strftime(ISO8601)
923 header['date'] = datetime.now()
908 rec = dict(result_content=content, result_header=header, result_buffers=[])
924 rec = dict(result_content=content, result_header=header, result_buffers=[])
909 rec['completed'] = header['date']
925 rec['completed'] = header['date']
910 rec['engine_uuid'] = uuid
926 rec['engine_uuid'] = uuid
911 self.db.update_record(msg_id, rec)
927 try:
928 self.db.update_record(msg_id, rec)
929 except Exception:
930 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
931
912
932
913 def finish_registration(self, heart):
933 def finish_registration(self, heart):
914 """Second half of engine registration, called after our HeartMonitor
934 """Second half of engine registration, called after our HeartMonitor
915 has received a beat from the Engine's Heart."""
935 has received a beat from the Engine's Heart."""
916 try:
936 try:
917 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
937 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
918 except KeyError:
938 except KeyError:
919 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
939 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
920 return
940 return
921 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
941 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
922 if purge is not None:
942 if purge is not None:
923 purge.stop()
943 purge.stop()
924 control = queue
944 control = queue
925 self.ids.add(eid)
945 self.ids.add(eid)
926 self.keytable[eid] = queue
946 self.keytable[eid] = queue
927 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
947 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
928 control=control, heartbeat=heart)
948 control=control, heartbeat=heart)
929 self.by_ident[queue] = eid
949 self.by_ident[queue] = eid
930 self.queues[eid] = list()
950 self.queues[eid] = list()
931 self.tasks[eid] = list()
951 self.tasks[eid] = list()
932 self.completed[eid] = list()
952 self.completed[eid] = list()
933 self.hearts[heart] = eid
953 self.hearts[heart] = eid
934 content = dict(id=eid, queue=self.engines[eid].queue)
954 content = dict(id=eid, queue=self.engines[eid].queue)
935 if self.notifier:
955 if self.notifier:
936 self.session.send(self.notifier, "registration_notification", content=content)
956 self.session.send(self.notifier, "registration_notification", content=content)
937 self.log.info("engine::Engine Connected: %i"%eid)
957 self.log.info("engine::Engine Connected: %i"%eid)
938
958
939 def _purge_stalled_registration(self, heart):
959 def _purge_stalled_registration(self, heart):
940 if heart in self.incoming_registrations:
960 if heart in self.incoming_registrations:
941 eid = self.incoming_registrations.pop(heart)[0]
961 eid = self.incoming_registrations.pop(heart)[0]
942 self.log.info("registration::purging stalled registration: %i"%eid)
962 self.log.info("registration::purging stalled registration: %i"%eid)
943 else:
963 else:
944 pass
964 pass
945
965
946 #-------------------------------------------------------------------------
966 #-------------------------------------------------------------------------
947 # Client Requests
967 # Client Requests
948 #-------------------------------------------------------------------------
968 #-------------------------------------------------------------------------
949
969
950 def shutdown_request(self, client_id, msg):
970 def shutdown_request(self, client_id, msg):
951 """handle shutdown request."""
971 """handle shutdown request."""
952 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
972 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
953 # also notify other clients of shutdown
973 # also notify other clients of shutdown
954 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
974 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
955 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
975 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
956 dc.start()
976 dc.start()
957
977
958 def _shutdown(self):
978 def _shutdown(self):
959 self.log.info("hub::hub shutting down.")
979 self.log.info("hub::hub shutting down.")
960 time.sleep(0.1)
980 time.sleep(0.1)
961 sys.exit(0)
981 sys.exit(0)
962
982
963
983
964 def check_load(self, client_id, msg):
984 def check_load(self, client_id, msg):
965 content = msg['content']
985 content = msg['content']
966 try:
986 try:
967 targets = content['targets']
987 targets = content['targets']
968 targets = self._validate_targets(targets)
988 targets = self._validate_targets(targets)
969 except:
989 except:
970 content = error.wrap_exception()
990 content = error.wrap_exception()
971 self.session.send(self.query, "hub_error",
991 self.session.send(self.query, "hub_error",
972 content=content, ident=client_id)
992 content=content, ident=client_id)
973 return
993 return
974
994
975 content = dict(status='ok')
995 content = dict(status='ok')
976 # loads = {}
996 # loads = {}
977 for t in targets:
997 for t in targets:
978 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
998 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
979 self.session.send(self.query, "load_reply", content=content, ident=client_id)
999 self.session.send(self.query, "load_reply", content=content, ident=client_id)
980
1000
981
1001
982 def queue_status(self, client_id, msg):
1002 def queue_status(self, client_id, msg):
983 """Return the Queue status of one or more targets.
1003 """Return the Queue status of one or more targets.
984 if verbose: return the msg_ids
1004 if verbose: return the msg_ids
985 else: return len of each type.
1005 else: return len of each type.
986 keys: queue (pending MUX jobs)
1006 keys: queue (pending MUX jobs)
987 tasks (pending Task jobs)
1007 tasks (pending Task jobs)
988 completed (finished jobs from both queues)"""
1008 completed (finished jobs from both queues)"""
989 content = msg['content']
1009 content = msg['content']
990 targets = content['targets']
1010 targets = content['targets']
991 try:
1011 try:
992 targets = self._validate_targets(targets)
1012 targets = self._validate_targets(targets)
993 except:
1013 except:
994 content = error.wrap_exception()
1014 content = error.wrap_exception()
995 self.session.send(self.query, "hub_error",
1015 self.session.send(self.query, "hub_error",
996 content=content, ident=client_id)
1016 content=content, ident=client_id)
997 return
1017 return
998 verbose = content.get('verbose', False)
1018 verbose = content.get('verbose', False)
999 content = dict(status='ok')
1019 content = dict(status='ok')
1000 for t in targets:
1020 for t in targets:
1001 queue = self.queues[t]
1021 queue = self.queues[t]
1002 completed = self.completed[t]
1022 completed = self.completed[t]
1003 tasks = self.tasks[t]
1023 tasks = self.tasks[t]
1004 if not verbose:
1024 if not verbose:
1005 queue = len(queue)
1025 queue = len(queue)
1006 completed = len(completed)
1026 completed = len(completed)
1007 tasks = len(tasks)
1027 tasks = len(tasks)
1008 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1028 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1009 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1029 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1010
1030
1011 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1031 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1012
1032
1013 def purge_results(self, client_id, msg):
1033 def purge_results(self, client_id, msg):
1014 """Purge results from memory. This method is more valuable before we move
1034 """Purge results from memory. This method is more valuable before we move
1015 to a DB based message storage mechanism."""
1035 to a DB based message storage mechanism."""
1016 content = msg['content']
1036 content = msg['content']
1017 msg_ids = content.get('msg_ids', [])
1037 msg_ids = content.get('msg_ids', [])
1018 reply = dict(status='ok')
1038 reply = dict(status='ok')
1019 if msg_ids == 'all':
1039 if msg_ids == 'all':
1020 self.db.drop_matching_records(dict(completed={'$ne':None}))
1040 try:
1041 self.db.drop_matching_records(dict(completed={'$ne':None}))
1042 except Exception:
1043 reply = error.wrap_exception()
1021 else:
1044 else:
1022 for msg_id in msg_ids:
1045 for msg_id in msg_ids:
1023 if msg_id in self.all_completed:
1046 if msg_id in self.all_completed:
1024 self.db.drop_record(msg_id)
1047 self.db.drop_record(msg_id)
1025 else:
1048 else:
1026 if msg_id in self.pending:
1049 if msg_id in self.pending:
1027 try:
1050 try:
1028 raise IndexError("msg pending: %r"%msg_id)
1051 raise IndexError("msg pending: %r"%msg_id)
1029 except:
1052 except:
1030 reply = error.wrap_exception()
1053 reply = error.wrap_exception()
1031 else:
1054 else:
1032 try:
1055 try:
1033 raise IndexError("No such msg: %r"%msg_id)
1056 raise IndexError("No such msg: %r"%msg_id)
1034 except:
1057 except:
1035 reply = error.wrap_exception()
1058 reply = error.wrap_exception()
1036 break
1059 break
1037 eids = content.get('engine_ids', [])
1060 eids = content.get('engine_ids', [])
1038 for eid in eids:
1061 for eid in eids:
1039 if eid not in self.engines:
1062 if eid not in self.engines:
1040 try:
1063 try:
1041 raise IndexError("No such engine: %i"%eid)
1064 raise IndexError("No such engine: %i"%eid)
1042 except:
1065 except:
1043 reply = error.wrap_exception()
1066 reply = error.wrap_exception()
1044 break
1067 break
1045 msg_ids = self.completed.pop(eid)
1068 msg_ids = self.completed.pop(eid)
1046 uid = self.engines[eid].queue
1069 uid = self.engines[eid].queue
1047 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1070 try:
1071 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1072 except Exception:
1073 reply = error.wrap_exception()
1074 break
1048
1075
1049 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1076 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1050
1077
1051 def resubmit_task(self, client_id, msg, buffers):
1078 def resubmit_task(self, client_id, msg, buffers):
1052 """Resubmit a task."""
1079 """Resubmit a task."""
1053 raise NotImplementedError
1080 raise NotImplementedError
1054
1081
1082 def _extract_record(self, rec):
1083 """decompose a TaskRecord dict into subsection of reply for get_result"""
1084 io_dict = {}
1085 for key in 'pyin pyout pyerr stdout stderr'.split():
1086 io_dict[key] = rec[key]
1087 content = { 'result_content': rec['result_content'],
1088 'header': rec['header'],
1089 'result_header' : rec['result_header'],
1090 'io' : io_dict,
1091 }
1092 if rec['result_buffers']:
1093 buffers = map(str, rec['result_buffers'])
1094 else:
1095 buffers = []
1096
1097 return content, buffers
1098
1055 def get_results(self, client_id, msg):
1099 def get_results(self, client_id, msg):
1056 """Get the result of 1 or more messages."""
1100 """Get the result of 1 or more messages."""
1057 content = msg['content']
1101 content = msg['content']
1058 msg_ids = sorted(set(content['msg_ids']))
1102 msg_ids = sorted(set(content['msg_ids']))
1059 statusonly = content.get('status_only', False)
1103 statusonly = content.get('status_only', False)
1060 pending = []
1104 pending = []
1061 completed = []
1105 completed = []
1062 content = dict(status='ok')
1106 content = dict(status='ok')
1063 content['pending'] = pending
1107 content['pending'] = pending
1064 content['completed'] = completed
1108 content['completed'] = completed
1065 buffers = []
1109 buffers = []
1066 if not statusonly:
1110 if not statusonly:
1067 content['results'] = {}
1111 try:
1068 records = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1112 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1113 # turn match list into dict, for faster lookup
1114 records = {}
1115 for rec in matches:
1116 records[rec['msg_id']] = rec
1117 except Exception:
1118 content = error.wrap_exception()
1119 self.session.send(self.query, "result_reply", content=content,
1120 parent=msg, ident=client_id)
1121 return
1122 else:
1123 records = {}
1069 for msg_id in msg_ids:
1124 for msg_id in msg_ids:
1070 if msg_id in self.pending:
1125 if msg_id in self.pending:
1071 pending.append(msg_id)
1126 pending.append(msg_id)
1072 elif msg_id in self.all_completed:
1127 elif msg_id in self.all_completed or msg_id in records:
1073 completed.append(msg_id)
1128 completed.append(msg_id)
1074 if not statusonly:
1129 if not statusonly:
1075 rec = records[msg_id]
1130 c,bufs = self._extract_record(records[msg_id])
1076 io_dict = {}
1131 content[msg_id] = c
1077 for key in 'pyin pyout pyerr stdout stderr'.split():
1132 buffers.extend(bufs)
1078 io_dict[key] = rec[key]
1079 content[msg_id] = { 'result_content': rec['result_content'],
1080 'header': rec['header'],
1081 'result_header' : rec['result_header'],
1082 'io' : io_dict,
1083 }
1084 if rec['result_buffers']:
1085 buffers.extend(map(str, rec['result_buffers']))
1086 else:
1133 else:
1087 try:
1134 try:
1088 raise KeyError('No such message: '+msg_id)
1135 raise KeyError('No such message: '+msg_id)
1089 except:
1136 except:
1090 content = error.wrap_exception()
1137 content = error.wrap_exception()
1091 break
1138 break
1092 self.session.send(self.query, "result_reply", content=content,
1139 self.session.send(self.query, "result_reply", content=content,
1093 parent=msg, ident=client_id,
1140 parent=msg, ident=client_id,
1094 buffers=buffers)
1141 buffers=buffers)
1095
1142
1143 def get_history(self, client_id, msg):
1144 """Get a list of all msg_ids in our DB records"""
1145 try:
1146 msg_ids = self.db.get_history()
1147 except Exception as e:
1148 content = error.wrap_exception()
1149 else:
1150 content = dict(status='ok', history=msg_ids)
1151
1152 self.session.send(self.query, "history_reply", content=content,
1153 parent=msg, ident=client_id)
1154
1155 def db_query(self, client_id, msg):
1156 """Perform a raw query on the task record database."""
1157 content = msg['content']
1158 query = content.get('query', {})
1159 keys = content.get('keys', None)
1160 query = util.extract_dates(query)
1161 buffers = []
1162 empty = list()
1163
1164 try:
1165 records = self.db.find_records(query, keys)
1166 except Exception as e:
1167 content = error.wrap_exception()
1168 else:
1169 # extract buffers from reply content:
1170 if keys is not None:
1171 buffer_lens = [] if 'buffers' in keys else None
1172 result_buffer_lens = [] if 'result_buffers' in keys else None
1173 else:
1174 buffer_lens = []
1175 result_buffer_lens = []
1176
1177 for rec in records:
1178 # buffers may be None, so double check
1179 if buffer_lens is not None:
1180 b = rec.pop('buffers', empty) or empty
1181 buffer_lens.append(len(b))
1182 buffers.extend(b)
1183 if result_buffer_lens is not None:
1184 rb = rec.pop('result_buffers', empty) or empty
1185 result_buffer_lens.append(len(rb))
1186 buffers.extend(rb)
1187 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1188 result_buffer_lens=result_buffer_lens)
1189
1190 self.session.send(self.query, "db_reply", content=content,
1191 parent=msg, ident=client_id,
1192 buffers=buffers)
1193
@@ -1,80 +1,96 b''
1 """A TaskRecord backend using mongodb"""
1 """A TaskRecord backend using mongodb"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 from datetime import datetime
9 from datetime import datetime
10
10
11 from pymongo import Connection
11 from pymongo import Connection
12 from pymongo.binary import Binary
12 from pymongo.binary import Binary
13
13
14 from IPython.utils.traitlets import Dict, List, CUnicode
14 from IPython.utils.traitlets import Dict, List, CUnicode
15
15
16 from .dictdb import BaseDB
16 from .dictdb import BaseDB
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # MongoDB class
19 # MongoDB class
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21
21
22 class MongoDB(BaseDB):
22 class MongoDB(BaseDB):
23 """MongoDB TaskRecord backend."""
23 """MongoDB TaskRecord backend."""
24
24
25 connection_args = List(config=True)
25 connection_args = List(config=True) # args passed to pymongo.Connection
26 connection_kwargs = Dict(config=True)
26 connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection
27 database = CUnicode(config=True)
27 database = CUnicode(config=True) # name of the mongodb database
28 _table = Dict()
28 _table = Dict()
29
29
30 def __init__(self, **kwargs):
30 def __init__(self, **kwargs):
31 super(MongoDB, self).__init__(**kwargs)
31 super(MongoDB, self).__init__(**kwargs)
32 self._connection = Connection(*self.connection_args, **self.connection_kwargs)
32 self._connection = Connection(*self.connection_args, **self.connection_kwargs)
33 if not self.database:
33 if not self.database:
34 self.database = self.session
34 self.database = self.session
35 self._db = self._connection[self.database]
35 self._db = self._connection[self.database]
36 self._records = self._db['task_records']
36 self._records = self._db['task_records']
37
37
38 def _binary_buffers(self, rec):
38 def _binary_buffers(self, rec):
39 for key in ('buffers', 'result_buffers'):
39 for key in ('buffers', 'result_buffers'):
40 if key in rec:
40 if rec.get(key, None):
41 rec[key] = map(Binary, rec[key])
41 rec[key] = map(Binary, rec[key])
42 return rec
42
43
43 def add_record(self, msg_id, rec):
44 def add_record(self, msg_id, rec):
44 """Add a new Task Record, by msg_id."""
45 """Add a new Task Record, by msg_id."""
45 # print rec
46 # print rec
46 rec = _binary_buffers(rec)
47 rec = self._binary_buffers(rec)
47 obj_id = self._records.insert(rec)
48 obj_id = self._records.insert(rec)
48 self._table[msg_id] = obj_id
49 self._table[msg_id] = obj_id
49
50
50 def get_record(self, msg_id):
51 def get_record(self, msg_id):
51 """Get a specific Task Record, by msg_id."""
52 """Get a specific Task Record, by msg_id."""
52 return self._records.find_one(self._table[msg_id])
53 return self._records.find_one(self._table[msg_id])
53
54
54 def update_record(self, msg_id, rec):
55 def update_record(self, msg_id, rec):
55 """Update the data in an existing record."""
56 """Update the data in an existing record."""
56 rec = _binary_buffers(rec)
57 rec = self._binary_buffers(rec)
57 obj_id = self._table[msg_id]
58 obj_id = self._table[msg_id]
58 self._records.update({'_id':obj_id}, {'$set': rec})
59 self._records.update({'_id':obj_id}, {'$set': rec})
59
60
60 def drop_matching_records(self, check):
61 def drop_matching_records(self, check):
61 """Remove a record from the DB."""
62 """Remove a record from the DB."""
62 self._records.remove(check)
63 self._records.remove(check)
63
64
64 def drop_record(self, msg_id):
65 def drop_record(self, msg_id):
65 """Remove a record from the DB."""
66 """Remove a record from the DB."""
66 obj_id = self._table.pop(msg_id)
67 obj_id = self._table.pop(msg_id)
67 self._records.remove(obj_id)
68 self._records.remove(obj_id)
68
69
69 def find_records(self, check, id_only=False):
70 def find_records(self, check, keys=None):
70 """Find records matching a query dict."""
71 """Find records matching a query dict, optionally extracting subset of keys.
71 matches = list(self._records.find(check))
72
72 if id_only:
73 Returns list of matching records.
73 return [ rec['msg_id'] for rec in matches ]
74
74 else:
75 Parameters
75 data = {}
76 ----------
76 for rec in matches:
77
77 data[rec['msg_id']] = rec
78 check: dict
78 return data
79 mongodb-style query argument
80 keys: list of strs [optional]
81 if specified, the subset of keys to extract. msg_id will *always* be
82 included.
83 """
84 if keys and 'msg_id' not in keys:
85 keys.append('msg_id')
86 matches = list(self._records.find(check,keys))
87 for rec in matches:
88 rec.pop('_id')
89 return matches
90
91 def get_history(self):
92 """get all msg_ids, ordered by time submitted."""
93 cursor = self._records.find({},{'msg_id':1}).sort('submitted')
94 return [ rec['msg_id'] for rec in cursor ]
79
95
80
96
@@ -1,284 +1,312 b''
1 """A TaskRecord backend using sqlite3"""
1 """A TaskRecord backend using sqlite3"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2011 The IPython Development Team
3 # Copyright (C) 2011 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 import json
9 import json
10 import os
10 import os
11 import cPickle as pickle
11 import cPickle as pickle
12 from datetime import datetime
12 from datetime import datetime
13
13
14 import sqlite3
14 import sqlite3
15
15
16 from zmq.eventloop import ioloop
16 from zmq.eventloop import ioloop
17
17
18 from IPython.utils.traitlets import CUnicode, CStr, Instance, List
18 from IPython.utils.traitlets import CUnicode, CStr, Instance, List
19 from .dictdb import BaseDB
19 from .dictdb import BaseDB
20 from IPython.parallel.util import ISO8601
20 from IPython.parallel.util import ISO8601
21
21
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23 # SQLite operators, adapters, and converters
23 # SQLite operators, adapters, and converters
24 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
25
25
26 operators = {
26 operators = {
27 '$lt' : lambda a,b: "%s < ?",
27 '$lt' : "<",
28 '$gt' : ">",
28 '$gt' : ">",
29 # null is handled weird with ==,!=
29 # null is handled weird with ==,!=
30 '$eq' : "IS",
30 '$eq' : "IS",
31 '$ne' : "IS NOT",
31 '$ne' : "IS NOT",
32 '$lte': "<=",
32 '$lte': "<=",
33 '$gte': ">=",
33 '$gte': ">=",
34 '$in' : ('IS', ' OR '),
34 '$in' : ('IS', ' OR '),
35 '$nin': ('IS NOT', ' AND '),
35 '$nin': ('IS NOT', ' AND '),
36 # '$all': None,
36 # '$all': None,
37 # '$mod': None,
37 # '$mod': None,
38 # '$exists' : None
38 # '$exists' : None
39 }
39 }
40
40
41 def _adapt_datetime(dt):
41 def _adapt_datetime(dt):
42 return dt.strftime(ISO8601)
42 return dt.strftime(ISO8601)
43
43
44 def _convert_datetime(ds):
44 def _convert_datetime(ds):
45 if ds is None:
45 if ds is None:
46 return ds
46 return ds
47 else:
47 else:
48 return datetime.strptime(ds, ISO8601)
48 return datetime.strptime(ds, ISO8601)
49
49
50 def _adapt_dict(d):
50 def _adapt_dict(d):
51 return json.dumps(d)
51 return json.dumps(d)
52
52
53 def _convert_dict(ds):
53 def _convert_dict(ds):
54 if ds is None:
54 if ds is None:
55 return ds
55 return ds
56 else:
56 else:
57 return json.loads(ds)
57 return json.loads(ds)
58
58
59 def _adapt_bufs(bufs):
59 def _adapt_bufs(bufs):
60 # this is *horrible*
60 # this is *horrible*
61 # copy buffers into single list and pickle it:
61 # copy buffers into single list and pickle it:
62 if bufs and isinstance(bufs[0], (bytes, buffer)):
62 if bufs and isinstance(bufs[0], (bytes, buffer)):
63 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
63 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
64 elif bufs:
64 elif bufs:
65 return bufs
65 return bufs
66 else:
66 else:
67 return None
67 return None
68
68
69 def _convert_bufs(bs):
69 def _convert_bufs(bs):
70 if bs is None:
70 if bs is None:
71 return []
71 return []
72 else:
72 else:
73 return pickle.loads(bytes(bs))
73 return pickle.loads(bytes(bs))
74
74
75 #-----------------------------------------------------------------------------
75 #-----------------------------------------------------------------------------
76 # SQLiteDB class
76 # SQLiteDB class
77 #-----------------------------------------------------------------------------
77 #-----------------------------------------------------------------------------
78
78
79 class SQLiteDB(BaseDB):
79 class SQLiteDB(BaseDB):
80 """SQLite3 TaskRecord backend."""
80 """SQLite3 TaskRecord backend."""
81
81
82 filename = CUnicode('tasks.db', config=True)
82 filename = CUnicode('tasks.db', config=True)
83 location = CUnicode('', config=True)
83 location = CUnicode('', config=True)
84 table = CUnicode("", config=True)
84 table = CUnicode("", config=True)
85
85
86 _db = Instance('sqlite3.Connection')
86 _db = Instance('sqlite3.Connection')
87 _keys = List(['msg_id' ,
87 _keys = List(['msg_id' ,
88 'header' ,
88 'header' ,
89 'content',
89 'content',
90 'buffers',
90 'buffers',
91 'submitted',
91 'submitted',
92 'client_uuid' ,
92 'client_uuid' ,
93 'engine_uuid' ,
93 'engine_uuid' ,
94 'started',
94 'started',
95 'completed',
95 'completed',
96 'resubmitted',
96 'resubmitted',
97 'result_header' ,
97 'result_header' ,
98 'result_content' ,
98 'result_content' ,
99 'result_buffers' ,
99 'result_buffers' ,
100 'queue' ,
100 'queue' ,
101 'pyin' ,
101 'pyin' ,
102 'pyout',
102 'pyout',
103 'pyerr',
103 'pyerr',
104 'stdout',
104 'stdout',
105 'stderr',
105 'stderr',
106 ])
106 ])
107
107
108 def __init__(self, **kwargs):
108 def __init__(self, **kwargs):
109 super(SQLiteDB, self).__init__(**kwargs)
109 super(SQLiteDB, self).__init__(**kwargs)
110 if not self.table:
110 if not self.table:
111 # use session, and prefix _, since starting with # is illegal
111 # use session, and prefix _, since starting with # is illegal
112 self.table = '_'+self.session.replace('-','_')
112 self.table = '_'+self.session.replace('-','_')
113 if not self.location:
113 if not self.location:
114 if hasattr(self.config.Global, 'cluster_dir'):
114 if hasattr(self.config.Global, 'cluster_dir'):
115 self.location = self.config.Global.cluster_dir
115 self.location = self.config.Global.cluster_dir
116 else:
116 else:
117 self.location = '.'
117 self.location = '.'
118 self._init_db()
118 self._init_db()
119
119
120 # register db commit as 2s periodic callback
120 # register db commit as 2s periodic callback
121 # to prevent clogging pipes
121 # to prevent clogging pipes
122 # assumes we are being run in a zmq ioloop app
122 # assumes we are being run in a zmq ioloop app
123 loop = ioloop.IOLoop.instance()
123 loop = ioloop.IOLoop.instance()
124 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
124 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
125 pc.start()
125 pc.start()
126
126
127 def _defaults(self):
127 def _defaults(self, keys=None):
128 """create an empty record"""
128 """create an empty record"""
129 d = {}
129 d = {}
130 for key in self._keys:
130 keys = self._keys if keys is None else keys
131 for key in keys:
131 d[key] = None
132 d[key] = None
132 return d
133 return d
133
134
134 def _init_db(self):
135 def _init_db(self):
135 """Connect to the database and get new session number."""
136 """Connect to the database and get new session number."""
136 # register adapters
137 # register adapters
137 sqlite3.register_adapter(datetime, _adapt_datetime)
138 sqlite3.register_adapter(datetime, _adapt_datetime)
138 sqlite3.register_converter('datetime', _convert_datetime)
139 sqlite3.register_converter('datetime', _convert_datetime)
139 sqlite3.register_adapter(dict, _adapt_dict)
140 sqlite3.register_adapter(dict, _adapt_dict)
140 sqlite3.register_converter('dict', _convert_dict)
141 sqlite3.register_converter('dict', _convert_dict)
141 sqlite3.register_adapter(list, _adapt_bufs)
142 sqlite3.register_adapter(list, _adapt_bufs)
142 sqlite3.register_converter('bufs', _convert_bufs)
143 sqlite3.register_converter('bufs', _convert_bufs)
143 # connect to the db
144 # connect to the db
144 dbfile = os.path.join(self.location, self.filename)
145 dbfile = os.path.join(self.location, self.filename)
145 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
146 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
146 # isolation_level = None)#,
147 # isolation_level = None)#,
147 cached_statements=64)
148 cached_statements=64)
148 # print dir(self._db)
149 # print dir(self._db)
149
150
150 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
151 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
151 (msg_id text PRIMARY KEY,
152 (msg_id text PRIMARY KEY,
152 header dict text,
153 header dict text,
153 content dict text,
154 content dict text,
154 buffers bufs blob,
155 buffers bufs blob,
155 submitted datetime text,
156 submitted datetime text,
156 client_uuid text,
157 client_uuid text,
157 engine_uuid text,
158 engine_uuid text,
158 started datetime text,
159 started datetime text,
159 completed datetime text,
160 completed datetime text,
160 resubmitted datetime text,
161 resubmitted datetime text,
161 result_header dict text,
162 result_header dict text,
162 result_content dict text,
163 result_content dict text,
163 result_buffers bufs blob,
164 result_buffers bufs blob,
164 queue text,
165 queue text,
165 pyin text,
166 pyin text,
166 pyout text,
167 pyout text,
167 pyerr text,
168 pyerr text,
168 stdout text,
169 stdout text,
169 stderr text)
170 stderr text)
170 """%self.table)
171 """%self.table)
171 # self._db.execute("""CREATE TABLE IF NOT EXISTS %s_buffers
172 # (msg_id text, result integer, buffer blob)
173 # """%self.table)
174 self._db.commit()
172 self._db.commit()
175
173
176 def _dict_to_list(self, d):
174 def _dict_to_list(self, d):
177 """turn a mongodb-style record dict into a list."""
175 """turn a mongodb-style record dict into a list."""
178
176
179 return [ d[key] for key in self._keys ]
177 return [ d[key] for key in self._keys ]
180
178
181 def _list_to_dict(self, line):
179 def _list_to_dict(self, line, keys=None):
182 """Inverse of dict_to_list"""
180 """Inverse of dict_to_list"""
183 d = self._defaults()
181 keys = self._keys if keys is None else keys
184 for key,value in zip(self._keys, line):
182 d = self._defaults(keys)
183 for key,value in zip(keys, line):
185 d[key] = value
184 d[key] = value
186
185
187 return d
186 return d
188
187
189 def _render_expression(self, check):
188 def _render_expression(self, check):
190 """Turn a mongodb-style search dict into an SQL query."""
189 """Turn a mongodb-style search dict into an SQL query."""
191 expressions = []
190 expressions = []
192 args = []
191 args = []
193
192
194 skeys = set(check.keys())
193 skeys = set(check.keys())
195 skeys.difference_update(set(self._keys))
194 skeys.difference_update(set(self._keys))
196 skeys.difference_update(set(['buffers', 'result_buffers']))
195 skeys.difference_update(set(['buffers', 'result_buffers']))
197 if skeys:
196 if skeys:
198 raise KeyError("Illegal testing key(s): %s"%skeys)
197 raise KeyError("Illegal testing key(s): %s"%skeys)
199
198
200 for name,sub_check in check.iteritems():
199 for name,sub_check in check.iteritems():
201 if isinstance(sub_check, dict):
200 if isinstance(sub_check, dict):
202 for test,value in sub_check.iteritems():
201 for test,value in sub_check.iteritems():
203 try:
202 try:
204 op = operators[test]
203 op = operators[test]
205 except KeyError:
204 except KeyError:
206 raise KeyError("Unsupported operator: %r"%test)
205 raise KeyError("Unsupported operator: %r"%test)
207 if isinstance(op, tuple):
206 if isinstance(op, tuple):
208 op, join = op
207 op, join = op
209 expr = "%s %s ?"%(name, op)
208 expr = "%s %s ?"%(name, op)
210 if isinstance(value, (tuple,list)):
209 if isinstance(value, (tuple,list)):
211 expr = '( %s )'%( join.join([expr]*len(value)) )
210 expr = '( %s )'%( join.join([expr]*len(value)) )
212 args.extend(value)
211 args.extend(value)
213 else:
212 else:
214 args.append(value)
213 args.append(value)
215 expressions.append(expr)
214 expressions.append(expr)
216 else:
215 else:
217 # it's an equality check
216 # it's an equality check
218 expressions.append("%s IS ?"%name)
217 expressions.append("%s IS ?"%name)
219 args.append(sub_check)
218 args.append(sub_check)
220
219
221 expr = " AND ".join(expressions)
220 expr = " AND ".join(expressions)
222 return expr, args
221 return expr, args
223
222
224 def add_record(self, msg_id, rec):
223 def add_record(self, msg_id, rec):
225 """Add a new Task Record, by msg_id."""
224 """Add a new Task Record, by msg_id."""
226 d = self._defaults()
225 d = self._defaults()
227 d.update(rec)
226 d.update(rec)
228 d['msg_id'] = msg_id
227 d['msg_id'] = msg_id
229 line = self._dict_to_list(d)
228 line = self._dict_to_list(d)
230 tups = '(%s)'%(','.join(['?']*len(line)))
229 tups = '(%s)'%(','.join(['?']*len(line)))
231 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
230 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
232 # self._db.commit()
231 # self._db.commit()
233
232
234 def get_record(self, msg_id):
233 def get_record(self, msg_id):
235 """Get a specific Task Record, by msg_id."""
234 """Get a specific Task Record, by msg_id."""
236 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
235 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
237 line = cursor.fetchone()
236 line = cursor.fetchone()
238 if line is None:
237 if line is None:
239 raise KeyError("No such msg: %r"%msg_id)
238 raise KeyError("No such msg: %r"%msg_id)
240 return self._list_to_dict(line)
239 return self._list_to_dict(line)
241
240
242 def update_record(self, msg_id, rec):
241 def update_record(self, msg_id, rec):
243 """Update the data in an existing record."""
242 """Update the data in an existing record."""
244 query = "UPDATE %s SET "%self.table
243 query = "UPDATE %s SET "%self.table
245 sets = []
244 sets = []
246 keys = sorted(rec.keys())
245 keys = sorted(rec.keys())
247 values = []
246 values = []
248 for key in keys:
247 for key in keys:
249 sets.append('%s = ?'%key)
248 sets.append('%s = ?'%key)
250 values.append(rec[key])
249 values.append(rec[key])
251 query += ', '.join(sets)
250 query += ', '.join(sets)
252 query += ' WHERE msg_id == %r'%msg_id
251 query += ' WHERE msg_id == ?'
252 values.append(msg_id)
253 self._db.execute(query, values)
253 self._db.execute(query, values)
254 # self._db.commit()
254 # self._db.commit()
255
255
256 def drop_record(self, msg_id):
256 def drop_record(self, msg_id):
257 """Remove a record from the DB."""
257 """Remove a record from the DB."""
258 self._db.execute("""DELETE FROM %s WHERE mgs_id==?"""%self.table, (msg_id,))
258 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
259 # self._db.commit()
259 # self._db.commit()
260
260
261 def drop_matching_records(self, check):
261 def drop_matching_records(self, check):
262 """Remove a record from the DB."""
262 """Remove a record from the DB."""
263 expr,args = self._render_expression(check)
263 expr,args = self._render_expression(check)
264 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
264 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
265 self._db.execute(query,args)
265 self._db.execute(query,args)
266 # self._db.commit()
266 # self._db.commit()
267
267
268 def find_records(self, check, id_only=False):
268 def find_records(self, check, keys=None):
269 """Find records matching a query dict."""
269 """Find records matching a query dict, optionally extracting subset of keys.
270 req = 'msg_id' if id_only else '*'
270
271 Returns list of matching records.
272
273 Parameters
274 ----------
275
276 check: dict
277 mongodb-style query argument
278 keys: list of strs [optional]
279 if specified, the subset of keys to extract. msg_id will *always* be
280 included.
281 """
282 if keys:
283 bad_keys = [ key for key in keys if key not in self._keys ]
284 if bad_keys:
285 raise KeyError("Bad record key(s): %s"%bad_keys)
286
287 if keys:
288 # ensure msg_id is present and first:
289 if 'msg_id' in keys:
290 keys.remove('msg_id')
291 keys.insert(0, 'msg_id')
292 req = ', '.join(keys)
293 else:
294 req = '*'
271 expr,args = self._render_expression(check)
295 expr,args = self._render_expression(check)
272 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
296 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
273 cursor = self._db.execute(query, args)
297 cursor = self._db.execute(query, args)
274 matches = cursor.fetchall()
298 matches = cursor.fetchall()
275 if id_only:
299 records = []
276 return [ m[0] for m in matches ]
300 for line in matches:
277 else:
301 rec = self._list_to_dict(line, keys)
278 records = {}
302 records.append(rec)
279 for line in matches:
303 return records
280 rec = self._list_to_dict(line)
304
281 records[rec['msg_id']] = rec
305 def get_history(self):
282 return records
306 """get all msg_ids, ordered by time submitted."""
307 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
308 cursor = self._db.execute(query)
309 # will be a list of length 1 tuples
310 return [ tup[0] for tup in cursor.fetchall()]
283
311
284 __all__ = ['SQLiteDB'] No newline at end of file
312 __all__ = ['SQLiteDB']
@@ -1,402 +1,410 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """edited session.py to work with streams, and move msg_type to the header
2 """edited session.py to work with streams, and move msg_type to the header
3 """
3 """
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2010-2011 The IPython Development Team
5 # Copyright (C) 2010-2011 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11
11
12 import os
12 import os
13 import pprint
13 import pprint
14 import uuid
14 import uuid
15 from datetime import datetime
15 from datetime import datetime
16
16
17 try:
17 try:
18 import cPickle
18 import cPickle
19 pickle = cPickle
19 pickle = cPickle
20 except:
20 except:
21 cPickle = None
21 cPickle = None
22 import pickle
22 import pickle
23
23
24 import zmq
24 import zmq
25 from zmq.utils import jsonapi
25 from zmq.utils import jsonapi
26 from zmq.eventloop.zmqstream import ZMQStream
26 from zmq.eventloop.zmqstream import ZMQStream
27
27
28 from .util import ISO8601
28 from .util import ISO8601
29
29
30 def squash_unicode(obj):
30 def squash_unicode(obj):
31 """coerce unicode back to bytestrings."""
31 if isinstance(obj,dict):
32 if isinstance(obj,dict):
32 for key in obj.keys():
33 for key in obj.keys():
33 obj[key] = squash_unicode(obj[key])
34 obj[key] = squash_unicode(obj[key])
34 if isinstance(key, unicode):
35 if isinstance(key, unicode):
35 obj[squash_unicode(key)] = obj.pop(key)
36 obj[squash_unicode(key)] = obj.pop(key)
36 elif isinstance(obj, list):
37 elif isinstance(obj, list):
37 for i,v in enumerate(obj):
38 for i,v in enumerate(obj):
38 obj[i] = squash_unicode(v)
39 obj[i] = squash_unicode(v)
39 elif isinstance(obj, unicode):
40 elif isinstance(obj, unicode):
40 obj = obj.encode('utf8')
41 obj = obj.encode('utf8')
41 return obj
42 return obj
42
43
43 json_packer = jsonapi.dumps
44 def _date_default(obj):
45 if isinstance(obj, datetime):
46 return obj.strftime(ISO8601)
47 else:
48 raise TypeError("%r is not JSON serializable"%obj)
49
50 _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
51 json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:_date_default})
44 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
52 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
45
53
46 pickle_packer = lambda o: pickle.dumps(o,-1)
54 pickle_packer = lambda o: pickle.dumps(o,-1)
47 pickle_unpacker = pickle.loads
55 pickle_unpacker = pickle.loads
48
56
49 default_packer = json_packer
57 default_packer = json_packer
50 default_unpacker = json_unpacker
58 default_unpacker = json_unpacker
51
59
52
60
53 DELIM="<IDS|MSG>"
61 DELIM="<IDS|MSG>"
54
62
55 class Message(object):
63 class Message(object):
56 """A simple message object that maps dict keys to attributes.
64 """A simple message object that maps dict keys to attributes.
57
65
58 A Message can be created from a dict and a dict from a Message instance
66 A Message can be created from a dict and a dict from a Message instance
59 simply by calling dict(msg_obj)."""
67 simply by calling dict(msg_obj)."""
60
68
61 def __init__(self, msg_dict):
69 def __init__(self, msg_dict):
62 dct = self.__dict__
70 dct = self.__dict__
63 for k, v in dict(msg_dict).iteritems():
71 for k, v in dict(msg_dict).iteritems():
64 if isinstance(v, dict):
72 if isinstance(v, dict):
65 v = Message(v)
73 v = Message(v)
66 dct[k] = v
74 dct[k] = v
67
75
68 # Having this iterator lets dict(msg_obj) work out of the box.
76 # Having this iterator lets dict(msg_obj) work out of the box.
69 def __iter__(self):
77 def __iter__(self):
70 return iter(self.__dict__.iteritems())
78 return iter(self.__dict__.iteritems())
71
79
72 def __repr__(self):
80 def __repr__(self):
73 return repr(self.__dict__)
81 return repr(self.__dict__)
74
82
75 def __str__(self):
83 def __str__(self):
76 return pprint.pformat(self.__dict__)
84 return pprint.pformat(self.__dict__)
77
85
78 def __contains__(self, k):
86 def __contains__(self, k):
79 return k in self.__dict__
87 return k in self.__dict__
80
88
81 def __getitem__(self, k):
89 def __getitem__(self, k):
82 return self.__dict__[k]
90 return self.__dict__[k]
83
91
84
92
85 def msg_header(msg_id, msg_type, username, session):
93 def msg_header(msg_id, msg_type, username, session):
86 date=datetime.now().strftime(ISO8601)
94 date=datetime.now().strftime(ISO8601)
87 return locals()
95 return locals()
88
96
89 def extract_header(msg_or_header):
97 def extract_header(msg_or_header):
90 """Given a message or header, return the header."""
98 """Given a message or header, return the header."""
91 if not msg_or_header:
99 if not msg_or_header:
92 return {}
100 return {}
93 try:
101 try:
94 # See if msg_or_header is the entire message.
102 # See if msg_or_header is the entire message.
95 h = msg_or_header['header']
103 h = msg_or_header['header']
96 except KeyError:
104 except KeyError:
97 try:
105 try:
98 # See if msg_or_header is just the header
106 # See if msg_or_header is just the header
99 h = msg_or_header['msg_id']
107 h = msg_or_header['msg_id']
100 except KeyError:
108 except KeyError:
101 raise
109 raise
102 else:
110 else:
103 h = msg_or_header
111 h = msg_or_header
104 if not isinstance(h, dict):
112 if not isinstance(h, dict):
105 h = dict(h)
113 h = dict(h)
106 return h
114 return h
107
115
108 class StreamSession(object):
116 class StreamSession(object):
109 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
117 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
110 debug=False
118 debug=False
111 key=None
119 key=None
112
120
113 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
121 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
114 if username is None:
122 if username is None:
115 username = os.environ.get('USER','username')
123 username = os.environ.get('USER','username')
116 self.username = username
124 self.username = username
117 if session is None:
125 if session is None:
118 self.session = str(uuid.uuid4())
126 self.session = str(uuid.uuid4())
119 else:
127 else:
120 self.session = session
128 self.session = session
121 self.msg_id = str(uuid.uuid4())
129 self.msg_id = str(uuid.uuid4())
122 if packer is None:
130 if packer is None:
123 self.pack = default_packer
131 self.pack = default_packer
124 else:
132 else:
125 if not callable(packer):
133 if not callable(packer):
126 raise TypeError("packer must be callable, not %s"%type(packer))
134 raise TypeError("packer must be callable, not %s"%type(packer))
127 self.pack = packer
135 self.pack = packer
128
136
129 if unpacker is None:
137 if unpacker is None:
130 self.unpack = default_unpacker
138 self.unpack = default_unpacker
131 else:
139 else:
132 if not callable(unpacker):
140 if not callable(unpacker):
133 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
141 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
134 self.unpack = unpacker
142 self.unpack = unpacker
135
143
136 if key is not None and keyfile is not None:
144 if key is not None and keyfile is not None:
137 raise TypeError("Must specify key OR keyfile, not both")
145 raise TypeError("Must specify key OR keyfile, not both")
138 if keyfile is not None:
146 if keyfile is not None:
139 with open(keyfile) as f:
147 with open(keyfile) as f:
140 self.key = f.read().strip()
148 self.key = f.read().strip()
141 else:
149 else:
142 self.key = key
150 self.key = key
143 if isinstance(self.key, unicode):
151 if isinstance(self.key, unicode):
144 self.key = self.key.encode('utf8')
152 self.key = self.key.encode('utf8')
145 # print key, keyfile, self.key
153 # print key, keyfile, self.key
146 self.none = self.pack({})
154 self.none = self.pack({})
147
155
148 def msg_header(self, msg_type):
156 def msg_header(self, msg_type):
149 h = msg_header(self.msg_id, msg_type, self.username, self.session)
157 h = msg_header(self.msg_id, msg_type, self.username, self.session)
150 self.msg_id = str(uuid.uuid4())
158 self.msg_id = str(uuid.uuid4())
151 return h
159 return h
152
160
153 def msg(self, msg_type, content=None, parent=None, subheader=None):
161 def msg(self, msg_type, content=None, parent=None, subheader=None):
154 msg = {}
162 msg = {}
155 msg['header'] = self.msg_header(msg_type)
163 msg['header'] = self.msg_header(msg_type)
156 msg['msg_id'] = msg['header']['msg_id']
164 msg['msg_id'] = msg['header']['msg_id']
157 msg['parent_header'] = {} if parent is None else extract_header(parent)
165 msg['parent_header'] = {} if parent is None else extract_header(parent)
158 msg['msg_type'] = msg_type
166 msg['msg_type'] = msg_type
159 msg['content'] = {} if content is None else content
167 msg['content'] = {} if content is None else content
160 sub = {} if subheader is None else subheader
168 sub = {} if subheader is None else subheader
161 msg['header'].update(sub)
169 msg['header'].update(sub)
162 return msg
170 return msg
163
171
164 def check_key(self, msg_or_header):
172 def check_key(self, msg_or_header):
165 """Check that a message's header has the right key"""
173 """Check that a message's header has the right key"""
166 if self.key is None:
174 if self.key is None:
167 return True
175 return True
168 header = extract_header(msg_or_header)
176 header = extract_header(msg_or_header)
169 return header.get('key', None) == self.key
177 return header.get('key', None) == self.key
170
178
171
179
172 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False):
180 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False):
173 """Build and send a message via stream or socket.
181 """Build and send a message via stream or socket.
174
182
175 Parameters
183 Parameters
176 ----------
184 ----------
177
185
178 stream : zmq.Socket or ZMQStream
186 stream : zmq.Socket or ZMQStream
179 the socket-like object used to send the data
187 the socket-like object used to send the data
180 msg_or_type : str or Message/dict
188 msg_or_type : str or Message/dict
181 Normally, msg_or_type will be a msg_type unless a message is being sent more
189 Normally, msg_or_type will be a msg_type unless a message is being sent more
182 than once.
190 than once.
183
191
184 content : dict or None
192 content : dict or None
185 the content of the message (ignored if msg_or_type is a message)
193 the content of the message (ignored if msg_or_type is a message)
186 buffers : list or None
194 buffers : list or None
187 the already-serialized buffers to be appended to the message
195 the already-serialized buffers to be appended to the message
188 parent : Message or dict or None
196 parent : Message or dict or None
189 the parent or parent header describing the parent of this message
197 the parent or parent header describing the parent of this message
190 subheader : dict or None
198 subheader : dict or None
191 extra header keys for this message's header
199 extra header keys for this message's header
192 ident : bytes or list of bytes
200 ident : bytes or list of bytes
193 the zmq.IDENTITY routing path
201 the zmq.IDENTITY routing path
194 track : bool
202 track : bool
195 whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages.
203 whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages.
196
204
197 Returns
205 Returns
198 -------
206 -------
199 msg : message dict
207 msg : message dict
200 the constructed message
208 the constructed message
201 (msg,tracker) : (message dict, MessageTracker)
209 (msg,tracker) : (message dict, MessageTracker)
202 if track=True, then a 2-tuple will be returned, the first element being the constructed
210 if track=True, then a 2-tuple will be returned, the first element being the constructed
203 message, and the second being the MessageTracker
211 message, and the second being the MessageTracker
204
212
205 """
213 """
206
214
207 if not isinstance(stream, (zmq.Socket, ZMQStream)):
215 if not isinstance(stream, (zmq.Socket, ZMQStream)):
208 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
216 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
209 elif track and isinstance(stream, ZMQStream):
217 elif track and isinstance(stream, ZMQStream):
210 raise TypeError("ZMQStream cannot track messages")
218 raise TypeError("ZMQStream cannot track messages")
211
219
212 if isinstance(msg_or_type, (Message, dict)):
220 if isinstance(msg_or_type, (Message, dict)):
213 # we got a Message, not a msg_type
221 # we got a Message, not a msg_type
214 # don't build a new Message
222 # don't build a new Message
215 msg = msg_or_type
223 msg = msg_or_type
216 content = msg['content']
224 content = msg['content']
217 else:
225 else:
218 msg = self.msg(msg_or_type, content, parent, subheader)
226 msg = self.msg(msg_or_type, content, parent, subheader)
219
227
220 buffers = [] if buffers is None else buffers
228 buffers = [] if buffers is None else buffers
221 to_send = []
229 to_send = []
222 if isinstance(ident, list):
230 if isinstance(ident, list):
223 # accept list of idents
231 # accept list of idents
224 to_send.extend(ident)
232 to_send.extend(ident)
225 elif ident is not None:
233 elif ident is not None:
226 to_send.append(ident)
234 to_send.append(ident)
227 to_send.append(DELIM)
235 to_send.append(DELIM)
228 if self.key is not None:
236 if self.key is not None:
229 to_send.append(self.key)
237 to_send.append(self.key)
230 to_send.append(self.pack(msg['header']))
238 to_send.append(self.pack(msg['header']))
231 to_send.append(self.pack(msg['parent_header']))
239 to_send.append(self.pack(msg['parent_header']))
232
240
233 if content is None:
241 if content is None:
234 content = self.none
242 content = self.none
235 elif isinstance(content, dict):
243 elif isinstance(content, dict):
236 content = self.pack(content)
244 content = self.pack(content)
237 elif isinstance(content, bytes):
245 elif isinstance(content, bytes):
238 # content is already packed, as in a relayed message
246 # content is already packed, as in a relayed message
239 pass
247 pass
240 else:
248 else:
241 raise TypeError("Content incorrect type: %s"%type(content))
249 raise TypeError("Content incorrect type: %s"%type(content))
242 to_send.append(content)
250 to_send.append(content)
243 flag = 0
251 flag = 0
244 if buffers:
252 if buffers:
245 flag = zmq.SNDMORE
253 flag = zmq.SNDMORE
246 _track = False
254 _track = False
247 else:
255 else:
248 _track=track
256 _track=track
249 if track:
257 if track:
250 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
258 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
251 else:
259 else:
252 tracker = stream.send_multipart(to_send, flag, copy=False)
260 tracker = stream.send_multipart(to_send, flag, copy=False)
253 for b in buffers[:-1]:
261 for b in buffers[:-1]:
254 stream.send(b, flag, copy=False)
262 stream.send(b, flag, copy=False)
255 if buffers:
263 if buffers:
256 if track:
264 if track:
257 tracker = stream.send(buffers[-1], copy=False, track=track)
265 tracker = stream.send(buffers[-1], copy=False, track=track)
258 else:
266 else:
259 tracker = stream.send(buffers[-1], copy=False)
267 tracker = stream.send(buffers[-1], copy=False)
260
268
261 # omsg = Message(msg)
269 # omsg = Message(msg)
262 if self.debug:
270 if self.debug:
263 pprint.pprint(msg)
271 pprint.pprint(msg)
264 pprint.pprint(to_send)
272 pprint.pprint(to_send)
265 pprint.pprint(buffers)
273 pprint.pprint(buffers)
266
274
267 msg['tracker'] = tracker
275 msg['tracker'] = tracker
268
276
269 return msg
277 return msg
270
278
271 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
279 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
272 """Send a raw message via ident path.
280 """Send a raw message via ident path.
273
281
274 Parameters
282 Parameters
275 ----------
283 ----------
276 msg : list of sendable buffers"""
284 msg : list of sendable buffers"""
277 to_send = []
285 to_send = []
278 if isinstance(ident, bytes):
286 if isinstance(ident, bytes):
279 ident = [ident]
287 ident = [ident]
280 if ident is not None:
288 if ident is not None:
281 to_send.extend(ident)
289 to_send.extend(ident)
282 to_send.append(DELIM)
290 to_send.append(DELIM)
283 if self.key is not None:
291 if self.key is not None:
284 to_send.append(self.key)
292 to_send.append(self.key)
285 to_send.extend(msg)
293 to_send.extend(msg)
286 stream.send_multipart(msg, flags, copy=copy)
294 stream.send_multipart(msg, flags, copy=copy)
287
295
288 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
296 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
289 """receives and unpacks a message
297 """receives and unpacks a message
290 returns [idents], msg"""
298 returns [idents], msg"""
291 if isinstance(socket, ZMQStream):
299 if isinstance(socket, ZMQStream):
292 socket = socket.socket
300 socket = socket.socket
293 try:
301 try:
294 msg = socket.recv_multipart(mode)
302 msg = socket.recv_multipart(mode)
295 except zmq.ZMQError as e:
303 except zmq.ZMQError as e:
296 if e.errno == zmq.EAGAIN:
304 if e.errno == zmq.EAGAIN:
297 # We can convert EAGAIN to None as we know in this case
305 # We can convert EAGAIN to None as we know in this case
298 # recv_multipart won't return None.
306 # recv_multipart won't return None.
299 return None
307 return None
300 else:
308 else:
301 raise
309 raise
302 # return an actual Message object
310 # return an actual Message object
303 # determine the number of idents by trying to unpack them.
311 # determine the number of idents by trying to unpack them.
304 # this is terrible:
312 # this is terrible:
305 idents, msg = self.feed_identities(msg, copy)
313 idents, msg = self.feed_identities(msg, copy)
306 try:
314 try:
307 return idents, self.unpack_message(msg, content=content, copy=copy)
315 return idents, self.unpack_message(msg, content=content, copy=copy)
308 except Exception as e:
316 except Exception as e:
309 print (idents, msg)
317 print (idents, msg)
310 # TODO: handle it
318 # TODO: handle it
311 raise e
319 raise e
312
320
313 def feed_identities(self, msg, copy=True):
321 def feed_identities(self, msg, copy=True):
314 """feed until DELIM is reached, then return the prefix as idents and remainder as
322 """feed until DELIM is reached, then return the prefix as idents and remainder as
315 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
323 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
316
324
317 Parameters
325 Parameters
318 ----------
326 ----------
319 msg : a list of Message or bytes objects
327 msg : a list of Message or bytes objects
320 the message to be split
328 the message to be split
321 copy : bool
329 copy : bool
322 flag determining whether the arguments are bytes or Messages
330 flag determining whether the arguments are bytes or Messages
323
331
324 Returns
332 Returns
325 -------
333 -------
326 (idents,msg) : two lists
334 (idents,msg) : two lists
327 idents will always be a list of bytes - the indentity prefix
335 idents will always be a list of bytes - the indentity prefix
328 msg will be a list of bytes or Messages, unchanged from input
336 msg will be a list of bytes or Messages, unchanged from input
329 msg should be unpackable via self.unpack_message at this point.
337 msg should be unpackable via self.unpack_message at this point.
330 """
338 """
331 ikey = int(self.key is not None)
339 ikey = int(self.key is not None)
332 minlen = 3 + ikey
340 minlen = 3 + ikey
333 msg = list(msg)
341 msg = list(msg)
334 idents = []
342 idents = []
335 while len(msg) > minlen:
343 while len(msg) > minlen:
336 if copy:
344 if copy:
337 s = msg[0]
345 s = msg[0]
338 else:
346 else:
339 s = msg[0].bytes
347 s = msg[0].bytes
340 if s == DELIM:
348 if s == DELIM:
341 msg.pop(0)
349 msg.pop(0)
342 break
350 break
343 else:
351 else:
344 idents.append(s)
352 idents.append(s)
345 msg.pop(0)
353 msg.pop(0)
346
354
347 return idents, msg
355 return idents, msg
348
356
349 def unpack_message(self, msg, content=True, copy=True):
357 def unpack_message(self, msg, content=True, copy=True):
350 """Return a message object from the format
358 """Return a message object from the format
351 sent by self.send.
359 sent by self.send.
352
360
353 Parameters:
361 Parameters:
354 -----------
362 -----------
355
363
356 content : bool (True)
364 content : bool (True)
357 whether to unpack the content dict (True),
365 whether to unpack the content dict (True),
358 or leave it serialized (False)
366 or leave it serialized (False)
359
367
360 copy : bool (True)
368 copy : bool (True)
361 whether to return the bytes (True),
369 whether to return the bytes (True),
362 or the non-copying Message object in each place (False)
370 or the non-copying Message object in each place (False)
363
371
364 """
372 """
365 ikey = int(self.key is not None)
373 ikey = int(self.key is not None)
366 minlen = 3 + ikey
374 minlen = 3 + ikey
367 message = {}
375 message = {}
368 if not copy:
376 if not copy:
369 for i in range(minlen):
377 for i in range(minlen):
370 msg[i] = msg[i].bytes
378 msg[i] = msg[i].bytes
371 if ikey:
379 if ikey:
372 if not self.key == msg[0]:
380 if not self.key == msg[0]:
373 raise KeyError("Invalid Session Key: %s"%msg[0])
381 raise KeyError("Invalid Session Key: %s"%msg[0])
374 if not len(msg) >= minlen:
382 if not len(msg) >= minlen:
375 raise TypeError("malformed message, must have at least %i elements"%minlen)
383 raise TypeError("malformed message, must have at least %i elements"%minlen)
376 message['header'] = self.unpack(msg[ikey+0])
384 message['header'] = self.unpack(msg[ikey+0])
377 message['msg_type'] = message['header']['msg_type']
385 message['msg_type'] = message['header']['msg_type']
378 message['parent_header'] = self.unpack(msg[ikey+1])
386 message['parent_header'] = self.unpack(msg[ikey+1])
379 if content:
387 if content:
380 message['content'] = self.unpack(msg[ikey+2])
388 message['content'] = self.unpack(msg[ikey+2])
381 else:
389 else:
382 message['content'] = msg[ikey+2]
390 message['content'] = msg[ikey+2]
383
391
384 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
392 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
385 return message
393 return message
386
394
387
395
388 def test_msg2obj():
396 def test_msg2obj():
389 am = dict(x=1)
397 am = dict(x=1)
390 ao = Message(am)
398 ao = Message(am)
391 assert ao.x == am['x']
399 assert ao.x == am['x']
392
400
393 am['y'] = dict(z=1)
401 am['y'] = dict(z=1)
394 ao = Message(am)
402 ao = Message(am)
395 assert ao.y.z == am['y']['z']
403 assert ao.y.z == am['y']['z']
396
404
397 k1, k2 = 'y', 'z'
405 k1, k2 = 'y', 'z'
398 assert ao[k1][k2] == am[k1][k2]
406 assert ao[k1][k2] == am[k1][k2]
399
407
400 am2 = dict(ao)
408 am2 = dict(ao)
401 assert am['x'] == am2['x']
409 assert am['x'] == am2['x']
402 assert am['y']['z'] == am2['y']['z']
410 assert am['y']['z'] == am2['y']['z']
@@ -1,462 +1,477 b''
1 """some generic utilities for dealing with classes, urls, and serialization"""
1 """some generic utilities for dealing with classes, urls, and serialization"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010-2011 The IPython Development Team
3 # Copyright (C) 2010-2011 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 # Standard library imports.
13 # Standard library imports.
14 import logging
14 import logging
15 import os
15 import os
16 import re
16 import re
17 import stat
17 import stat
18 import socket
18 import socket
19 import sys
19 import sys
20 from datetime import datetime
20 from signal import signal, SIGINT, SIGABRT, SIGTERM
21 from signal import signal, SIGINT, SIGABRT, SIGTERM
21 try:
22 try:
22 from signal import SIGKILL
23 from signal import SIGKILL
23 except ImportError:
24 except ImportError:
24 SIGKILL=None
25 SIGKILL=None
25
26
26 try:
27 try:
27 import cPickle
28 import cPickle
28 pickle = cPickle
29 pickle = cPickle
29 except:
30 except:
30 cPickle = None
31 cPickle = None
31 import pickle
32 import pickle
32
33
33 # System library imports
34 # System library imports
34 import zmq
35 import zmq
35 from zmq.log import handlers
36 from zmq.log import handlers
36
37
37 # IPython imports
38 # IPython imports
38 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
39 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
39 from IPython.utils.newserialized import serialize, unserialize
40 from IPython.utils.newserialized import serialize, unserialize
40 from IPython.zmq.log import EnginePUBHandler
41 from IPython.zmq.log import EnginePUBHandler
41
42
42 # globals
43 # globals
43 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
44 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
45 ISO8601_RE=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
44
46
45 #-----------------------------------------------------------------------------
47 #-----------------------------------------------------------------------------
46 # Classes
48 # Classes
47 #-----------------------------------------------------------------------------
49 #-----------------------------------------------------------------------------
48
50
49 class Namespace(dict):
51 class Namespace(dict):
50 """Subclass of dict for attribute access to keys."""
52 """Subclass of dict for attribute access to keys."""
51
53
52 def __getattr__(self, key):
54 def __getattr__(self, key):
53 """getattr aliased to getitem"""
55 """getattr aliased to getitem"""
54 if key in self.iterkeys():
56 if key in self.iterkeys():
55 return self[key]
57 return self[key]
56 else:
58 else:
57 raise NameError(key)
59 raise NameError(key)
58
60
59 def __setattr__(self, key, value):
61 def __setattr__(self, key, value):
60 """setattr aliased to setitem, with strict"""
62 """setattr aliased to setitem, with strict"""
61 if hasattr(dict, key):
63 if hasattr(dict, key):
62 raise KeyError("Cannot override dict keys %r"%key)
64 raise KeyError("Cannot override dict keys %r"%key)
63 self[key] = value
65 self[key] = value
64
66
65
67
66 class ReverseDict(dict):
68 class ReverseDict(dict):
67 """simple double-keyed subset of dict methods."""
69 """simple double-keyed subset of dict methods."""
68
70
69 def __init__(self, *args, **kwargs):
71 def __init__(self, *args, **kwargs):
70 dict.__init__(self, *args, **kwargs)
72 dict.__init__(self, *args, **kwargs)
71 self._reverse = dict()
73 self._reverse = dict()
72 for key, value in self.iteritems():
74 for key, value in self.iteritems():
73 self._reverse[value] = key
75 self._reverse[value] = key
74
76
75 def __getitem__(self, key):
77 def __getitem__(self, key):
76 try:
78 try:
77 return dict.__getitem__(self, key)
79 return dict.__getitem__(self, key)
78 except KeyError:
80 except KeyError:
79 return self._reverse[key]
81 return self._reverse[key]
80
82
81 def __setitem__(self, key, value):
83 def __setitem__(self, key, value):
82 if key in self._reverse:
84 if key in self._reverse:
83 raise KeyError("Can't have key %r on both sides!"%key)
85 raise KeyError("Can't have key %r on both sides!"%key)
84 dict.__setitem__(self, key, value)
86 dict.__setitem__(self, key, value)
85 self._reverse[value] = key
87 self._reverse[value] = key
86
88
87 def pop(self, key):
89 def pop(self, key):
88 value = dict.pop(self, key)
90 value = dict.pop(self, key)
89 self._reverse.pop(value)
91 self._reverse.pop(value)
90 return value
92 return value
91
93
92 def get(self, key, default=None):
94 def get(self, key, default=None):
93 try:
95 try:
94 return self[key]
96 return self[key]
95 except KeyError:
97 except KeyError:
96 return default
98 return default
97
99
98 #-----------------------------------------------------------------------------
100 #-----------------------------------------------------------------------------
99 # Functions
101 # Functions
100 #-----------------------------------------------------------------------------
102 #-----------------------------------------------------------------------------
101
103
104 def extract_dates(obj):
105 """extract ISO8601 dates from unpacked JSON"""
106 if isinstance(obj, dict):
107 for k,v in obj.iteritems():
108 obj[k] = extract_dates(v)
109 elif isinstance(obj, list):
110 obj = [ extract_dates(o) for o in obj ]
111 elif isinstance(obj, basestring):
112 if ISO8601_RE.match(obj):
113 obj = datetime.strptime(obj, ISO8601)
114 return obj
115
102 def validate_url(url):
116 def validate_url(url):
103 """validate a url for zeromq"""
117 """validate a url for zeromq"""
104 if not isinstance(url, basestring):
118 if not isinstance(url, basestring):
105 raise TypeError("url must be a string, not %r"%type(url))
119 raise TypeError("url must be a string, not %r"%type(url))
106 url = url.lower()
120 url = url.lower()
107
121
108 proto_addr = url.split('://')
122 proto_addr = url.split('://')
109 assert len(proto_addr) == 2, 'Invalid url: %r'%url
123 assert len(proto_addr) == 2, 'Invalid url: %r'%url
110 proto, addr = proto_addr
124 proto, addr = proto_addr
111 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
125 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
112
126
113 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
127 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
114 # author: Remi Sabourin
128 # author: Remi Sabourin
115 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
129 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
116
130
117 if proto == 'tcp':
131 if proto == 'tcp':
118 lis = addr.split(':')
132 lis = addr.split(':')
119 assert len(lis) == 2, 'Invalid url: %r'%url
133 assert len(lis) == 2, 'Invalid url: %r'%url
120 addr,s_port = lis
134 addr,s_port = lis
121 try:
135 try:
122 port = int(s_port)
136 port = int(s_port)
123 except ValueError:
137 except ValueError:
124 raise AssertionError("Invalid port %r in url: %r"%(port, url))
138 raise AssertionError("Invalid port %r in url: %r"%(port, url))
125
139
126 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
140 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
127
141
128 else:
142 else:
129 # only validate tcp urls currently
143 # only validate tcp urls currently
130 pass
144 pass
131
145
132 return True
146 return True
133
147
134
148
135 def validate_url_container(container):
149 def validate_url_container(container):
136 """validate a potentially nested collection of urls."""
150 """validate a potentially nested collection of urls."""
137 if isinstance(container, basestring):
151 if isinstance(container, basestring):
138 url = container
152 url = container
139 return validate_url(url)
153 return validate_url(url)
140 elif isinstance(container, dict):
154 elif isinstance(container, dict):
141 container = container.itervalues()
155 container = container.itervalues()
142
156
143 for element in container:
157 for element in container:
144 validate_url_container(element)
158 validate_url_container(element)
145
159
146
160
147 def split_url(url):
161 def split_url(url):
148 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
162 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
149 proto_addr = url.split('://')
163 proto_addr = url.split('://')
150 assert len(proto_addr) == 2, 'Invalid url: %r'%url
164 assert len(proto_addr) == 2, 'Invalid url: %r'%url
151 proto, addr = proto_addr
165 proto, addr = proto_addr
152 lis = addr.split(':')
166 lis = addr.split(':')
153 assert len(lis) == 2, 'Invalid url: %r'%url
167 assert len(lis) == 2, 'Invalid url: %r'%url
154 addr,s_port = lis
168 addr,s_port = lis
155 return proto,addr,s_port
169 return proto,addr,s_port
156
170
157 def disambiguate_ip_address(ip, location=None):
171 def disambiguate_ip_address(ip, location=None):
158 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
172 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
159 ones, based on the location (default interpretation of location is localhost)."""
173 ones, based on the location (default interpretation of location is localhost)."""
160 if ip in ('0.0.0.0', '*'):
174 if ip in ('0.0.0.0', '*'):
161 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
175 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
162 if location is None or location in external_ips:
176 if location is None or location in external_ips:
163 ip='127.0.0.1'
177 ip='127.0.0.1'
164 elif location:
178 elif location:
165 return location
179 return location
166 return ip
180 return ip
167
181
168 def disambiguate_url(url, location=None):
182 def disambiguate_url(url, location=None):
169 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
183 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
170 ones, based on the location (default interpretation is localhost).
184 ones, based on the location (default interpretation is localhost).
171
185
172 This is for zeromq urls, such as tcp://*:10101."""
186 This is for zeromq urls, such as tcp://*:10101."""
173 try:
187 try:
174 proto,ip,port = split_url(url)
188 proto,ip,port = split_url(url)
175 except AssertionError:
189 except AssertionError:
176 # probably not tcp url; could be ipc, etc.
190 # probably not tcp url; could be ipc, etc.
177 return url
191 return url
178
192
179 ip = disambiguate_ip_address(ip,location)
193 ip = disambiguate_ip_address(ip,location)
180
194
181 return "%s://%s:%s"%(proto,ip,port)
195 return "%s://%s:%s"%(proto,ip,port)
182
196
183
197
184 def rekey(dikt):
198 def rekey(dikt):
185 """Rekey a dict that has been forced to use str keys where there should be
199 """Rekey a dict that has been forced to use str keys where there should be
186 ints by json. This belongs in the jsonutil added by fperez."""
200 ints by json. This belongs in the jsonutil added by fperez."""
187 for k in dikt.iterkeys():
201 for k in dikt.iterkeys():
188 if isinstance(k, str):
202 if isinstance(k, str):
189 ik=fk=None
203 ik=fk=None
190 try:
204 try:
191 ik = int(k)
205 ik = int(k)
192 except ValueError:
206 except ValueError:
193 try:
207 try:
194 fk = float(k)
208 fk = float(k)
195 except ValueError:
209 except ValueError:
196 continue
210 continue
197 if ik is not None:
211 if ik is not None:
198 nk = ik
212 nk = ik
199 else:
213 else:
200 nk = fk
214 nk = fk
201 if nk in dikt:
215 if nk in dikt:
202 raise KeyError("already have key %r"%nk)
216 raise KeyError("already have key %r"%nk)
203 dikt[nk] = dikt.pop(k)
217 dikt[nk] = dikt.pop(k)
204 return dikt
218 return dikt
205
219
206 def serialize_object(obj, threshold=64e-6):
220 def serialize_object(obj, threshold=64e-6):
207 """Serialize an object into a list of sendable buffers.
221 """Serialize an object into a list of sendable buffers.
208
222
209 Parameters
223 Parameters
210 ----------
224 ----------
211
225
212 obj : object
226 obj : object
213 The object to be serialized
227 The object to be serialized
214 threshold : float
228 threshold : float
215 The threshold for not double-pickling the content.
229 The threshold for not double-pickling the content.
216
230
217
231
218 Returns
232 Returns
219 -------
233 -------
220 ('pmd', [bufs]) :
234 ('pmd', [bufs]) :
221 where pmd is the pickled metadata wrapper,
235 where pmd is the pickled metadata wrapper,
222 bufs is a list of data buffers
236 bufs is a list of data buffers
223 """
237 """
224 databuffers = []
238 databuffers = []
225 if isinstance(obj, (list, tuple)):
239 if isinstance(obj, (list, tuple)):
226 clist = canSequence(obj)
240 clist = canSequence(obj)
227 slist = map(serialize, clist)
241 slist = map(serialize, clist)
228 for s in slist:
242 for s in slist:
229 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
243 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
230 databuffers.append(s.getData())
244 databuffers.append(s.getData())
231 s.data = None
245 s.data = None
232 return pickle.dumps(slist,-1), databuffers
246 return pickle.dumps(slist,-1), databuffers
233 elif isinstance(obj, dict):
247 elif isinstance(obj, dict):
234 sobj = {}
248 sobj = {}
235 for k in sorted(obj.iterkeys()):
249 for k in sorted(obj.iterkeys()):
236 s = serialize(can(obj[k]))
250 s = serialize(can(obj[k]))
237 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
251 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
238 databuffers.append(s.getData())
252 databuffers.append(s.getData())
239 s.data = None
253 s.data = None
240 sobj[k] = s
254 sobj[k] = s
241 return pickle.dumps(sobj,-1),databuffers
255 return pickle.dumps(sobj,-1),databuffers
242 else:
256 else:
243 s = serialize(can(obj))
257 s = serialize(can(obj))
244 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
258 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
245 databuffers.append(s.getData())
259 databuffers.append(s.getData())
246 s.data = None
260 s.data = None
247 return pickle.dumps(s,-1),databuffers
261 return pickle.dumps(s,-1),databuffers
248
262
249
263
250 def unserialize_object(bufs):
264 def unserialize_object(bufs):
251 """reconstruct an object serialized by serialize_object from data buffers."""
265 """reconstruct an object serialized by serialize_object from data buffers."""
252 bufs = list(bufs)
266 bufs = list(bufs)
253 sobj = pickle.loads(bufs.pop(0))
267 sobj = pickle.loads(bufs.pop(0))
254 if isinstance(sobj, (list, tuple)):
268 if isinstance(sobj, (list, tuple)):
255 for s in sobj:
269 for s in sobj:
256 if s.data is None:
270 if s.data is None:
257 s.data = bufs.pop(0)
271 s.data = bufs.pop(0)
258 return uncanSequence(map(unserialize, sobj)), bufs
272 return uncanSequence(map(unserialize, sobj)), bufs
259 elif isinstance(sobj, dict):
273 elif isinstance(sobj, dict):
260 newobj = {}
274 newobj = {}
261 for k in sorted(sobj.iterkeys()):
275 for k in sorted(sobj.iterkeys()):
262 s = sobj[k]
276 s = sobj[k]
263 if s.data is None:
277 if s.data is None:
264 s.data = bufs.pop(0)
278 s.data = bufs.pop(0)
265 newobj[k] = uncan(unserialize(s))
279 newobj[k] = uncan(unserialize(s))
266 return newobj, bufs
280 return newobj, bufs
267 else:
281 else:
268 if sobj.data is None:
282 if sobj.data is None:
269 sobj.data = bufs.pop(0)
283 sobj.data = bufs.pop(0)
270 return uncan(unserialize(sobj)), bufs
284 return uncan(unserialize(sobj)), bufs
271
285
272 def pack_apply_message(f, args, kwargs, threshold=64e-6):
286 def pack_apply_message(f, args, kwargs, threshold=64e-6):
273 """pack up a function, args, and kwargs to be sent over the wire
287 """pack up a function, args, and kwargs to be sent over the wire
274 as a series of buffers. Any object whose data is larger than `threshold`
288 as a series of buffers. Any object whose data is larger than `threshold`
275 will not have their data copied (currently only numpy arrays support zero-copy)"""
289 will not have their data copied (currently only numpy arrays support zero-copy)"""
276 msg = [pickle.dumps(can(f),-1)]
290 msg = [pickle.dumps(can(f),-1)]
277 databuffers = [] # for large objects
291 databuffers = [] # for large objects
278 sargs, bufs = serialize_object(args,threshold)
292 sargs, bufs = serialize_object(args,threshold)
279 msg.append(sargs)
293 msg.append(sargs)
280 databuffers.extend(bufs)
294 databuffers.extend(bufs)
281 skwargs, bufs = serialize_object(kwargs,threshold)
295 skwargs, bufs = serialize_object(kwargs,threshold)
282 msg.append(skwargs)
296 msg.append(skwargs)
283 databuffers.extend(bufs)
297 databuffers.extend(bufs)
284 msg.extend(databuffers)
298 msg.extend(databuffers)
285 return msg
299 return msg
286
300
287 def unpack_apply_message(bufs, g=None, copy=True):
301 def unpack_apply_message(bufs, g=None, copy=True):
288 """unpack f,args,kwargs from buffers packed by pack_apply_message()
302 """unpack f,args,kwargs from buffers packed by pack_apply_message()
289 Returns: original f,args,kwargs"""
303 Returns: original f,args,kwargs"""
290 bufs = list(bufs) # allow us to pop
304 bufs = list(bufs) # allow us to pop
291 assert len(bufs) >= 3, "not enough buffers!"
305 assert len(bufs) >= 3, "not enough buffers!"
292 if not copy:
306 if not copy:
293 for i in range(3):
307 for i in range(3):
294 bufs[i] = bufs[i].bytes
308 bufs[i] = bufs[i].bytes
295 cf = pickle.loads(bufs.pop(0))
309 cf = pickle.loads(bufs.pop(0))
296 sargs = list(pickle.loads(bufs.pop(0)))
310 sargs = list(pickle.loads(bufs.pop(0)))
297 skwargs = dict(pickle.loads(bufs.pop(0)))
311 skwargs = dict(pickle.loads(bufs.pop(0)))
298 # print sargs, skwargs
312 # print sargs, skwargs
299 f = uncan(cf, g)
313 f = uncan(cf, g)
300 for sa in sargs:
314 for sa in sargs:
301 if sa.data is None:
315 if sa.data is None:
302 m = bufs.pop(0)
316 m = bufs.pop(0)
303 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
317 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
304 if copy:
318 if copy:
305 sa.data = buffer(m)
319 sa.data = buffer(m)
306 else:
320 else:
307 sa.data = m.buffer
321 sa.data = m.buffer
308 else:
322 else:
309 if copy:
323 if copy:
310 sa.data = m
324 sa.data = m
311 else:
325 else:
312 sa.data = m.bytes
326 sa.data = m.bytes
313
327
314 args = uncanSequence(map(unserialize, sargs), g)
328 args = uncanSequence(map(unserialize, sargs), g)
315 kwargs = {}
329 kwargs = {}
316 for k in sorted(skwargs.iterkeys()):
330 for k in sorted(skwargs.iterkeys()):
317 sa = skwargs[k]
331 sa = skwargs[k]
318 if sa.data is None:
332 if sa.data is None:
319 m = bufs.pop(0)
333 m = bufs.pop(0)
320 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
334 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
321 if copy:
335 if copy:
322 sa.data = buffer(m)
336 sa.data = buffer(m)
323 else:
337 else:
324 sa.data = m.buffer
338 sa.data = m.buffer
325 else:
339 else:
326 if copy:
340 if copy:
327 sa.data = m
341 sa.data = m
328 else:
342 else:
329 sa.data = m.bytes
343 sa.data = m.bytes
330
344
331 kwargs[k] = uncan(unserialize(sa), g)
345 kwargs[k] = uncan(unserialize(sa), g)
332
346
333 return f,args,kwargs
347 return f,args,kwargs
334
348
335 #--------------------------------------------------------------------------
349 #--------------------------------------------------------------------------
336 # helpers for implementing old MEC API via view.apply
350 # helpers for implementing old MEC API via view.apply
337 #--------------------------------------------------------------------------
351 #--------------------------------------------------------------------------
338
352
339 def interactive(f):
353 def interactive(f):
340 """decorator for making functions appear as interactively defined.
354 """decorator for making functions appear as interactively defined.
341 This results in the function being linked to the user_ns as globals()
355 This results in the function being linked to the user_ns as globals()
342 instead of the module globals().
356 instead of the module globals().
343 """
357 """
344 f.__module__ = '__main__'
358 f.__module__ = '__main__'
345 return f
359 return f
346
360
347 @interactive
361 @interactive
348 def _push(ns):
362 def _push(ns):
349 """helper method for implementing `client.push` via `client.apply`"""
363 """helper method for implementing `client.push` via `client.apply`"""
350 globals().update(ns)
364 globals().update(ns)
351
365
352 @interactive
366 @interactive
353 def _pull(keys):
367 def _pull(keys):
354 """helper method for implementing `client.pull` via `client.apply`"""
368 """helper method for implementing `client.pull` via `client.apply`"""
355 user_ns = globals()
369 user_ns = globals()
356 if isinstance(keys, (list,tuple, set)):
370 if isinstance(keys, (list,tuple, set)):
357 for key in keys:
371 for key in keys:
358 if not user_ns.has_key(key):
372 if not user_ns.has_key(key):
359 raise NameError("name '%s' is not defined"%key)
373 raise NameError("name '%s' is not defined"%key)
360 return map(user_ns.get, keys)
374 return map(user_ns.get, keys)
361 else:
375 else:
362 if not user_ns.has_key(keys):
376 if not user_ns.has_key(keys):
363 raise NameError("name '%s' is not defined"%keys)
377 raise NameError("name '%s' is not defined"%keys)
364 return user_ns.get(keys)
378 return user_ns.get(keys)
365
379
366 @interactive
380 @interactive
367 def _execute(code):
381 def _execute(code):
368 """helper method for implementing `client.execute` via `client.apply`"""
382 """helper method for implementing `client.execute` via `client.apply`"""
369 exec code in globals()
383 exec code in globals()
370
384
371 #--------------------------------------------------------------------------
385 #--------------------------------------------------------------------------
372 # extra process management utilities
386 # extra process management utilities
373 #--------------------------------------------------------------------------
387 #--------------------------------------------------------------------------
374
388
375 _random_ports = set()
389 _random_ports = set()
376
390
377 def select_random_ports(n):
391 def select_random_ports(n):
378 """Selects and return n random ports that are available."""
392 """Selects and return n random ports that are available."""
379 ports = []
393 ports = []
380 for i in xrange(n):
394 for i in xrange(n):
381 sock = socket.socket()
395 sock = socket.socket()
382 sock.bind(('', 0))
396 sock.bind(('', 0))
383 while sock.getsockname()[1] in _random_ports:
397 while sock.getsockname()[1] in _random_ports:
384 sock.close()
398 sock.close()
385 sock = socket.socket()
399 sock = socket.socket()
386 sock.bind(('', 0))
400 sock.bind(('', 0))
387 ports.append(sock)
401 ports.append(sock)
388 for i, sock in enumerate(ports):
402 for i, sock in enumerate(ports):
389 port = sock.getsockname()[1]
403 port = sock.getsockname()[1]
390 sock.close()
404 sock.close()
391 ports[i] = port
405 ports[i] = port
392 _random_ports.add(port)
406 _random_ports.add(port)
393 return ports
407 return ports
394
408
395 def signal_children(children):
409 def signal_children(children):
396 """Relay interupt/term signals to children, for more solid process cleanup."""
410 """Relay interupt/term signals to children, for more solid process cleanup."""
397 def terminate_children(sig, frame):
411 def terminate_children(sig, frame):
398 logging.critical("Got signal %i, terminating children..."%sig)
412 logging.critical("Got signal %i, terminating children..."%sig)
399 for child in children:
413 for child in children:
400 child.terminate()
414 child.terminate()
401
415
402 sys.exit(sig != SIGINT)
416 sys.exit(sig != SIGINT)
403 # sys.exit(sig)
417 # sys.exit(sig)
404 for sig in (SIGINT, SIGABRT, SIGTERM):
418 for sig in (SIGINT, SIGABRT, SIGTERM):
405 signal(sig, terminate_children)
419 signal(sig, terminate_children)
406
420
407 def generate_exec_key(keyfile):
421 def generate_exec_key(keyfile):
408 import uuid
422 import uuid
409 newkey = str(uuid.uuid4())
423 newkey = str(uuid.uuid4())
410 with open(keyfile, 'w') as f:
424 with open(keyfile, 'w') as f:
411 # f.write('ipython-key ')
425 # f.write('ipython-key ')
412 f.write(newkey+'\n')
426 f.write(newkey+'\n')
413 # set user-only RW permissions (0600)
427 # set user-only RW permissions (0600)
414 # this will have no effect on Windows
428 # this will have no effect on Windows
415 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
429 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
416
430
417
431
418 def integer_loglevel(loglevel):
432 def integer_loglevel(loglevel):
419 try:
433 try:
420 loglevel = int(loglevel)
434 loglevel = int(loglevel)
421 except ValueError:
435 except ValueError:
422 if isinstance(loglevel, str):
436 if isinstance(loglevel, str):
423 loglevel = getattr(logging, loglevel)
437 loglevel = getattr(logging, loglevel)
424 return loglevel
438 return loglevel
425
439
426 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
440 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
427 logger = logging.getLogger(logname)
441 logger = logging.getLogger(logname)
428 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
442 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
429 # don't add a second PUBHandler
443 # don't add a second PUBHandler
430 return
444 return
431 loglevel = integer_loglevel(loglevel)
445 loglevel = integer_loglevel(loglevel)
432 lsock = context.socket(zmq.PUB)
446 lsock = context.socket(zmq.PUB)
433 lsock.connect(iface)
447 lsock.connect(iface)
434 handler = handlers.PUBHandler(lsock)
448 handler = handlers.PUBHandler(lsock)
435 handler.setLevel(loglevel)
449 handler.setLevel(loglevel)
436 handler.root_topic = root
450 handler.root_topic = root
437 logger.addHandler(handler)
451 logger.addHandler(handler)
438 logger.setLevel(loglevel)
452 logger.setLevel(loglevel)
439
453
440 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
454 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
441 logger = logging.getLogger()
455 logger = logging.getLogger()
442 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
456 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
443 # don't add a second PUBHandler
457 # don't add a second PUBHandler
444 return
458 return
445 loglevel = integer_loglevel(loglevel)
459 loglevel = integer_loglevel(loglevel)
446 lsock = context.socket(zmq.PUB)
460 lsock = context.socket(zmq.PUB)
447 lsock.connect(iface)
461 lsock.connect(iface)
448 handler = EnginePUBHandler(engine, lsock)
462 handler = EnginePUBHandler(engine, lsock)
449 handler.setLevel(loglevel)
463 handler.setLevel(loglevel)
450 logger.addHandler(handler)
464 logger.addHandler(handler)
451 logger.setLevel(loglevel)
465 logger.setLevel(loglevel)
452
466
453 def local_logger(logname, loglevel=logging.DEBUG):
467 def local_logger(logname, loglevel=logging.DEBUG):
454 loglevel = integer_loglevel(loglevel)
468 loglevel = integer_loglevel(loglevel)
455 logger = logging.getLogger(logname)
469 logger = logging.getLogger(logname)
456 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
470 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
457 # don't add a second StreamHandler
471 # don't add a second StreamHandler
458 return
472 return
459 handler = logging.StreamHandler()
473 handler = logging.StreamHandler()
460 handler.setLevel(loglevel)
474 handler.setLevel(loglevel)
461 logger.addHandler(handler)
475 logger.addHandler(handler)
462 logger.setLevel(loglevel)
476 logger.setLevel(loglevel)
477
General Comments 0
You need to be logged in to leave comments. Login now