##// END OF EJS Templates
Merge pull request #4497 from minrk/datetime-session...
Matthias Bussonnier -
r13542:c1cf1a93 merge
parent child Browse files
Show More
@@ -1,846 +1,848
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9
9
10 Authors:
10 Authors:
11
11
12 * Min RK
12 * Min RK
13 * Brian Granger
13 * Brian Granger
14 * Fernando Perez
14 * Fernando Perez
15 """
15 """
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Copyright (C) 2010-2011 The IPython Development Team
17 # Copyright (C) 2010-2011 The IPython Development Team
18 #
18 #
19 # Distributed under the terms of the BSD License. The full license is in
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
20 # the file COPYING, distributed as part of this software.
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22
22
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24 # Imports
24 # Imports
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26
26
27 import hashlib
27 import hashlib
28 import hmac
28 import hmac
29 import logging
29 import logging
30 import os
30 import os
31 import pprint
31 import pprint
32 import random
32 import random
33 import uuid
33 import uuid
34 from datetime import datetime
34 from datetime import datetime
35
35
36 try:
36 try:
37 import cPickle
37 import cPickle
38 pickle = cPickle
38 pickle = cPickle
39 except:
39 except:
40 cPickle = None
40 cPickle = None
41 import pickle
41 import pickle
42
42
43 import zmq
43 import zmq
44 from zmq.utils import jsonapi
44 from zmq.utils import jsonapi
45 from zmq.eventloop.ioloop import IOLoop
45 from zmq.eventloop.ioloop import IOLoop
46 from zmq.eventloop.zmqstream import ZMQStream
46 from zmq.eventloop.zmqstream import ZMQStream
47
47
48 from IPython.config.configurable import Configurable, LoggingConfigurable
48 from IPython.config.configurable import Configurable, LoggingConfigurable
49 from IPython.utils import io
49 from IPython.utils import io
50 from IPython.utils.importstring import import_item
50 from IPython.utils.importstring import import_item
51 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
51 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
52 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
52 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
53 iteritems)
53 iteritems)
54 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
54 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
55 DottedObjectName, CUnicode, Dict, Integer,
55 DottedObjectName, CUnicode, Dict, Integer,
56 TraitError,
56 TraitError,
57 )
57 )
58 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
58 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
59
59
60 #-----------------------------------------------------------------------------
60 #-----------------------------------------------------------------------------
61 # utility functions
61 # utility functions
62 #-----------------------------------------------------------------------------
62 #-----------------------------------------------------------------------------
63
63
64 def squash_unicode(obj):
64 def squash_unicode(obj):
65 """coerce unicode back to bytestrings."""
65 """coerce unicode back to bytestrings."""
66 if isinstance(obj,dict):
66 if isinstance(obj,dict):
67 for key in obj.keys():
67 for key in obj.keys():
68 obj[key] = squash_unicode(obj[key])
68 obj[key] = squash_unicode(obj[key])
69 if isinstance(key, unicode_type):
69 if isinstance(key, unicode_type):
70 obj[squash_unicode(key)] = obj.pop(key)
70 obj[squash_unicode(key)] = obj.pop(key)
71 elif isinstance(obj, list):
71 elif isinstance(obj, list):
72 for i,v in enumerate(obj):
72 for i,v in enumerate(obj):
73 obj[i] = squash_unicode(v)
73 obj[i] = squash_unicode(v)
74 elif isinstance(obj, unicode_type):
74 elif isinstance(obj, unicode_type):
75 obj = obj.encode('utf8')
75 obj = obj.encode('utf8')
76 return obj
76 return obj
77
77
78 #-----------------------------------------------------------------------------
78 #-----------------------------------------------------------------------------
79 # globals and defaults
79 # globals and defaults
80 #-----------------------------------------------------------------------------
80 #-----------------------------------------------------------------------------
81
81
82 # ISO8601-ify datetime objects
82 # ISO8601-ify datetime objects
83 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
83 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
84 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
84 json_unpacker = lambda s: jsonapi.loads(s)
85
85
86 pickle_packer = lambda o: pickle.dumps(o,-1)
86 pickle_packer = lambda o: pickle.dumps(squash_dates(o),-1)
87 pickle_unpacker = pickle.loads
87 pickle_unpacker = pickle.loads
88
88
89 default_packer = json_packer
89 default_packer = json_packer
90 default_unpacker = json_unpacker
90 default_unpacker = json_unpacker
91
91
92 DELIM = b"<IDS|MSG>"
92 DELIM = b"<IDS|MSG>"
93 # singleton dummy tracker, which will always report as done
93 # singleton dummy tracker, which will always report as done
94 DONE = zmq.MessageTracker()
94 DONE = zmq.MessageTracker()
95
95
96 #-----------------------------------------------------------------------------
96 #-----------------------------------------------------------------------------
97 # Mixin tools for apps that use Sessions
97 # Mixin tools for apps that use Sessions
98 #-----------------------------------------------------------------------------
98 #-----------------------------------------------------------------------------
99
99
100 session_aliases = dict(
100 session_aliases = dict(
101 ident = 'Session.session',
101 ident = 'Session.session',
102 user = 'Session.username',
102 user = 'Session.username',
103 keyfile = 'Session.keyfile',
103 keyfile = 'Session.keyfile',
104 )
104 )
105
105
106 session_flags = {
106 session_flags = {
107 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
107 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
108 'keyfile' : '' }},
108 'keyfile' : '' }},
109 """Use HMAC digests for authentication of messages.
109 """Use HMAC digests for authentication of messages.
110 Setting this flag will generate a new UUID to use as the HMAC key.
110 Setting this flag will generate a new UUID to use as the HMAC key.
111 """),
111 """),
112 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
112 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
113 """Don't authenticate messages."""),
113 """Don't authenticate messages."""),
114 }
114 }
115
115
116 def default_secure(cfg):
116 def default_secure(cfg):
117 """Set the default behavior for a config environment to be secure.
117 """Set the default behavior for a config environment to be secure.
118
118
119 If Session.key/keyfile have not been set, set Session.key to
119 If Session.key/keyfile have not been set, set Session.key to
120 a new random UUID.
120 a new random UUID.
121 """
121 """
122
122
123 if 'Session' in cfg:
123 if 'Session' in cfg:
124 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
124 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
125 return
125 return
126 # key/keyfile not specified, generate new UUID:
126 # key/keyfile not specified, generate new UUID:
127 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
127 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
128
128
129
129
130 #-----------------------------------------------------------------------------
130 #-----------------------------------------------------------------------------
131 # Classes
131 # Classes
132 #-----------------------------------------------------------------------------
132 #-----------------------------------------------------------------------------
133
133
134 class SessionFactory(LoggingConfigurable):
134 class SessionFactory(LoggingConfigurable):
135 """The Base class for configurables that have a Session, Context, logger,
135 """The Base class for configurables that have a Session, Context, logger,
136 and IOLoop.
136 and IOLoop.
137 """
137 """
138
138
139 logname = Unicode('')
139 logname = Unicode('')
140 def _logname_changed(self, name, old, new):
140 def _logname_changed(self, name, old, new):
141 self.log = logging.getLogger(new)
141 self.log = logging.getLogger(new)
142
142
143 # not configurable:
143 # not configurable:
144 context = Instance('zmq.Context')
144 context = Instance('zmq.Context')
145 def _context_default(self):
145 def _context_default(self):
146 return zmq.Context.instance()
146 return zmq.Context.instance()
147
147
148 session = Instance('IPython.kernel.zmq.session.Session')
148 session = Instance('IPython.kernel.zmq.session.Session')
149
149
150 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
150 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
151 def _loop_default(self):
151 def _loop_default(self):
152 return IOLoop.instance()
152 return IOLoop.instance()
153
153
154 def __init__(self, **kwargs):
154 def __init__(self, **kwargs):
155 super(SessionFactory, self).__init__(**kwargs)
155 super(SessionFactory, self).__init__(**kwargs)
156
156
157 if self.session is None:
157 if self.session is None:
158 # construct the session
158 # construct the session
159 self.session = Session(**kwargs)
159 self.session = Session(**kwargs)
160
160
161
161
162 class Message(object):
162 class Message(object):
163 """A simple message object that maps dict keys to attributes.
163 """A simple message object that maps dict keys to attributes.
164
164
165 A Message can be created from a dict and a dict from a Message instance
165 A Message can be created from a dict and a dict from a Message instance
166 simply by calling dict(msg_obj)."""
166 simply by calling dict(msg_obj)."""
167
167
168 def __init__(self, msg_dict):
168 def __init__(self, msg_dict):
169 dct = self.__dict__
169 dct = self.__dict__
170 for k, v in iteritems(dict(msg_dict)):
170 for k, v in iteritems(dict(msg_dict)):
171 if isinstance(v, dict):
171 if isinstance(v, dict):
172 v = Message(v)
172 v = Message(v)
173 dct[k] = v
173 dct[k] = v
174
174
175 # Having this iterator lets dict(msg_obj) work out of the box.
175 # Having this iterator lets dict(msg_obj) work out of the box.
176 def __iter__(self):
176 def __iter__(self):
177 return iter(iteritems(self.__dict__))
177 return iter(iteritems(self.__dict__))
178
178
179 def __repr__(self):
179 def __repr__(self):
180 return repr(self.__dict__)
180 return repr(self.__dict__)
181
181
182 def __str__(self):
182 def __str__(self):
183 return pprint.pformat(self.__dict__)
183 return pprint.pformat(self.__dict__)
184
184
185 def __contains__(self, k):
185 def __contains__(self, k):
186 return k in self.__dict__
186 return k in self.__dict__
187
187
188 def __getitem__(self, k):
188 def __getitem__(self, k):
189 return self.__dict__[k]
189 return self.__dict__[k]
190
190
191
191
192 def msg_header(msg_id, msg_type, username, session):
192 def msg_header(msg_id, msg_type, username, session):
193 date = datetime.now()
193 date = datetime.now()
194 return locals()
194 return locals()
195
195
196 def extract_header(msg_or_header):
196 def extract_header(msg_or_header):
197 """Given a message or header, return the header."""
197 """Given a message or header, return the header."""
198 if not msg_or_header:
198 if not msg_or_header:
199 return {}
199 return {}
200 try:
200 try:
201 # See if msg_or_header is the entire message.
201 # See if msg_or_header is the entire message.
202 h = msg_or_header['header']
202 h = msg_or_header['header']
203 except KeyError:
203 except KeyError:
204 try:
204 try:
205 # See if msg_or_header is just the header
205 # See if msg_or_header is just the header
206 h = msg_or_header['msg_id']
206 h = msg_or_header['msg_id']
207 except KeyError:
207 except KeyError:
208 raise
208 raise
209 else:
209 else:
210 h = msg_or_header
210 h = msg_or_header
211 if not isinstance(h, dict):
211 if not isinstance(h, dict):
212 h = dict(h)
212 h = dict(h)
213 return h
213 return h
214
214
215 class Session(Configurable):
215 class Session(Configurable):
216 """Object for handling serialization and sending of messages.
216 """Object for handling serialization and sending of messages.
217
217
218 The Session object handles building messages and sending them
218 The Session object handles building messages and sending them
219 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
219 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
220 other over the network via Session objects, and only need to work with the
220 other over the network via Session objects, and only need to work with the
221 dict-based IPython message spec. The Session will handle
221 dict-based IPython message spec. The Session will handle
222 serialization/deserialization, security, and metadata.
222 serialization/deserialization, security, and metadata.
223
223
224 Sessions support configurable serialiization via packer/unpacker traits,
224 Sessions support configurable serialiization via packer/unpacker traits,
225 and signing with HMAC digests via the key/keyfile traits.
225 and signing with HMAC digests via the key/keyfile traits.
226
226
227 Parameters
227 Parameters
228 ----------
228 ----------
229
229
230 debug : bool
230 debug : bool
231 whether to trigger extra debugging statements
231 whether to trigger extra debugging statements
232 packer/unpacker : str : 'json', 'pickle' or import_string
232 packer/unpacker : str : 'json', 'pickle' or import_string
233 importstrings for methods to serialize message parts. If just
233 importstrings for methods to serialize message parts. If just
234 'json' or 'pickle', predefined JSON and pickle packers will be used.
234 'json' or 'pickle', predefined JSON and pickle packers will be used.
235 Otherwise, the entire importstring must be used.
235 Otherwise, the entire importstring must be used.
236
236
237 The functions must accept at least valid JSON input, and output *bytes*.
237 The functions must accept at least valid JSON input, and output *bytes*.
238
238
239 For example, to use msgpack:
239 For example, to use msgpack:
240 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
240 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
241 pack/unpack : callables
241 pack/unpack : callables
242 You can also set the pack/unpack callables for serialization directly.
242 You can also set the pack/unpack callables for serialization directly.
243 session : bytes
243 session : bytes
244 the ID of this Session object. The default is to generate a new UUID.
244 the ID of this Session object. The default is to generate a new UUID.
245 username : unicode
245 username : unicode
246 username added to message headers. The default is to ask the OS.
246 username added to message headers. The default is to ask the OS.
247 key : bytes
247 key : bytes
248 The key used to initialize an HMAC signature. If unset, messages
248 The key used to initialize an HMAC signature. If unset, messages
249 will not be signed or checked.
249 will not be signed or checked.
250 keyfile : filepath
250 keyfile : filepath
251 The file containing a key. If this is set, `key` will be initialized
251 The file containing a key. If this is set, `key` will be initialized
252 to the contents of the file.
252 to the contents of the file.
253
253
254 """
254 """
255
255
256 debug=Bool(False, config=True, help="""Debug output in the Session""")
256 debug=Bool(False, config=True, help="""Debug output in the Session""")
257
257
258 packer = DottedObjectName('json',config=True,
258 packer = DottedObjectName('json',config=True,
259 help="""The name of the packer for serializing messages.
259 help="""The name of the packer for serializing messages.
260 Should be one of 'json', 'pickle', or an import name
260 Should be one of 'json', 'pickle', or an import name
261 for a custom callable serializer.""")
261 for a custom callable serializer.""")
262 def _packer_changed(self, name, old, new):
262 def _packer_changed(self, name, old, new):
263 if new.lower() == 'json':
263 if new.lower() == 'json':
264 self.pack = json_packer
264 self.pack = json_packer
265 self.unpack = json_unpacker
265 self.unpack = json_unpacker
266 self.unpacker = new
266 self.unpacker = new
267 elif new.lower() == 'pickle':
267 elif new.lower() == 'pickle':
268 self.pack = pickle_packer
268 self.pack = pickle_packer
269 self.unpack = pickle_unpacker
269 self.unpack = pickle_unpacker
270 self.unpacker = new
270 self.unpacker = new
271 else:
271 else:
272 self.pack = import_item(str(new))
272 self.pack = import_item(str(new))
273
273
274 unpacker = DottedObjectName('json', config=True,
274 unpacker = DottedObjectName('json', config=True,
275 help="""The name of the unpacker for unserializing messages.
275 help="""The name of the unpacker for unserializing messages.
276 Only used with custom functions for `packer`.""")
276 Only used with custom functions for `packer`.""")
277 def _unpacker_changed(self, name, old, new):
277 def _unpacker_changed(self, name, old, new):
278 if new.lower() == 'json':
278 if new.lower() == 'json':
279 self.pack = json_packer
279 self.pack = json_packer
280 self.unpack = json_unpacker
280 self.unpack = json_unpacker
281 self.packer = new
281 self.packer = new
282 elif new.lower() == 'pickle':
282 elif new.lower() == 'pickle':
283 self.pack = pickle_packer
283 self.pack = pickle_packer
284 self.unpack = pickle_unpacker
284 self.unpack = pickle_unpacker
285 self.packer = new
285 self.packer = new
286 else:
286 else:
287 self.unpack = import_item(str(new))
287 self.unpack = import_item(str(new))
288
288
289 session = CUnicode(u'', config=True,
289 session = CUnicode(u'', config=True,
290 help="""The UUID identifying this session.""")
290 help="""The UUID identifying this session.""")
291 def _session_default(self):
291 def _session_default(self):
292 u = unicode_type(uuid.uuid4())
292 u = unicode_type(uuid.uuid4())
293 self.bsession = u.encode('ascii')
293 self.bsession = u.encode('ascii')
294 return u
294 return u
295
295
296 def _session_changed(self, name, old, new):
296 def _session_changed(self, name, old, new):
297 self.bsession = self.session.encode('ascii')
297 self.bsession = self.session.encode('ascii')
298
298
299 # bsession is the session as bytes
299 # bsession is the session as bytes
300 bsession = CBytes(b'')
300 bsession = CBytes(b'')
301
301
302 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
302 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
303 help="""Username for the Session. Default is your system username.""",
303 help="""Username for the Session. Default is your system username.""",
304 config=True)
304 config=True)
305
305
306 metadata = Dict({}, config=True,
306 metadata = Dict({}, config=True,
307 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
307 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
308
308
309 # message signature related traits:
309 # message signature related traits:
310
310
311 key = CBytes(b'', config=True,
311 key = CBytes(b'', config=True,
312 help="""execution key, for extra authentication.""")
312 help="""execution key, for extra authentication.""")
313 def _key_changed(self, name, old, new):
313 def _key_changed(self, name, old, new):
314 if new:
314 if new:
315 self.auth = hmac.HMAC(new, digestmod=self.digest_mod)
315 self.auth = hmac.HMAC(new, digestmod=self.digest_mod)
316 else:
316 else:
317 self.auth = None
317 self.auth = None
318
318
319 signature_scheme = Unicode('hmac-sha256', config=True,
319 signature_scheme = Unicode('hmac-sha256', config=True,
320 help="""The digest scheme used to construct the message signatures.
320 help="""The digest scheme used to construct the message signatures.
321 Must have the form 'hmac-HASH'.""")
321 Must have the form 'hmac-HASH'.""")
322 def _signature_scheme_changed(self, name, old, new):
322 def _signature_scheme_changed(self, name, old, new):
323 if not new.startswith('hmac-'):
323 if not new.startswith('hmac-'):
324 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
324 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
325 hash_name = new.split('-', 1)[1]
325 hash_name = new.split('-', 1)[1]
326 try:
326 try:
327 self.digest_mod = getattr(hashlib, hash_name)
327 self.digest_mod = getattr(hashlib, hash_name)
328 except AttributeError:
328 except AttributeError:
329 raise TraitError("hashlib has no such attribute: %s" % hash_name)
329 raise TraitError("hashlib has no such attribute: %s" % hash_name)
330
330
331 digest_mod = Any()
331 digest_mod = Any()
332 def _digest_mod_default(self):
332 def _digest_mod_default(self):
333 return hashlib.sha256
333 return hashlib.sha256
334
334
335 auth = Instance(hmac.HMAC)
335 auth = Instance(hmac.HMAC)
336
336
337 digest_history = Set()
337 digest_history = Set()
338 digest_history_size = Integer(2**16, config=True,
338 digest_history_size = Integer(2**16, config=True,
339 help="""The maximum number of digests to remember.
339 help="""The maximum number of digests to remember.
340
340
341 The digest history will be culled when it exceeds this value.
341 The digest history will be culled when it exceeds this value.
342 """
342 """
343 )
343 )
344
344
345 keyfile = Unicode('', config=True,
345 keyfile = Unicode('', config=True,
346 help="""path to file containing execution key.""")
346 help="""path to file containing execution key.""")
347 def _keyfile_changed(self, name, old, new):
347 def _keyfile_changed(self, name, old, new):
348 with open(new, 'rb') as f:
348 with open(new, 'rb') as f:
349 self.key = f.read().strip()
349 self.key = f.read().strip()
350
350
351 # for protecting against sends from forks
351 # for protecting against sends from forks
352 pid = Integer()
352 pid = Integer()
353
353
354 # serialization traits:
354 # serialization traits:
355
355
356 pack = Any(default_packer) # the actual packer function
356 pack = Any(default_packer) # the actual packer function
357 def _pack_changed(self, name, old, new):
357 def _pack_changed(self, name, old, new):
358 if not callable(new):
358 if not callable(new):
359 raise TypeError("packer must be callable, not %s"%type(new))
359 raise TypeError("packer must be callable, not %s"%type(new))
360
360
361 unpack = Any(default_unpacker) # the actual packer function
361 unpack = Any(default_unpacker) # the actual packer function
362 def _unpack_changed(self, name, old, new):
362 def _unpack_changed(self, name, old, new):
363 # unpacker is not checked - it is assumed to be
363 # unpacker is not checked - it is assumed to be
364 if not callable(new):
364 if not callable(new):
365 raise TypeError("unpacker must be callable, not %s"%type(new))
365 raise TypeError("unpacker must be callable, not %s"%type(new))
366
366
367 # thresholds:
367 # thresholds:
368 copy_threshold = Integer(2**16, config=True,
368 copy_threshold = Integer(2**16, config=True,
369 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
369 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
370 buffer_threshold = Integer(MAX_BYTES, config=True,
370 buffer_threshold = Integer(MAX_BYTES, config=True,
371 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
371 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
372 item_threshold = Integer(MAX_ITEMS, config=True,
372 item_threshold = Integer(MAX_ITEMS, config=True,
373 help="""The maximum number of items for a container to be introspected for custom serialization.
373 help="""The maximum number of items for a container to be introspected for custom serialization.
374 Containers larger than this are pickled outright.
374 Containers larger than this are pickled outright.
375 """
375 """
376 )
376 )
377
377
378
378
379 def __init__(self, **kwargs):
379 def __init__(self, **kwargs):
380 """create a Session object
380 """create a Session object
381
381
382 Parameters
382 Parameters
383 ----------
383 ----------
384
384
385 debug : bool
385 debug : bool
386 whether to trigger extra debugging statements
386 whether to trigger extra debugging statements
387 packer/unpacker : str : 'json', 'pickle' or import_string
387 packer/unpacker : str : 'json', 'pickle' or import_string
388 importstrings for methods to serialize message parts. If just
388 importstrings for methods to serialize message parts. If just
389 'json' or 'pickle', predefined JSON and pickle packers will be used.
389 'json' or 'pickle', predefined JSON and pickle packers will be used.
390 Otherwise, the entire importstring must be used.
390 Otherwise, the entire importstring must be used.
391
391
392 The functions must accept at least valid JSON input, and output
392 The functions must accept at least valid JSON input, and output
393 *bytes*.
393 *bytes*.
394
394
395 For example, to use msgpack:
395 For example, to use msgpack:
396 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
396 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
397 pack/unpack : callables
397 pack/unpack : callables
398 You can also set the pack/unpack callables for serialization
398 You can also set the pack/unpack callables for serialization
399 directly.
399 directly.
400 session : unicode (must be ascii)
400 session : unicode (must be ascii)
401 the ID of this Session object. The default is to generate a new
401 the ID of this Session object. The default is to generate a new
402 UUID.
402 UUID.
403 bsession : bytes
403 bsession : bytes
404 The session as bytes
404 The session as bytes
405 username : unicode
405 username : unicode
406 username added to message headers. The default is to ask the OS.
406 username added to message headers. The default is to ask the OS.
407 key : bytes
407 key : bytes
408 The key used to initialize an HMAC signature. If unset, messages
408 The key used to initialize an HMAC signature. If unset, messages
409 will not be signed or checked.
409 will not be signed or checked.
410 signature_scheme : str
410 signature_scheme : str
411 The message digest scheme. Currently must be of the form 'hmac-HASH',
411 The message digest scheme. Currently must be of the form 'hmac-HASH',
412 where 'HASH' is a hashing function available in Python's hashlib.
412 where 'HASH' is a hashing function available in Python's hashlib.
413 The default is 'hmac-sha256'.
413 The default is 'hmac-sha256'.
414 This is ignored if 'key' is empty.
414 This is ignored if 'key' is empty.
415 keyfile : filepath
415 keyfile : filepath
416 The file containing a key. If this is set, `key` will be
416 The file containing a key. If this is set, `key` will be
417 initialized to the contents of the file.
417 initialized to the contents of the file.
418 """
418 """
419 super(Session, self).__init__(**kwargs)
419 super(Session, self).__init__(**kwargs)
420 self._check_packers()
420 self._check_packers()
421 self.none = self.pack({})
421 self.none = self.pack({})
422 # ensure self._session_default() if necessary, so bsession is defined:
422 # ensure self._session_default() if necessary, so bsession is defined:
423 self.session
423 self.session
424 self.pid = os.getpid()
424 self.pid = os.getpid()
425
425
426 @property
426 @property
427 def msg_id(self):
427 def msg_id(self):
428 """always return new uuid"""
428 """always return new uuid"""
429 return str(uuid.uuid4())
429 return str(uuid.uuid4())
430
430
431 def _check_packers(self):
431 def _check_packers(self):
432 """check packers for binary data and datetime support."""
432 """check packers for datetime support."""
433 pack = self.pack
433 pack = self.pack
434 unpack = self.unpack
434 unpack = self.unpack
435
435
436 # check simple serialization
436 # check simple serialization
437 msg = dict(a=[1,'hi'])
437 msg = dict(a=[1,'hi'])
438 try:
438 try:
439 packed = pack(msg)
439 packed = pack(msg)
440 except Exception as e:
440 except Exception as e:
441 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
441 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
442 if self.packer == 'json':
442 if self.packer == 'json':
443 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
443 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
444 else:
444 else:
445 jsonmsg = ""
445 jsonmsg = ""
446 raise ValueError(
446 raise ValueError(
447 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
447 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
448 )
448 )
449
449
450 # ensure packed message is bytes
450 # ensure packed message is bytes
451 if not isinstance(packed, bytes):
451 if not isinstance(packed, bytes):
452 raise ValueError("message packed to %r, but bytes are required"%type(packed))
452 raise ValueError("message packed to %r, but bytes are required"%type(packed))
453
453
454 # check that unpack is pack's inverse
454 # check that unpack is pack's inverse
455 try:
455 try:
456 unpacked = unpack(packed)
456 unpacked = unpack(packed)
457 assert unpacked == msg
457 assert unpacked == msg
458 except Exception as e:
458 except Exception as e:
459 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
459 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
460 if self.packer == 'json':
460 if self.packer == 'json':
461 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
461 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
462 else:
462 else:
463 jsonmsg = ""
463 jsonmsg = ""
464 raise ValueError(
464 raise ValueError(
465 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
465 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
466 )
466 )
467
467
468 # check datetime support
468 # check datetime support
469 msg = dict(t=datetime.now())
469 msg = dict(t=datetime.now())
470 try:
470 try:
471 unpacked = unpack(pack(msg))
471 unpacked = unpack(pack(msg))
472 if isinstance(unpacked['t'], datetime):
473 raise ValueError("Shouldn't deserialize to datetime")
472 except Exception:
474 except Exception:
473 self.pack = lambda o: pack(squash_dates(o))
475 self.pack = lambda o: pack(squash_dates(o))
474 self.unpack = lambda s: extract_dates(unpack(s))
476 self.unpack = lambda s: unpack(s)
475
477
476 def msg_header(self, msg_type):
478 def msg_header(self, msg_type):
477 return msg_header(self.msg_id, msg_type, self.username, self.session)
479 return msg_header(self.msg_id, msg_type, self.username, self.session)
478
480
479 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
481 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
480 """Return the nested message dict.
482 """Return the nested message dict.
481
483
482 This format is different from what is sent over the wire. The
484 This format is different from what is sent over the wire. The
483 serialize/unserialize methods converts this nested message dict to the wire
485 serialize/unserialize methods converts this nested message dict to the wire
484 format, which is a list of message parts.
486 format, which is a list of message parts.
485 """
487 """
486 msg = {}
488 msg = {}
487 header = self.msg_header(msg_type) if header is None else header
489 header = self.msg_header(msg_type) if header is None else header
488 msg['header'] = header
490 msg['header'] = header
489 msg['msg_id'] = header['msg_id']
491 msg['msg_id'] = header['msg_id']
490 msg['msg_type'] = header['msg_type']
492 msg['msg_type'] = header['msg_type']
491 msg['parent_header'] = {} if parent is None else extract_header(parent)
493 msg['parent_header'] = {} if parent is None else extract_header(parent)
492 msg['content'] = {} if content is None else content
494 msg['content'] = {} if content is None else content
493 msg['metadata'] = self.metadata.copy()
495 msg['metadata'] = self.metadata.copy()
494 if metadata is not None:
496 if metadata is not None:
495 msg['metadata'].update(metadata)
497 msg['metadata'].update(metadata)
496 return msg
498 return msg
497
499
498 def sign(self, msg_list):
500 def sign(self, msg_list):
499 """Sign a message with HMAC digest. If no auth, return b''.
501 """Sign a message with HMAC digest. If no auth, return b''.
500
502
501 Parameters
503 Parameters
502 ----------
504 ----------
503 msg_list : list
505 msg_list : list
504 The [p_header,p_parent,p_content] part of the message list.
506 The [p_header,p_parent,p_content] part of the message list.
505 """
507 """
506 if self.auth is None:
508 if self.auth is None:
507 return b''
509 return b''
508 h = self.auth.copy()
510 h = self.auth.copy()
509 for m in msg_list:
511 for m in msg_list:
510 h.update(m)
512 h.update(m)
511 return str_to_bytes(h.hexdigest())
513 return str_to_bytes(h.hexdigest())
512
514
513 def serialize(self, msg, ident=None):
515 def serialize(self, msg, ident=None):
514 """Serialize the message components to bytes.
516 """Serialize the message components to bytes.
515
517
516 This is roughly the inverse of unserialize. The serialize/unserialize
518 This is roughly the inverse of unserialize. The serialize/unserialize
517 methods work with full message lists, whereas pack/unpack work with
519 methods work with full message lists, whereas pack/unpack work with
518 the individual message parts in the message list.
520 the individual message parts in the message list.
519
521
520 Parameters
522 Parameters
521 ----------
523 ----------
522 msg : dict or Message
524 msg : dict or Message
523 The nexted message dict as returned by the self.msg method.
525 The nexted message dict as returned by the self.msg method.
524
526
525 Returns
527 Returns
526 -------
528 -------
527 msg_list : list
529 msg_list : list
528 The list of bytes objects to be sent with the format:
530 The list of bytes objects to be sent with the format:
529 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_metadata,p_content,
531 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_metadata,p_content,
530 buffer1,buffer2,...]. In this list, the p_* entities are
532 buffer1,buffer2,...]. In this list, the p_* entities are
531 the packed or serialized versions, so if JSON is used, these
533 the packed or serialized versions, so if JSON is used, these
532 are utf8 encoded JSON strings.
534 are utf8 encoded JSON strings.
533 """
535 """
534 content = msg.get('content', {})
536 content = msg.get('content', {})
535 if content is None:
537 if content is None:
536 content = self.none
538 content = self.none
537 elif isinstance(content, dict):
539 elif isinstance(content, dict):
538 content = self.pack(content)
540 content = self.pack(content)
539 elif isinstance(content, bytes):
541 elif isinstance(content, bytes):
540 # content is already packed, as in a relayed message
542 # content is already packed, as in a relayed message
541 pass
543 pass
542 elif isinstance(content, unicode_type):
544 elif isinstance(content, unicode_type):
543 # should be bytes, but JSON often spits out unicode
545 # should be bytes, but JSON often spits out unicode
544 content = content.encode('utf8')
546 content = content.encode('utf8')
545 else:
547 else:
546 raise TypeError("Content incorrect type: %s"%type(content))
548 raise TypeError("Content incorrect type: %s"%type(content))
547
549
548 real_message = [self.pack(msg['header']),
550 real_message = [self.pack(msg['header']),
549 self.pack(msg['parent_header']),
551 self.pack(msg['parent_header']),
550 self.pack(msg['metadata']),
552 self.pack(msg['metadata']),
551 content,
553 content,
552 ]
554 ]
553
555
554 to_send = []
556 to_send = []
555
557
556 if isinstance(ident, list):
558 if isinstance(ident, list):
557 # accept list of idents
559 # accept list of idents
558 to_send.extend(ident)
560 to_send.extend(ident)
559 elif ident is not None:
561 elif ident is not None:
560 to_send.append(ident)
562 to_send.append(ident)
561 to_send.append(DELIM)
563 to_send.append(DELIM)
562
564
563 signature = self.sign(real_message)
565 signature = self.sign(real_message)
564 to_send.append(signature)
566 to_send.append(signature)
565
567
566 to_send.extend(real_message)
568 to_send.extend(real_message)
567
569
568 return to_send
570 return to_send
569
571
570 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
572 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
571 buffers=None, track=False, header=None, metadata=None):
573 buffers=None, track=False, header=None, metadata=None):
572 """Build and send a message via stream or socket.
574 """Build and send a message via stream or socket.
573
575
574 The message format used by this function internally is as follows:
576 The message format used by this function internally is as follows:
575
577
576 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
578 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
577 buffer1,buffer2,...]
579 buffer1,buffer2,...]
578
580
579 The serialize/unserialize methods convert the nested message dict into this
581 The serialize/unserialize methods convert the nested message dict into this
580 format.
582 format.
581
583
582 Parameters
584 Parameters
583 ----------
585 ----------
584
586
585 stream : zmq.Socket or ZMQStream
587 stream : zmq.Socket or ZMQStream
586 The socket-like object used to send the data.
588 The socket-like object used to send the data.
587 msg_or_type : str or Message/dict
589 msg_or_type : str or Message/dict
588 Normally, msg_or_type will be a msg_type unless a message is being
590 Normally, msg_or_type will be a msg_type unless a message is being
589 sent more than once. If a header is supplied, this can be set to
591 sent more than once. If a header is supplied, this can be set to
590 None and the msg_type will be pulled from the header.
592 None and the msg_type will be pulled from the header.
591
593
592 content : dict or None
594 content : dict or None
593 The content of the message (ignored if msg_or_type is a message).
595 The content of the message (ignored if msg_or_type is a message).
594 header : dict or None
596 header : dict or None
595 The header dict for the message (ignored if msg_to_type is a message).
597 The header dict for the message (ignored if msg_to_type is a message).
596 parent : Message or dict or None
598 parent : Message or dict or None
597 The parent or parent header describing the parent of this message
599 The parent or parent header describing the parent of this message
598 (ignored if msg_or_type is a message).
600 (ignored if msg_or_type is a message).
599 ident : bytes or list of bytes
601 ident : bytes or list of bytes
600 The zmq.IDENTITY routing path.
602 The zmq.IDENTITY routing path.
601 metadata : dict or None
603 metadata : dict or None
602 The metadata describing the message
604 The metadata describing the message
603 buffers : list or None
605 buffers : list or None
604 The already-serialized buffers to be appended to the message.
606 The already-serialized buffers to be appended to the message.
605 track : bool
607 track : bool
606 Whether to track. Only for use with Sockets, because ZMQStream
608 Whether to track. Only for use with Sockets, because ZMQStream
607 objects cannot track messages.
609 objects cannot track messages.
608
610
609
611
610 Returns
612 Returns
611 -------
613 -------
612 msg : dict
614 msg : dict
613 The constructed message.
615 The constructed message.
614 """
616 """
615 if not isinstance(stream, zmq.Socket):
617 if not isinstance(stream, zmq.Socket):
616 # ZMQStreams and dummy sockets do not support tracking.
618 # ZMQStreams and dummy sockets do not support tracking.
617 track = False
619 track = False
618
620
619 if isinstance(msg_or_type, (Message, dict)):
621 if isinstance(msg_or_type, (Message, dict)):
620 # We got a Message or message dict, not a msg_type so don't
622 # We got a Message or message dict, not a msg_type so don't
621 # build a new Message.
623 # build a new Message.
622 msg = msg_or_type
624 msg = msg_or_type
623 else:
625 else:
624 msg = self.msg(msg_or_type, content=content, parent=parent,
626 msg = self.msg(msg_or_type, content=content, parent=parent,
625 header=header, metadata=metadata)
627 header=header, metadata=metadata)
626 if not os.getpid() == self.pid:
628 if not os.getpid() == self.pid:
627 io.rprint("WARNING: attempted to send message from fork")
629 io.rprint("WARNING: attempted to send message from fork")
628 io.rprint(msg)
630 io.rprint(msg)
629 return
631 return
630 buffers = [] if buffers is None else buffers
632 buffers = [] if buffers is None else buffers
631 to_send = self.serialize(msg, ident)
633 to_send = self.serialize(msg, ident)
632 to_send.extend(buffers)
634 to_send.extend(buffers)
633 longest = max([ len(s) for s in to_send ])
635 longest = max([ len(s) for s in to_send ])
634 copy = (longest < self.copy_threshold)
636 copy = (longest < self.copy_threshold)
635
637
636 if buffers and track and not copy:
638 if buffers and track and not copy:
637 # only really track when we are doing zero-copy buffers
639 # only really track when we are doing zero-copy buffers
638 tracker = stream.send_multipart(to_send, copy=False, track=True)
640 tracker = stream.send_multipart(to_send, copy=False, track=True)
639 else:
641 else:
640 # use dummy tracker, which will be done immediately
642 # use dummy tracker, which will be done immediately
641 tracker = DONE
643 tracker = DONE
642 stream.send_multipart(to_send, copy=copy)
644 stream.send_multipart(to_send, copy=copy)
643
645
644 if self.debug:
646 if self.debug:
645 pprint.pprint(msg)
647 pprint.pprint(msg)
646 pprint.pprint(to_send)
648 pprint.pprint(to_send)
647 pprint.pprint(buffers)
649 pprint.pprint(buffers)
648
650
649 msg['tracker'] = tracker
651 msg['tracker'] = tracker
650
652
651 return msg
653 return msg
652
654
653 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
655 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
654 """Send a raw message via ident path.
656 """Send a raw message via ident path.
655
657
656 This method is used to send a already serialized message.
658 This method is used to send a already serialized message.
657
659
658 Parameters
660 Parameters
659 ----------
661 ----------
660 stream : ZMQStream or Socket
662 stream : ZMQStream or Socket
661 The ZMQ stream or socket to use for sending the message.
663 The ZMQ stream or socket to use for sending the message.
662 msg_list : list
664 msg_list : list
663 The serialized list of messages to send. This only includes the
665 The serialized list of messages to send. This only includes the
664 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
666 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
665 the message.
667 the message.
666 ident : ident or list
668 ident : ident or list
667 A single ident or a list of idents to use in sending.
669 A single ident or a list of idents to use in sending.
668 """
670 """
669 to_send = []
671 to_send = []
670 if isinstance(ident, bytes):
672 if isinstance(ident, bytes):
671 ident = [ident]
673 ident = [ident]
672 if ident is not None:
674 if ident is not None:
673 to_send.extend(ident)
675 to_send.extend(ident)
674
676
675 to_send.append(DELIM)
677 to_send.append(DELIM)
676 to_send.append(self.sign(msg_list))
678 to_send.append(self.sign(msg_list))
677 to_send.extend(msg_list)
679 to_send.extend(msg_list)
678 stream.send_multipart(msg_list, flags, copy=copy)
680 stream.send_multipart(msg_list, flags, copy=copy)
679
681
680 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
682 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
681 """Receive and unpack a message.
683 """Receive and unpack a message.
682
684
683 Parameters
685 Parameters
684 ----------
686 ----------
685 socket : ZMQStream or Socket
687 socket : ZMQStream or Socket
686 The socket or stream to use in receiving.
688 The socket or stream to use in receiving.
687
689
688 Returns
690 Returns
689 -------
691 -------
690 [idents], msg
692 [idents], msg
691 [idents] is a list of idents and msg is a nested message dict of
693 [idents] is a list of idents and msg is a nested message dict of
692 same format as self.msg returns.
694 same format as self.msg returns.
693 """
695 """
694 if isinstance(socket, ZMQStream):
696 if isinstance(socket, ZMQStream):
695 socket = socket.socket
697 socket = socket.socket
696 try:
698 try:
697 msg_list = socket.recv_multipart(mode, copy=copy)
699 msg_list = socket.recv_multipart(mode, copy=copy)
698 except zmq.ZMQError as e:
700 except zmq.ZMQError as e:
699 if e.errno == zmq.EAGAIN:
701 if e.errno == zmq.EAGAIN:
700 # We can convert EAGAIN to None as we know in this case
702 # We can convert EAGAIN to None as we know in this case
701 # recv_multipart won't return None.
703 # recv_multipart won't return None.
702 return None,None
704 return None,None
703 else:
705 else:
704 raise
706 raise
705 # split multipart message into identity list and message dict
707 # split multipart message into identity list and message dict
706 # invalid large messages can cause very expensive string comparisons
708 # invalid large messages can cause very expensive string comparisons
707 idents, msg_list = self.feed_identities(msg_list, copy)
709 idents, msg_list = self.feed_identities(msg_list, copy)
708 try:
710 try:
709 return idents, self.unserialize(msg_list, content=content, copy=copy)
711 return idents, self.unserialize(msg_list, content=content, copy=copy)
710 except Exception as e:
712 except Exception as e:
711 # TODO: handle it
713 # TODO: handle it
712 raise e
714 raise e
713
715
714 def feed_identities(self, msg_list, copy=True):
716 def feed_identities(self, msg_list, copy=True):
715 """Split the identities from the rest of the message.
717 """Split the identities from the rest of the message.
716
718
717 Feed until DELIM is reached, then return the prefix as idents and
719 Feed until DELIM is reached, then return the prefix as idents and
718 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
720 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
719 but that would be silly.
721 but that would be silly.
720
722
721 Parameters
723 Parameters
722 ----------
724 ----------
723 msg_list : a list of Message or bytes objects
725 msg_list : a list of Message or bytes objects
724 The message to be split.
726 The message to be split.
725 copy : bool
727 copy : bool
726 flag determining whether the arguments are bytes or Messages
728 flag determining whether the arguments are bytes or Messages
727
729
728 Returns
730 Returns
729 -------
731 -------
730 (idents, msg_list) : two lists
732 (idents, msg_list) : two lists
731 idents will always be a list of bytes, each of which is a ZMQ
733 idents will always be a list of bytes, each of which is a ZMQ
732 identity. msg_list will be a list of bytes or zmq.Messages of the
734 identity. msg_list will be a list of bytes or zmq.Messages of the
733 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
735 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
734 should be unpackable/unserializable via self.unserialize at this
736 should be unpackable/unserializable via self.unserialize at this
735 point.
737 point.
736 """
738 """
737 if copy:
739 if copy:
738 idx = msg_list.index(DELIM)
740 idx = msg_list.index(DELIM)
739 return msg_list[:idx], msg_list[idx+1:]
741 return msg_list[:idx], msg_list[idx+1:]
740 else:
742 else:
741 failed = True
743 failed = True
742 for idx,m in enumerate(msg_list):
744 for idx,m in enumerate(msg_list):
743 if m.bytes == DELIM:
745 if m.bytes == DELIM:
744 failed = False
746 failed = False
745 break
747 break
746 if failed:
748 if failed:
747 raise ValueError("DELIM not in msg_list")
749 raise ValueError("DELIM not in msg_list")
748 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
750 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
749 return [m.bytes for m in idents], msg_list
751 return [m.bytes for m in idents], msg_list
750
752
751 def _add_digest(self, signature):
753 def _add_digest(self, signature):
752 """add a digest to history to protect against replay attacks"""
754 """add a digest to history to protect against replay attacks"""
753 if self.digest_history_size == 0:
755 if self.digest_history_size == 0:
754 # no history, never add digests
756 # no history, never add digests
755 return
757 return
756
758
757 self.digest_history.add(signature)
759 self.digest_history.add(signature)
758 if len(self.digest_history) > self.digest_history_size:
760 if len(self.digest_history) > self.digest_history_size:
759 # threshold reached, cull 10%
761 # threshold reached, cull 10%
760 self._cull_digest_history()
762 self._cull_digest_history()
761
763
762 def _cull_digest_history(self):
764 def _cull_digest_history(self):
763 """cull the digest history
765 """cull the digest history
764
766
765 Removes a randomly selected 10% of the digest history
767 Removes a randomly selected 10% of the digest history
766 """
768 """
767 current = len(self.digest_history)
769 current = len(self.digest_history)
768 n_to_cull = max(int(current // 10), current - self.digest_history_size)
770 n_to_cull = max(int(current // 10), current - self.digest_history_size)
769 if n_to_cull >= current:
771 if n_to_cull >= current:
770 self.digest_history = set()
772 self.digest_history = set()
771 return
773 return
772 to_cull = random.sample(self.digest_history, n_to_cull)
774 to_cull = random.sample(self.digest_history, n_to_cull)
773 self.digest_history.difference_update(to_cull)
775 self.digest_history.difference_update(to_cull)
774
776
775 def unserialize(self, msg_list, content=True, copy=True):
777 def unserialize(self, msg_list, content=True, copy=True):
776 """Unserialize a msg_list to a nested message dict.
778 """Unserialize a msg_list to a nested message dict.
777
779
778 This is roughly the inverse of serialize. The serialize/unserialize
780 This is roughly the inverse of serialize. The serialize/unserialize
779 methods work with full message lists, whereas pack/unpack work with
781 methods work with full message lists, whereas pack/unpack work with
780 the individual message parts in the message list.
782 the individual message parts in the message list.
781
783
782 Parameters:
784 Parameters:
783 -----------
785 -----------
784 msg_list : list of bytes or Message objects
786 msg_list : list of bytes or Message objects
785 The list of message parts of the form [HMAC,p_header,p_parent,
787 The list of message parts of the form [HMAC,p_header,p_parent,
786 p_metadata,p_content,buffer1,buffer2,...].
788 p_metadata,p_content,buffer1,buffer2,...].
787 content : bool (True)
789 content : bool (True)
788 Whether to unpack the content dict (True), or leave it packed
790 Whether to unpack the content dict (True), or leave it packed
789 (False).
791 (False).
790 copy : bool (True)
792 copy : bool (True)
791 Whether to return the bytes (True), or the non-copying Message
793 Whether to return the bytes (True), or the non-copying Message
792 object in each place (False).
794 object in each place (False).
793
795
794 Returns
796 Returns
795 -------
797 -------
796 msg : dict
798 msg : dict
797 The nested message dict with top-level keys [header, parent_header,
799 The nested message dict with top-level keys [header, parent_header,
798 content, buffers].
800 content, buffers].
799 """
801 """
800 minlen = 5
802 minlen = 5
801 message = {}
803 message = {}
802 if not copy:
804 if not copy:
803 for i in range(minlen):
805 for i in range(minlen):
804 msg_list[i] = msg_list[i].bytes
806 msg_list[i] = msg_list[i].bytes
805 if self.auth is not None:
807 if self.auth is not None:
806 signature = msg_list[0]
808 signature = msg_list[0]
807 if not signature:
809 if not signature:
808 raise ValueError("Unsigned Message")
810 raise ValueError("Unsigned Message")
809 if signature in self.digest_history:
811 if signature in self.digest_history:
810 raise ValueError("Duplicate Signature: %r" % signature)
812 raise ValueError("Duplicate Signature: %r" % signature)
811 self._add_digest(signature)
813 self._add_digest(signature)
812 check = self.sign(msg_list[1:5])
814 check = self.sign(msg_list[1:5])
813 if not signature == check:
815 if not signature == check:
814 raise ValueError("Invalid Signature: %r" % signature)
816 raise ValueError("Invalid Signature: %r" % signature)
815 if not len(msg_list) >= minlen:
817 if not len(msg_list) >= minlen:
816 raise TypeError("malformed message, must have at least %i elements"%minlen)
818 raise TypeError("malformed message, must have at least %i elements"%minlen)
817 header = self.unpack(msg_list[1])
819 header = self.unpack(msg_list[1])
818 message['header'] = header
820 message['header'] = extract_dates(header)
819 message['msg_id'] = header['msg_id']
821 message['msg_id'] = header['msg_id']
820 message['msg_type'] = header['msg_type']
822 message['msg_type'] = header['msg_type']
821 message['parent_header'] = self.unpack(msg_list[2])
823 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
822 message['metadata'] = self.unpack(msg_list[3])
824 message['metadata'] = self.unpack(msg_list[3])
823 if content:
825 if content:
824 message['content'] = self.unpack(msg_list[4])
826 message['content'] = self.unpack(msg_list[4])
825 else:
827 else:
826 message['content'] = msg_list[4]
828 message['content'] = msg_list[4]
827
829
828 message['buffers'] = msg_list[5:]
830 message['buffers'] = msg_list[5:]
829 return message
831 return message
830
832
831 def test_msg2obj():
833 def test_msg2obj():
832 am = dict(x=1)
834 am = dict(x=1)
833 ao = Message(am)
835 ao = Message(am)
834 assert ao.x == am['x']
836 assert ao.x == am['x']
835
837
836 am['y'] = dict(z=1)
838 am['y'] = dict(z=1)
837 ao = Message(am)
839 ao = Message(am)
838 assert ao.y.z == am['y']['z']
840 assert ao.y.z == am['y']['z']
839
841
840 k1, k2 = 'y', 'z'
842 k1, k2 = 'y', 'z'
841 assert ao[k1][k2] == am[k1][k2]
843 assert ao[k1][k2] == am[k1][k2]
842
844
843 am2 = dict(ao)
845 am2 = dict(ao)
844 assert am['x'] == am2['x']
846 assert am['x'] == am2['x']
845 assert am['y']['z'] == am2['y']['z']
847 assert am['y']['z'] == am2['y']['z']
846
848
@@ -1,271 +1,289
1 """test building messages with streamsession"""
1 """test building messages with streamsession"""
2
2
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 import os
14 import os
15 import uuid
15 import uuid
16 from datetime import datetime
17
16 import zmq
18 import zmq
17
19
18 from zmq.tests import BaseZMQTestCase
20 from zmq.tests import BaseZMQTestCase
19 from zmq.eventloop.zmqstream import ZMQStream
21 from zmq.eventloop.zmqstream import ZMQStream
20
22
21 from IPython.kernel.zmq import session as ss
23 from IPython.kernel.zmq import session as ss
22
24
25 from IPython.testing.decorators import skipif, module_not_available
26 from IPython.utils.py3compat import string_types
27 from IPython.utils import jsonutil
28
23 def _bad_packer(obj):
29 def _bad_packer(obj):
24 raise TypeError("I don't work")
30 raise TypeError("I don't work")
25
31
26 def _bad_unpacker(bytes):
32 def _bad_unpacker(bytes):
27 raise TypeError("I don't work either")
33 raise TypeError("I don't work either")
28
34
29 class SessionTestCase(BaseZMQTestCase):
35 class SessionTestCase(BaseZMQTestCase):
30
36
31 def setUp(self):
37 def setUp(self):
32 BaseZMQTestCase.setUp(self)
38 BaseZMQTestCase.setUp(self)
33 self.session = ss.Session()
39 self.session = ss.Session()
34
40
35
41
36 class TestSession(SessionTestCase):
42 class TestSession(SessionTestCase):
37
43
38 def test_msg(self):
44 def test_msg(self):
39 """message format"""
45 """message format"""
40 msg = self.session.msg('execute')
46 msg = self.session.msg('execute')
41 thekeys = set('header parent_header metadata content msg_type msg_id'.split())
47 thekeys = set('header parent_header metadata content msg_type msg_id'.split())
42 s = set(msg.keys())
48 s = set(msg.keys())
43 self.assertEqual(s, thekeys)
49 self.assertEqual(s, thekeys)
44 self.assertTrue(isinstance(msg['content'],dict))
50 self.assertTrue(isinstance(msg['content'],dict))
45 self.assertTrue(isinstance(msg['metadata'],dict))
51 self.assertTrue(isinstance(msg['metadata'],dict))
46 self.assertTrue(isinstance(msg['header'],dict))
52 self.assertTrue(isinstance(msg['header'],dict))
47 self.assertTrue(isinstance(msg['parent_header'],dict))
53 self.assertTrue(isinstance(msg['parent_header'],dict))
48 self.assertTrue(isinstance(msg['msg_id'],str))
54 self.assertTrue(isinstance(msg['msg_id'],str))
49 self.assertTrue(isinstance(msg['msg_type'],str))
55 self.assertTrue(isinstance(msg['msg_type'],str))
50 self.assertEqual(msg['header']['msg_type'], 'execute')
56 self.assertEqual(msg['header']['msg_type'], 'execute')
51 self.assertEqual(msg['msg_type'], 'execute')
57 self.assertEqual(msg['msg_type'], 'execute')
52
58
53 def test_serialize(self):
59 def test_serialize(self):
54 msg = self.session.msg('execute', content=dict(a=10, b=1.1))
60 msg = self.session.msg('execute', content=dict(a=10, b=1.1))
55 msg_list = self.session.serialize(msg, ident=b'foo')
61 msg_list = self.session.serialize(msg, ident=b'foo')
56 ident, msg_list = self.session.feed_identities(msg_list)
62 ident, msg_list = self.session.feed_identities(msg_list)
57 new_msg = self.session.unserialize(msg_list)
63 new_msg = self.session.unserialize(msg_list)
58 self.assertEqual(ident[0], b'foo')
64 self.assertEqual(ident[0], b'foo')
59 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
65 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
60 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
66 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
61 self.assertEqual(new_msg['header'],msg['header'])
67 self.assertEqual(new_msg['header'],msg['header'])
62 self.assertEqual(new_msg['content'],msg['content'])
68 self.assertEqual(new_msg['content'],msg['content'])
63 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
69 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
64 self.assertEqual(new_msg['metadata'],msg['metadata'])
70 self.assertEqual(new_msg['metadata'],msg['metadata'])
65 # ensure floats don't come out as Decimal:
71 # ensure floats don't come out as Decimal:
66 self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
72 self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
67
73
68 def test_send(self):
74 def test_send(self):
69 ctx = zmq.Context.instance()
75 ctx = zmq.Context.instance()
70 A = ctx.socket(zmq.PAIR)
76 A = ctx.socket(zmq.PAIR)
71 B = ctx.socket(zmq.PAIR)
77 B = ctx.socket(zmq.PAIR)
72 A.bind("inproc://test")
78 A.bind("inproc://test")
73 B.connect("inproc://test")
79 B.connect("inproc://test")
74
80
75 msg = self.session.msg('execute', content=dict(a=10))
81 msg = self.session.msg('execute', content=dict(a=10))
76 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
82 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
77
83
78 ident, msg_list = self.session.feed_identities(B.recv_multipart())
84 ident, msg_list = self.session.feed_identities(B.recv_multipart())
79 new_msg = self.session.unserialize(msg_list)
85 new_msg = self.session.unserialize(msg_list)
80 self.assertEqual(ident[0], b'foo')
86 self.assertEqual(ident[0], b'foo')
81 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
87 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
82 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
88 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
83 self.assertEqual(new_msg['header'],msg['header'])
89 self.assertEqual(new_msg['header'],msg['header'])
84 self.assertEqual(new_msg['content'],msg['content'])
90 self.assertEqual(new_msg['content'],msg['content'])
85 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
91 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
86 self.assertEqual(new_msg['metadata'],msg['metadata'])
92 self.assertEqual(new_msg['metadata'],msg['metadata'])
87 self.assertEqual(new_msg['buffers'],[b'bar'])
93 self.assertEqual(new_msg['buffers'],[b'bar'])
88
94
89 content = msg['content']
95 content = msg['content']
90 header = msg['header']
96 header = msg['header']
91 parent = msg['parent_header']
97 parent = msg['parent_header']
92 metadata = msg['metadata']
98 metadata = msg['metadata']
93 msg_type = header['msg_type']
99 msg_type = header['msg_type']
94 self.session.send(A, None, content=content, parent=parent,
100 self.session.send(A, None, content=content, parent=parent,
95 header=header, metadata=metadata, ident=b'foo', buffers=[b'bar'])
101 header=header, metadata=metadata, ident=b'foo', buffers=[b'bar'])
96 ident, msg_list = self.session.feed_identities(B.recv_multipart())
102 ident, msg_list = self.session.feed_identities(B.recv_multipart())
97 new_msg = self.session.unserialize(msg_list)
103 new_msg = self.session.unserialize(msg_list)
98 self.assertEqual(ident[0], b'foo')
104 self.assertEqual(ident[0], b'foo')
99 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
105 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
100 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
106 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
101 self.assertEqual(new_msg['header'],msg['header'])
107 self.assertEqual(new_msg['header'],msg['header'])
102 self.assertEqual(new_msg['content'],msg['content'])
108 self.assertEqual(new_msg['content'],msg['content'])
103 self.assertEqual(new_msg['metadata'],msg['metadata'])
109 self.assertEqual(new_msg['metadata'],msg['metadata'])
104 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
110 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
105 self.assertEqual(new_msg['buffers'],[b'bar'])
111 self.assertEqual(new_msg['buffers'],[b'bar'])
106
112
107 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
113 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
108 ident, new_msg = self.session.recv(B)
114 ident, new_msg = self.session.recv(B)
109 self.assertEqual(ident[0], b'foo')
115 self.assertEqual(ident[0], b'foo')
110 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
116 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
111 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
117 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
112 self.assertEqual(new_msg['header'],msg['header'])
118 self.assertEqual(new_msg['header'],msg['header'])
113 self.assertEqual(new_msg['content'],msg['content'])
119 self.assertEqual(new_msg['content'],msg['content'])
114 self.assertEqual(new_msg['metadata'],msg['metadata'])
120 self.assertEqual(new_msg['metadata'],msg['metadata'])
115 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
121 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
116 self.assertEqual(new_msg['buffers'],[b'bar'])
122 self.assertEqual(new_msg['buffers'],[b'bar'])
117
123
118 A.close()
124 A.close()
119 B.close()
125 B.close()
120 ctx.term()
126 ctx.term()
121
127
122 def test_args(self):
128 def test_args(self):
123 """initialization arguments for Session"""
129 """initialization arguments for Session"""
124 s = self.session
130 s = self.session
125 self.assertTrue(s.pack is ss.default_packer)
131 self.assertTrue(s.pack is ss.default_packer)
126 self.assertTrue(s.unpack is ss.default_unpacker)
132 self.assertTrue(s.unpack is ss.default_unpacker)
127 self.assertEqual(s.username, os.environ.get('USER', u'username'))
133 self.assertEqual(s.username, os.environ.get('USER', u'username'))
128
134
129 s = ss.Session()
135 s = ss.Session()
130 self.assertEqual(s.username, os.environ.get('USER', u'username'))
136 self.assertEqual(s.username, os.environ.get('USER', u'username'))
131
137
132 self.assertRaises(TypeError, ss.Session, pack='hi')
138 self.assertRaises(TypeError, ss.Session, pack='hi')
133 self.assertRaises(TypeError, ss.Session, unpack='hi')
139 self.assertRaises(TypeError, ss.Session, unpack='hi')
134 u = str(uuid.uuid4())
140 u = str(uuid.uuid4())
135 s = ss.Session(username=u'carrot', session=u)
141 s = ss.Session(username=u'carrot', session=u)
136 self.assertEqual(s.session, u)
142 self.assertEqual(s.session, u)
137 self.assertEqual(s.username, u'carrot')
143 self.assertEqual(s.username, u'carrot')
138
144
139 def test_tracking(self):
145 def test_tracking(self):
140 """test tracking messages"""
146 """test tracking messages"""
141 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
147 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
142 s = self.session
148 s = self.session
143 s.copy_threshold = 1
149 s.copy_threshold = 1
144 stream = ZMQStream(a)
150 stream = ZMQStream(a)
145 msg = s.send(a, 'hello', track=False)
151 msg = s.send(a, 'hello', track=False)
146 self.assertTrue(msg['tracker'] is ss.DONE)
152 self.assertTrue(msg['tracker'] is ss.DONE)
147 msg = s.send(a, 'hello', track=True)
153 msg = s.send(a, 'hello', track=True)
148 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
154 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
149 M = zmq.Message(b'hi there', track=True)
155 M = zmq.Message(b'hi there', track=True)
150 msg = s.send(a, 'hello', buffers=[M], track=True)
156 msg = s.send(a, 'hello', buffers=[M], track=True)
151 t = msg['tracker']
157 t = msg['tracker']
152 self.assertTrue(isinstance(t, zmq.MessageTracker))
158 self.assertTrue(isinstance(t, zmq.MessageTracker))
153 self.assertRaises(zmq.NotDone, t.wait, .1)
159 self.assertRaises(zmq.NotDone, t.wait, .1)
154 del M
160 del M
155 t.wait(1) # this will raise
161 t.wait(1) # this will raise
156
162
157
163
158 # def test_rekey(self):
159 # """rekeying dict around json str keys"""
160 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
161 # self.assertRaises(KeyError, ss.rekey, d)
162 #
163 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
164 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
165 # rd = ss.rekey(d)
166 # self.assertEqual(d2,rd)
167 #
168 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
169 # d2 = {1.5:d['1.5'],1:d['1']}
170 # rd = ss.rekey(d)
171 # self.assertEqual(d2,rd)
172 #
173 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
174 # self.assertRaises(KeyError, ss.rekey, d)
175 #
176 def test_unique_msg_ids(self):
164 def test_unique_msg_ids(self):
177 """test that messages receive unique ids"""
165 """test that messages receive unique ids"""
178 ids = set()
166 ids = set()
179 for i in range(2**12):
167 for i in range(2**12):
180 h = self.session.msg_header('test')
168 h = self.session.msg_header('test')
181 msg_id = h['msg_id']
169 msg_id = h['msg_id']
182 self.assertTrue(msg_id not in ids)
170 self.assertTrue(msg_id not in ids)
183 ids.add(msg_id)
171 ids.add(msg_id)
184
172
185 def test_feed_identities(self):
173 def test_feed_identities(self):
186 """scrub the front for zmq IDENTITIES"""
174 """scrub the front for zmq IDENTITIES"""
187 theids = "engine client other".split()
175 theids = "engine client other".split()
188 content = dict(code='whoda',stuff=object())
176 content = dict(code='whoda',stuff=object())
189 themsg = self.session.msg('execute',content=content)
177 themsg = self.session.msg('execute',content=content)
190 pmsg = theids
178 pmsg = theids
191
179
192 def test_session_id(self):
180 def test_session_id(self):
193 session = ss.Session()
181 session = ss.Session()
194 # get bs before us
182 # get bs before us
195 bs = session.bsession
183 bs = session.bsession
196 us = session.session
184 us = session.session
197 self.assertEqual(us.encode('ascii'), bs)
185 self.assertEqual(us.encode('ascii'), bs)
198 session = ss.Session()
186 session = ss.Session()
199 # get us before bs
187 # get us before bs
200 us = session.session
188 us = session.session
201 bs = session.bsession
189 bs = session.bsession
202 self.assertEqual(us.encode('ascii'), bs)
190 self.assertEqual(us.encode('ascii'), bs)
203 # change propagates:
191 # change propagates:
204 session.session = 'something else'
192 session.session = 'something else'
205 bs = session.bsession
193 bs = session.bsession
206 us = session.session
194 us = session.session
207 self.assertEqual(us.encode('ascii'), bs)
195 self.assertEqual(us.encode('ascii'), bs)
208 session = ss.Session(session='stuff')
196 session = ss.Session(session='stuff')
209 # get us before bs
197 # get us before bs
210 self.assertEqual(session.bsession, session.session.encode('ascii'))
198 self.assertEqual(session.bsession, session.session.encode('ascii'))
211 self.assertEqual(b'stuff', session.bsession)
199 self.assertEqual(b'stuff', session.bsession)
212
200
213 def test_zero_digest_history(self):
201 def test_zero_digest_history(self):
214 session = ss.Session(digest_history_size=0)
202 session = ss.Session(digest_history_size=0)
215 for i in range(11):
203 for i in range(11):
216 session._add_digest(uuid.uuid4().bytes)
204 session._add_digest(uuid.uuid4().bytes)
217 self.assertEqual(len(session.digest_history), 0)
205 self.assertEqual(len(session.digest_history), 0)
218
206
219 def test_cull_digest_history(self):
207 def test_cull_digest_history(self):
220 session = ss.Session(digest_history_size=100)
208 session = ss.Session(digest_history_size=100)
221 for i in range(100):
209 for i in range(100):
222 session._add_digest(uuid.uuid4().bytes)
210 session._add_digest(uuid.uuid4().bytes)
223 self.assertTrue(len(session.digest_history) == 100)
211 self.assertTrue(len(session.digest_history) == 100)
224 session._add_digest(uuid.uuid4().bytes)
212 session._add_digest(uuid.uuid4().bytes)
225 self.assertTrue(len(session.digest_history) == 91)
213 self.assertTrue(len(session.digest_history) == 91)
226 for i in range(9):
214 for i in range(9):
227 session._add_digest(uuid.uuid4().bytes)
215 session._add_digest(uuid.uuid4().bytes)
228 self.assertTrue(len(session.digest_history) == 100)
216 self.assertTrue(len(session.digest_history) == 100)
229 session._add_digest(uuid.uuid4().bytes)
217 session._add_digest(uuid.uuid4().bytes)
230 self.assertTrue(len(session.digest_history) == 91)
218 self.assertTrue(len(session.digest_history) == 91)
231
219
232 def test_bad_pack(self):
220 def test_bad_pack(self):
233 try:
221 try:
234 session = ss.Session(pack=_bad_packer)
222 session = ss.Session(pack=_bad_packer)
235 except ValueError as e:
223 except ValueError as e:
236 self.assertIn("could not serialize", str(e))
224 self.assertIn("could not serialize", str(e))
237 self.assertIn("don't work", str(e))
225 self.assertIn("don't work", str(e))
238 else:
226 else:
239 self.fail("Should have raised ValueError")
227 self.fail("Should have raised ValueError")
240
228
241 def test_bad_unpack(self):
229 def test_bad_unpack(self):
242 try:
230 try:
243 session = ss.Session(unpack=_bad_unpacker)
231 session = ss.Session(unpack=_bad_unpacker)
244 except ValueError as e:
232 except ValueError as e:
245 self.assertIn("could not handle output", str(e))
233 self.assertIn("could not handle output", str(e))
246 self.assertIn("don't work either", str(e))
234 self.assertIn("don't work either", str(e))
247 else:
235 else:
248 self.fail("Should have raised ValueError")
236 self.fail("Should have raised ValueError")
249
237
250 def test_bad_packer(self):
238 def test_bad_packer(self):
251 try:
239 try:
252 session = ss.Session(packer=__name__ + '._bad_packer')
240 session = ss.Session(packer=__name__ + '._bad_packer')
253 except ValueError as e:
241 except ValueError as e:
254 self.assertIn("could not serialize", str(e))
242 self.assertIn("could not serialize", str(e))
255 self.assertIn("don't work", str(e))
243 self.assertIn("don't work", str(e))
256 else:
244 else:
257 self.fail("Should have raised ValueError")
245 self.fail("Should have raised ValueError")
258
246
259 def test_bad_unpacker(self):
247 def test_bad_unpacker(self):
260 try:
248 try:
261 session = ss.Session(unpacker=__name__ + '._bad_unpacker')
249 session = ss.Session(unpacker=__name__ + '._bad_unpacker')
262 except ValueError as e:
250 except ValueError as e:
263 self.assertIn("could not handle output", str(e))
251 self.assertIn("could not handle output", str(e))
264 self.assertIn("don't work either", str(e))
252 self.assertIn("don't work either", str(e))
265 else:
253 else:
266 self.fail("Should have raised ValueError")
254 self.fail("Should have raised ValueError")
267
255
268 def test_bad_roundtrip(self):
256 def test_bad_roundtrip(self):
269 with self.assertRaises(ValueError):
257 with self.assertRaises(ValueError):
270 session= ss.Session(unpack=lambda b: 5)
258 session = ss.Session(unpack=lambda b: 5)
259
260 def _datetime_test(self, session):
261 content = dict(t=datetime.now())
262 metadata = dict(t=datetime.now())
263 p = session.msg('msg')
264 msg = session.msg('msg', content=content, metadata=metadata, parent=p['header'])
265 smsg = session.serialize(msg)
266 msg2 = session.unserialize(session.feed_identities(smsg)[1])
267 assert isinstance(msg2['header']['date'], datetime)
268 self.assertEqual(msg['header'], msg2['header'])
269 self.assertEqual(msg['parent_header'], msg2['parent_header'])
270 self.assertEqual(msg['parent_header'], msg2['parent_header'])
271 assert isinstance(msg['content']['t'], datetime)
272 assert isinstance(msg['metadata']['t'], datetime)
273 assert isinstance(msg2['content']['t'], string_types)
274 assert isinstance(msg2['metadata']['t'], string_types)
275 self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
276 self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
277
278 def test_datetimes(self):
279 self._datetime_test(self.session)
280
281 def test_datetimes_pickle(self):
282 session = ss.Session(packer='pickle')
283 self._datetime_test(session)
284
285 @skipif(module_not_available('msgpack'))
286 def test_datetimes_msgpack(self):
287 session = ss.Session(packer='msgpack.packb', unpacker='msgpack.unpackb')
288 self._datetime_test(session)
271
289
@@ -1,1855 +1,1862
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 from __future__ import print_function
7 from __future__ import print_function
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 import os
19 import os
20 import json
20 import json
21 import sys
21 import sys
22 from threading import Thread, Event
22 from threading import Thread, Event
23 import time
23 import time
24 import warnings
24 import warnings
25 from datetime import datetime
25 from datetime import datetime
26 from getpass import getpass
26 from getpass import getpass
27 from pprint import pprint
27 from pprint import pprint
28
28
29 pjoin = os.path.join
29 pjoin = os.path.join
30
30
31 import zmq
31 import zmq
32 # from zmq.eventloop import ioloop, zmqstream
32 # from zmq.eventloop import ioloop, zmqstream
33
33
34 from IPython.config.configurable import MultipleInstanceError
34 from IPython.config.configurable import MultipleInstanceError
35 from IPython.core.application import BaseIPythonApplication
35 from IPython.core.application import BaseIPythonApplication
36 from IPython.core.profiledir import ProfileDir, ProfileDirError
36 from IPython.core.profiledir import ProfileDir, ProfileDirError
37
37
38 from IPython.utils.capture import RichOutput
38 from IPython.utils.capture import RichOutput
39 from IPython.utils.coloransi import TermColors
39 from IPython.utils.coloransi import TermColors
40 from IPython.utils.jsonutil import rekey
40 from IPython.utils.jsonutil import rekey, extract_dates, parse_date
41 from IPython.utils.localinterfaces import localhost, is_local_ip
41 from IPython.utils.localinterfaces import localhost, is_local_ip
42 from IPython.utils.path import get_ipython_dir
42 from IPython.utils.path import get_ipython_dir
43 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
43 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
44 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
44 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
45 Dict, List, Bool, Set, Any)
45 Dict, List, Bool, Set, Any)
46 from IPython.external.decorator import decorator
46 from IPython.external.decorator import decorator
47 from IPython.external.ssh import tunnel
47 from IPython.external.ssh import tunnel
48
48
49 from IPython.parallel import Reference
49 from IPython.parallel import Reference
50 from IPython.parallel import error
50 from IPython.parallel import error
51 from IPython.parallel import util
51 from IPython.parallel import util
52
52
53 from IPython.kernel.zmq.session import Session, Message
53 from IPython.kernel.zmq.session import Session, Message
54 from IPython.kernel.zmq import serialize
54 from IPython.kernel.zmq import serialize
55
55
56 from .asyncresult import AsyncResult, AsyncHubResult
56 from .asyncresult import AsyncResult, AsyncHubResult
57 from .view import DirectView, LoadBalancedView
57 from .view import DirectView, LoadBalancedView
58
58
59 #--------------------------------------------------------------------------
59 #--------------------------------------------------------------------------
60 # Decorators for Client methods
60 # Decorators for Client methods
61 #--------------------------------------------------------------------------
61 #--------------------------------------------------------------------------
62
62
63 @decorator
63 @decorator
64 def spin_first(f, self, *args, **kwargs):
64 def spin_first(f, self, *args, **kwargs):
65 """Call spin() to sync state prior to calling the method."""
65 """Call spin() to sync state prior to calling the method."""
66 self.spin()
66 self.spin()
67 return f(self, *args, **kwargs)
67 return f(self, *args, **kwargs)
68
68
69
69
70 #--------------------------------------------------------------------------
70 #--------------------------------------------------------------------------
71 # Classes
71 # Classes
72 #--------------------------------------------------------------------------
72 #--------------------------------------------------------------------------
73
73
74
74
75 class ExecuteReply(RichOutput):
75 class ExecuteReply(RichOutput):
76 """wrapper for finished Execute results"""
76 """wrapper for finished Execute results"""
77 def __init__(self, msg_id, content, metadata):
77 def __init__(self, msg_id, content, metadata):
78 self.msg_id = msg_id
78 self.msg_id = msg_id
79 self._content = content
79 self._content = content
80 self.execution_count = content['execution_count']
80 self.execution_count = content['execution_count']
81 self.metadata = metadata
81 self.metadata = metadata
82
82
83 # RichOutput overrides
83 # RichOutput overrides
84
84
85 @property
85 @property
86 def source(self):
86 def source(self):
87 pyout = self.metadata['pyout']
87 pyout = self.metadata['pyout']
88 if pyout:
88 if pyout:
89 return pyout.get('source', '')
89 return pyout.get('source', '')
90
90
91 @property
91 @property
92 def data(self):
92 def data(self):
93 pyout = self.metadata['pyout']
93 pyout = self.metadata['pyout']
94 if pyout:
94 if pyout:
95 return pyout.get('data', {})
95 return pyout.get('data', {})
96
96
97 @property
97 @property
98 def _metadata(self):
98 def _metadata(self):
99 pyout = self.metadata['pyout']
99 pyout = self.metadata['pyout']
100 if pyout:
100 if pyout:
101 return pyout.get('metadata', {})
101 return pyout.get('metadata', {})
102
102
103 def display(self):
103 def display(self):
104 from IPython.display import publish_display_data
104 from IPython.display import publish_display_data
105 publish_display_data(self.source, self.data, self.metadata)
105 publish_display_data(self.source, self.data, self.metadata)
106
106
107 def _repr_mime_(self, mime):
107 def _repr_mime_(self, mime):
108 if mime not in self.data:
108 if mime not in self.data:
109 return
109 return
110 data = self.data[mime]
110 data = self.data[mime]
111 if mime in self._metadata:
111 if mime in self._metadata:
112 return data, self._metadata[mime]
112 return data, self._metadata[mime]
113 else:
113 else:
114 return data
114 return data
115
115
116 def __getitem__(self, key):
116 def __getitem__(self, key):
117 return self.metadata[key]
117 return self.metadata[key]
118
118
119 def __getattr__(self, key):
119 def __getattr__(self, key):
120 if key not in self.metadata:
120 if key not in self.metadata:
121 raise AttributeError(key)
121 raise AttributeError(key)
122 return self.metadata[key]
122 return self.metadata[key]
123
123
124 def __repr__(self):
124 def __repr__(self):
125 pyout = self.metadata['pyout'] or {'data':{}}
125 pyout = self.metadata['pyout'] or {'data':{}}
126 text_out = pyout['data'].get('text/plain', '')
126 text_out = pyout['data'].get('text/plain', '')
127 if len(text_out) > 32:
127 if len(text_out) > 32:
128 text_out = text_out[:29] + '...'
128 text_out = text_out[:29] + '...'
129
129
130 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
130 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
131
131
132 def _repr_pretty_(self, p, cycle):
132 def _repr_pretty_(self, p, cycle):
133 pyout = self.metadata['pyout'] or {'data':{}}
133 pyout = self.metadata['pyout'] or {'data':{}}
134 text_out = pyout['data'].get('text/plain', '')
134 text_out = pyout['data'].get('text/plain', '')
135
135
136 if not text_out:
136 if not text_out:
137 return
137 return
138
138
139 try:
139 try:
140 ip = get_ipython()
140 ip = get_ipython()
141 except NameError:
141 except NameError:
142 colors = "NoColor"
142 colors = "NoColor"
143 else:
143 else:
144 colors = ip.colors
144 colors = ip.colors
145
145
146 if colors == "NoColor":
146 if colors == "NoColor":
147 out = normal = ""
147 out = normal = ""
148 else:
148 else:
149 out = TermColors.Red
149 out = TermColors.Red
150 normal = TermColors.Normal
150 normal = TermColors.Normal
151
151
152 if '\n' in text_out and not text_out.startswith('\n'):
152 if '\n' in text_out and not text_out.startswith('\n'):
153 # add newline for multiline reprs
153 # add newline for multiline reprs
154 text_out = '\n' + text_out
154 text_out = '\n' + text_out
155
155
156 p.text(
156 p.text(
157 out + u'Out[%i:%i]: ' % (
157 out + u'Out[%i:%i]: ' % (
158 self.metadata['engine_id'], self.execution_count
158 self.metadata['engine_id'], self.execution_count
159 ) + normal + text_out
159 ) + normal + text_out
160 )
160 )
161
161
162
162
163 class Metadata(dict):
163 class Metadata(dict):
164 """Subclass of dict for initializing metadata values.
164 """Subclass of dict for initializing metadata values.
165
165
166 Attribute access works on keys.
166 Attribute access works on keys.
167
167
168 These objects have a strict set of keys - errors will raise if you try
168 These objects have a strict set of keys - errors will raise if you try
169 to add new keys.
169 to add new keys.
170 """
170 """
171 def __init__(self, *args, **kwargs):
171 def __init__(self, *args, **kwargs):
172 dict.__init__(self)
172 dict.__init__(self)
173 md = {'msg_id' : None,
173 md = {'msg_id' : None,
174 'submitted' : None,
174 'submitted' : None,
175 'started' : None,
175 'started' : None,
176 'completed' : None,
176 'completed' : None,
177 'received' : None,
177 'received' : None,
178 'engine_uuid' : None,
178 'engine_uuid' : None,
179 'engine_id' : None,
179 'engine_id' : None,
180 'follow' : None,
180 'follow' : None,
181 'after' : None,
181 'after' : None,
182 'status' : None,
182 'status' : None,
183
183
184 'pyin' : None,
184 'pyin' : None,
185 'pyout' : None,
185 'pyout' : None,
186 'pyerr' : None,
186 'pyerr' : None,
187 'stdout' : '',
187 'stdout' : '',
188 'stderr' : '',
188 'stderr' : '',
189 'outputs' : [],
189 'outputs' : [],
190 'data': {},
190 'data': {},
191 'outputs_ready' : False,
191 'outputs_ready' : False,
192 }
192 }
193 self.update(md)
193 self.update(md)
194 self.update(dict(*args, **kwargs))
194 self.update(dict(*args, **kwargs))
195
195
196 def __getattr__(self, key):
196 def __getattr__(self, key):
197 """getattr aliased to getitem"""
197 """getattr aliased to getitem"""
198 if key in self:
198 if key in self:
199 return self[key]
199 return self[key]
200 else:
200 else:
201 raise AttributeError(key)
201 raise AttributeError(key)
202
202
203 def __setattr__(self, key, value):
203 def __setattr__(self, key, value):
204 """setattr aliased to setitem, with strict"""
204 """setattr aliased to setitem, with strict"""
205 if key in self:
205 if key in self:
206 self[key] = value
206 self[key] = value
207 else:
207 else:
208 raise AttributeError(key)
208 raise AttributeError(key)
209
209
210 def __setitem__(self, key, value):
210 def __setitem__(self, key, value):
211 """strict static key enforcement"""
211 """strict static key enforcement"""
212 if key in self:
212 if key in self:
213 dict.__setitem__(self, key, value)
213 dict.__setitem__(self, key, value)
214 else:
214 else:
215 raise KeyError(key)
215 raise KeyError(key)
216
216
217
217
218 class Client(HasTraits):
218 class Client(HasTraits):
219 """A semi-synchronous client to the IPython ZMQ cluster
219 """A semi-synchronous client to the IPython ZMQ cluster
220
220
221 Parameters
221 Parameters
222 ----------
222 ----------
223
223
224 url_file : str/unicode; path to ipcontroller-client.json
224 url_file : str/unicode; path to ipcontroller-client.json
225 This JSON file should contain all the information needed to connect to a cluster,
225 This JSON file should contain all the information needed to connect to a cluster,
226 and is likely the only argument needed.
226 and is likely the only argument needed.
227 Connection information for the Hub's registration. If a json connector
227 Connection information for the Hub's registration. If a json connector
228 file is given, then likely no further configuration is necessary.
228 file is given, then likely no further configuration is necessary.
229 [Default: use profile]
229 [Default: use profile]
230 profile : bytes
230 profile : bytes
231 The name of the Cluster profile to be used to find connector information.
231 The name of the Cluster profile to be used to find connector information.
232 If run from an IPython application, the default profile will be the same
232 If run from an IPython application, the default profile will be the same
233 as the running application, otherwise it will be 'default'.
233 as the running application, otherwise it will be 'default'.
234 cluster_id : str
234 cluster_id : str
235 String id to added to runtime files, to prevent name collisions when using
235 String id to added to runtime files, to prevent name collisions when using
236 multiple clusters with a single profile simultaneously.
236 multiple clusters with a single profile simultaneously.
237 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
237 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
238 Since this is text inserted into filenames, typical recommendations apply:
238 Since this is text inserted into filenames, typical recommendations apply:
239 Simple character strings are ideal, and spaces are not recommended (but
239 Simple character strings are ideal, and spaces are not recommended (but
240 should generally work)
240 should generally work)
241 context : zmq.Context
241 context : zmq.Context
242 Pass an existing zmq.Context instance, otherwise the client will create its own.
242 Pass an existing zmq.Context instance, otherwise the client will create its own.
243 debug : bool
243 debug : bool
244 flag for lots of message printing for debug purposes
244 flag for lots of message printing for debug purposes
245 timeout : int/float
245 timeout : int/float
246 time (in seconds) to wait for connection replies from the Hub
246 time (in seconds) to wait for connection replies from the Hub
247 [Default: 10]
247 [Default: 10]
248
248
249 #-------------- session related args ----------------
249 #-------------- session related args ----------------
250
250
251 config : Config object
251 config : Config object
252 If specified, this will be relayed to the Session for configuration
252 If specified, this will be relayed to the Session for configuration
253 username : str
253 username : str
254 set username for the session object
254 set username for the session object
255
255
256 #-------------- ssh related args ----------------
256 #-------------- ssh related args ----------------
257 # These are args for configuring the ssh tunnel to be used
257 # These are args for configuring the ssh tunnel to be used
258 # credentials are used to forward connections over ssh to the Controller
258 # credentials are used to forward connections over ssh to the Controller
259 # Note that the ip given in `addr` needs to be relative to sshserver
259 # Note that the ip given in `addr` needs to be relative to sshserver
260 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
260 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
261 # and set sshserver as the same machine the Controller is on. However,
261 # and set sshserver as the same machine the Controller is on. However,
262 # the only requirement is that sshserver is able to see the Controller
262 # the only requirement is that sshserver is able to see the Controller
263 # (i.e. is within the same trusted network).
263 # (i.e. is within the same trusted network).
264
264
265 sshserver : str
265 sshserver : str
266 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
266 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
267 If keyfile or password is specified, and this is not, it will default to
267 If keyfile or password is specified, and this is not, it will default to
268 the ip given in addr.
268 the ip given in addr.
269 sshkey : str; path to ssh private key file
269 sshkey : str; path to ssh private key file
270 This specifies a key to be used in ssh login, default None.
270 This specifies a key to be used in ssh login, default None.
271 Regular default ssh keys will be used without specifying this argument.
271 Regular default ssh keys will be used without specifying this argument.
272 password : str
272 password : str
273 Your ssh password to sshserver. Note that if this is left None,
273 Your ssh password to sshserver. Note that if this is left None,
274 you will be prompted for it if passwordless key based login is unavailable.
274 you will be prompted for it if passwordless key based login is unavailable.
275 paramiko : bool
275 paramiko : bool
276 flag for whether to use paramiko instead of shell ssh for tunneling.
276 flag for whether to use paramiko instead of shell ssh for tunneling.
277 [default: True on win32, False else]
277 [default: True on win32, False else]
278
278
279
279
280 Attributes
280 Attributes
281 ----------
281 ----------
282
282
283 ids : list of int engine IDs
283 ids : list of int engine IDs
284 requesting the ids attribute always synchronizes
284 requesting the ids attribute always synchronizes
285 the registration state. To request ids without synchronization,
285 the registration state. To request ids without synchronization,
286 use semi-private _ids attributes.
286 use semi-private _ids attributes.
287
287
288 history : list of msg_ids
288 history : list of msg_ids
289 a list of msg_ids, keeping track of all the execution
289 a list of msg_ids, keeping track of all the execution
290 messages you have submitted in order.
290 messages you have submitted in order.
291
291
292 outstanding : set of msg_ids
292 outstanding : set of msg_ids
293 a set of msg_ids that have been submitted, but whose
293 a set of msg_ids that have been submitted, but whose
294 results have not yet been received.
294 results have not yet been received.
295
295
296 results : dict
296 results : dict
297 a dict of all our results, keyed by msg_id
297 a dict of all our results, keyed by msg_id
298
298
299 block : bool
299 block : bool
300 determines default behavior when block not specified
300 determines default behavior when block not specified
301 in execution methods
301 in execution methods
302
302
303 Methods
303 Methods
304 -------
304 -------
305
305
306 spin
306 spin
307 flushes incoming results and registration state changes
307 flushes incoming results and registration state changes
308 control methods spin, and requesting `ids` also ensures up to date
308 control methods spin, and requesting `ids` also ensures up to date
309
309
310 wait
310 wait
311 wait on one or more msg_ids
311 wait on one or more msg_ids
312
312
313 execution methods
313 execution methods
314 apply
314 apply
315 legacy: execute, run
315 legacy: execute, run
316
316
317 data movement
317 data movement
318 push, pull, scatter, gather
318 push, pull, scatter, gather
319
319
320 query methods
320 query methods
321 queue_status, get_result, purge, result_status
321 queue_status, get_result, purge, result_status
322
322
323 control methods
323 control methods
324 abort, shutdown
324 abort, shutdown
325
325
326 """
326 """
327
327
328
328
329 block = Bool(False)
329 block = Bool(False)
330 outstanding = Set()
330 outstanding = Set()
331 results = Instance('collections.defaultdict', (dict,))
331 results = Instance('collections.defaultdict', (dict,))
332 metadata = Instance('collections.defaultdict', (Metadata,))
332 metadata = Instance('collections.defaultdict', (Metadata,))
333 history = List()
333 history = List()
334 debug = Bool(False)
334 debug = Bool(False)
335 _spin_thread = Any()
335 _spin_thread = Any()
336 _stop_spinning = Any()
336 _stop_spinning = Any()
337
337
338 profile=Unicode()
338 profile=Unicode()
339 def _profile_default(self):
339 def _profile_default(self):
340 if BaseIPythonApplication.initialized():
340 if BaseIPythonApplication.initialized():
341 # an IPython app *might* be running, try to get its profile
341 # an IPython app *might* be running, try to get its profile
342 try:
342 try:
343 return BaseIPythonApplication.instance().profile
343 return BaseIPythonApplication.instance().profile
344 except (AttributeError, MultipleInstanceError):
344 except (AttributeError, MultipleInstanceError):
345 # could be a *different* subclass of config.Application,
345 # could be a *different* subclass of config.Application,
346 # which would raise one of these two errors.
346 # which would raise one of these two errors.
347 return u'default'
347 return u'default'
348 else:
348 else:
349 return u'default'
349 return u'default'
350
350
351
351
352 _outstanding_dict = Instance('collections.defaultdict', (set,))
352 _outstanding_dict = Instance('collections.defaultdict', (set,))
353 _ids = List()
353 _ids = List()
354 _connected=Bool(False)
354 _connected=Bool(False)
355 _ssh=Bool(False)
355 _ssh=Bool(False)
356 _context = Instance('zmq.Context')
356 _context = Instance('zmq.Context')
357 _config = Dict()
357 _config = Dict()
358 _engines=Instance(util.ReverseDict, (), {})
358 _engines=Instance(util.ReverseDict, (), {})
359 # _hub_socket=Instance('zmq.Socket')
359 # _hub_socket=Instance('zmq.Socket')
360 _query_socket=Instance('zmq.Socket')
360 _query_socket=Instance('zmq.Socket')
361 _control_socket=Instance('zmq.Socket')
361 _control_socket=Instance('zmq.Socket')
362 _iopub_socket=Instance('zmq.Socket')
362 _iopub_socket=Instance('zmq.Socket')
363 _notification_socket=Instance('zmq.Socket')
363 _notification_socket=Instance('zmq.Socket')
364 _mux_socket=Instance('zmq.Socket')
364 _mux_socket=Instance('zmq.Socket')
365 _task_socket=Instance('zmq.Socket')
365 _task_socket=Instance('zmq.Socket')
366 _task_scheme=Unicode()
366 _task_scheme=Unicode()
367 _closed = False
367 _closed = False
368 _ignored_control_replies=Integer(0)
368 _ignored_control_replies=Integer(0)
369 _ignored_hub_replies=Integer(0)
369 _ignored_hub_replies=Integer(0)
370
370
371 def __new__(self, *args, **kw):
371 def __new__(self, *args, **kw):
372 # don't raise on positional args
372 # don't raise on positional args
373 return HasTraits.__new__(self, **kw)
373 return HasTraits.__new__(self, **kw)
374
374
375 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
375 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
376 context=None, debug=False,
376 context=None, debug=False,
377 sshserver=None, sshkey=None, password=None, paramiko=None,
377 sshserver=None, sshkey=None, password=None, paramiko=None,
378 timeout=10, cluster_id=None, **extra_args
378 timeout=10, cluster_id=None, **extra_args
379 ):
379 ):
380 if profile:
380 if profile:
381 super(Client, self).__init__(debug=debug, profile=profile)
381 super(Client, self).__init__(debug=debug, profile=profile)
382 else:
382 else:
383 super(Client, self).__init__(debug=debug)
383 super(Client, self).__init__(debug=debug)
384 if context is None:
384 if context is None:
385 context = zmq.Context.instance()
385 context = zmq.Context.instance()
386 self._context = context
386 self._context = context
387 self._stop_spinning = Event()
387 self._stop_spinning = Event()
388
388
389 if 'url_or_file' in extra_args:
389 if 'url_or_file' in extra_args:
390 url_file = extra_args['url_or_file']
390 url_file = extra_args['url_or_file']
391 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
391 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
392
392
393 if url_file and util.is_url(url_file):
393 if url_file and util.is_url(url_file):
394 raise ValueError("single urls cannot be specified, url-files must be used.")
394 raise ValueError("single urls cannot be specified, url-files must be used.")
395
395
396 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
396 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
397
397
398 if self._cd is not None:
398 if self._cd is not None:
399 if url_file is None:
399 if url_file is None:
400 if not cluster_id:
400 if not cluster_id:
401 client_json = 'ipcontroller-client.json'
401 client_json = 'ipcontroller-client.json'
402 else:
402 else:
403 client_json = 'ipcontroller-%s-client.json' % cluster_id
403 client_json = 'ipcontroller-%s-client.json' % cluster_id
404 url_file = pjoin(self._cd.security_dir, client_json)
404 url_file = pjoin(self._cd.security_dir, client_json)
405 if url_file is None:
405 if url_file is None:
406 raise ValueError(
406 raise ValueError(
407 "I can't find enough information to connect to a hub!"
407 "I can't find enough information to connect to a hub!"
408 " Please specify at least one of url_file or profile."
408 " Please specify at least one of url_file or profile."
409 )
409 )
410
410
411 with open(url_file) as f:
411 with open(url_file) as f:
412 cfg = json.load(f)
412 cfg = json.load(f)
413
413
414 self._task_scheme = cfg['task_scheme']
414 self._task_scheme = cfg['task_scheme']
415
415
416 # sync defaults from args, json:
416 # sync defaults from args, json:
417 if sshserver:
417 if sshserver:
418 cfg['ssh'] = sshserver
418 cfg['ssh'] = sshserver
419
419
420 location = cfg.setdefault('location', None)
420 location = cfg.setdefault('location', None)
421
421
422 proto,addr = cfg['interface'].split('://')
422 proto,addr = cfg['interface'].split('://')
423 addr = util.disambiguate_ip_address(addr, location)
423 addr = util.disambiguate_ip_address(addr, location)
424 cfg['interface'] = "%s://%s" % (proto, addr)
424 cfg['interface'] = "%s://%s" % (proto, addr)
425
425
426 # turn interface,port into full urls:
426 # turn interface,port into full urls:
427 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
427 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
428 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
428 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
429
429
430 url = cfg['registration']
430 url = cfg['registration']
431
431
432 if location is not None and addr == localhost():
432 if location is not None and addr == localhost():
433 # location specified, and connection is expected to be local
433 # location specified, and connection is expected to be local
434 if not is_local_ip(location) and not sshserver:
434 if not is_local_ip(location) and not sshserver:
435 # load ssh from JSON *only* if the controller is not on
435 # load ssh from JSON *only* if the controller is not on
436 # this machine
436 # this machine
437 sshserver=cfg['ssh']
437 sshserver=cfg['ssh']
438 if not is_local_ip(location) and not sshserver:
438 if not is_local_ip(location) and not sshserver:
439 # warn if no ssh specified, but SSH is probably needed
439 # warn if no ssh specified, but SSH is probably needed
440 # This is only a warning, because the most likely cause
440 # This is only a warning, because the most likely cause
441 # is a local Controller on a laptop whose IP is dynamic
441 # is a local Controller on a laptop whose IP is dynamic
442 warnings.warn("""
442 warnings.warn("""
443 Controller appears to be listening on localhost, but not on this machine.
443 Controller appears to be listening on localhost, but not on this machine.
444 If this is true, you should specify Client(...,sshserver='you@%s')
444 If this is true, you should specify Client(...,sshserver='you@%s')
445 or instruct your controller to listen on an external IP."""%location,
445 or instruct your controller to listen on an external IP."""%location,
446 RuntimeWarning)
446 RuntimeWarning)
447 elif not sshserver:
447 elif not sshserver:
448 # otherwise sync with cfg
448 # otherwise sync with cfg
449 sshserver = cfg['ssh']
449 sshserver = cfg['ssh']
450
450
451 self._config = cfg
451 self._config = cfg
452
452
453 self._ssh = bool(sshserver or sshkey or password)
453 self._ssh = bool(sshserver or sshkey or password)
454 if self._ssh and sshserver is None:
454 if self._ssh and sshserver is None:
455 # default to ssh via localhost
455 # default to ssh via localhost
456 sshserver = addr
456 sshserver = addr
457 if self._ssh and password is None:
457 if self._ssh and password is None:
458 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
458 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
459 password=False
459 password=False
460 else:
460 else:
461 password = getpass("SSH Password for %s: "%sshserver)
461 password = getpass("SSH Password for %s: "%sshserver)
462 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
462 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
463
463
464 # configure and construct the session
464 # configure and construct the session
465 try:
465 try:
466 extra_args['packer'] = cfg['pack']
466 extra_args['packer'] = cfg['pack']
467 extra_args['unpacker'] = cfg['unpack']
467 extra_args['unpacker'] = cfg['unpack']
468 extra_args['key'] = cast_bytes(cfg['key'])
468 extra_args['key'] = cast_bytes(cfg['key'])
469 extra_args['signature_scheme'] = cfg['signature_scheme']
469 extra_args['signature_scheme'] = cfg['signature_scheme']
470 except KeyError as exc:
470 except KeyError as exc:
471 msg = '\n'.join([
471 msg = '\n'.join([
472 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
472 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
473 "If you are reusing connection files, remove them and start ipcontroller again."
473 "If you are reusing connection files, remove them and start ipcontroller again."
474 ])
474 ])
475 raise ValueError(msg.format(exc.message))
475 raise ValueError(msg.format(exc.message))
476
476
477 self.session = Session(**extra_args)
477 self.session = Session(**extra_args)
478
478
479 self._query_socket = self._context.socket(zmq.DEALER)
479 self._query_socket = self._context.socket(zmq.DEALER)
480
480
481 if self._ssh:
481 if self._ssh:
482 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
482 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
483 else:
483 else:
484 self._query_socket.connect(cfg['registration'])
484 self._query_socket.connect(cfg['registration'])
485
485
486 self.session.debug = self.debug
486 self.session.debug = self.debug
487
487
488 self._notification_handlers = {'registration_notification' : self._register_engine,
488 self._notification_handlers = {'registration_notification' : self._register_engine,
489 'unregistration_notification' : self._unregister_engine,
489 'unregistration_notification' : self._unregister_engine,
490 'shutdown_notification' : lambda msg: self.close(),
490 'shutdown_notification' : lambda msg: self.close(),
491 }
491 }
492 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
492 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
493 'apply_reply' : self._handle_apply_reply}
493 'apply_reply' : self._handle_apply_reply}
494
494
495 try:
495 try:
496 self._connect(sshserver, ssh_kwargs, timeout)
496 self._connect(sshserver, ssh_kwargs, timeout)
497 except:
497 except:
498 self.close(linger=0)
498 self.close(linger=0)
499 raise
499 raise
500
500
501 # last step: setup magics, if we are in IPython:
501 # last step: setup magics, if we are in IPython:
502
502
503 try:
503 try:
504 ip = get_ipython()
504 ip = get_ipython()
505 except NameError:
505 except NameError:
506 return
506 return
507 else:
507 else:
508 if 'px' not in ip.magics_manager.magics:
508 if 'px' not in ip.magics_manager.magics:
509 # in IPython but we are the first Client.
509 # in IPython but we are the first Client.
510 # activate a default view for parallel magics.
510 # activate a default view for parallel magics.
511 self.activate()
511 self.activate()
512
512
513 def __del__(self):
513 def __del__(self):
514 """cleanup sockets, but _not_ context."""
514 """cleanup sockets, but _not_ context."""
515 self.close()
515 self.close()
516
516
517 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
517 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
518 if ipython_dir is None:
518 if ipython_dir is None:
519 ipython_dir = get_ipython_dir()
519 ipython_dir = get_ipython_dir()
520 if profile_dir is not None:
520 if profile_dir is not None:
521 try:
521 try:
522 self._cd = ProfileDir.find_profile_dir(profile_dir)
522 self._cd = ProfileDir.find_profile_dir(profile_dir)
523 return
523 return
524 except ProfileDirError:
524 except ProfileDirError:
525 pass
525 pass
526 elif profile is not None:
526 elif profile is not None:
527 try:
527 try:
528 self._cd = ProfileDir.find_profile_dir_by_name(
528 self._cd = ProfileDir.find_profile_dir_by_name(
529 ipython_dir, profile)
529 ipython_dir, profile)
530 return
530 return
531 except ProfileDirError:
531 except ProfileDirError:
532 pass
532 pass
533 self._cd = None
533 self._cd = None
534
534
535 def _update_engines(self, engines):
535 def _update_engines(self, engines):
536 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
536 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
537 for k,v in iteritems(engines):
537 for k,v in iteritems(engines):
538 eid = int(k)
538 eid = int(k)
539 if eid not in self._engines:
539 if eid not in self._engines:
540 self._ids.append(eid)
540 self._ids.append(eid)
541 self._engines[eid] = v
541 self._engines[eid] = v
542 self._ids = sorted(self._ids)
542 self._ids = sorted(self._ids)
543 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
543 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
544 self._task_scheme == 'pure' and self._task_socket:
544 self._task_scheme == 'pure' and self._task_socket:
545 self._stop_scheduling_tasks()
545 self._stop_scheduling_tasks()
546
546
547 def _stop_scheduling_tasks(self):
547 def _stop_scheduling_tasks(self):
548 """Stop scheduling tasks because an engine has been unregistered
548 """Stop scheduling tasks because an engine has been unregistered
549 from a pure ZMQ scheduler.
549 from a pure ZMQ scheduler.
550 """
550 """
551 self._task_socket.close()
551 self._task_socket.close()
552 self._task_socket = None
552 self._task_socket = None
553 msg = "An engine has been unregistered, and we are using pure " +\
553 msg = "An engine has been unregistered, and we are using pure " +\
554 "ZMQ task scheduling. Task farming will be disabled."
554 "ZMQ task scheduling. Task farming will be disabled."
555 if self.outstanding:
555 if self.outstanding:
556 msg += " If you were running tasks when this happened, " +\
556 msg += " If you were running tasks when this happened, " +\
557 "some `outstanding` msg_ids may never resolve."
557 "some `outstanding` msg_ids may never resolve."
558 warnings.warn(msg, RuntimeWarning)
558 warnings.warn(msg, RuntimeWarning)
559
559
560 def _build_targets(self, targets):
560 def _build_targets(self, targets):
561 """Turn valid target IDs or 'all' into two lists:
561 """Turn valid target IDs or 'all' into two lists:
562 (int_ids, uuids).
562 (int_ids, uuids).
563 """
563 """
564 if not self._ids:
564 if not self._ids:
565 # flush notification socket if no engines yet, just in case
565 # flush notification socket if no engines yet, just in case
566 if not self.ids:
566 if not self.ids:
567 raise error.NoEnginesRegistered("Can't build targets without any engines")
567 raise error.NoEnginesRegistered("Can't build targets without any engines")
568
568
569 if targets is None:
569 if targets is None:
570 targets = self._ids
570 targets = self._ids
571 elif isinstance(targets, string_types):
571 elif isinstance(targets, string_types):
572 if targets.lower() == 'all':
572 if targets.lower() == 'all':
573 targets = self._ids
573 targets = self._ids
574 else:
574 else:
575 raise TypeError("%r not valid str target, must be 'all'"%(targets))
575 raise TypeError("%r not valid str target, must be 'all'"%(targets))
576 elif isinstance(targets, int):
576 elif isinstance(targets, int):
577 if targets < 0:
577 if targets < 0:
578 targets = self.ids[targets]
578 targets = self.ids[targets]
579 if targets not in self._ids:
579 if targets not in self._ids:
580 raise IndexError("No such engine: %i"%targets)
580 raise IndexError("No such engine: %i"%targets)
581 targets = [targets]
581 targets = [targets]
582
582
583 if isinstance(targets, slice):
583 if isinstance(targets, slice):
584 indices = list(range(len(self._ids))[targets])
584 indices = list(range(len(self._ids))[targets])
585 ids = self.ids
585 ids = self.ids
586 targets = [ ids[i] for i in indices ]
586 targets = [ ids[i] for i in indices ]
587
587
588 if not isinstance(targets, (tuple, list, xrange)):
588 if not isinstance(targets, (tuple, list, xrange)):
589 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
589 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
590
590
591 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
591 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
592
592
593 def _connect(self, sshserver, ssh_kwargs, timeout):
593 def _connect(self, sshserver, ssh_kwargs, timeout):
594 """setup all our socket connections to the cluster. This is called from
594 """setup all our socket connections to the cluster. This is called from
595 __init__."""
595 __init__."""
596
596
597 # Maybe allow reconnecting?
597 # Maybe allow reconnecting?
598 if self._connected:
598 if self._connected:
599 return
599 return
600 self._connected=True
600 self._connected=True
601
601
602 def connect_socket(s, url):
602 def connect_socket(s, url):
603 if self._ssh:
603 if self._ssh:
604 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
604 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
605 else:
605 else:
606 return s.connect(url)
606 return s.connect(url)
607
607
608 self.session.send(self._query_socket, 'connection_request')
608 self.session.send(self._query_socket, 'connection_request')
609 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
609 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
610 poller = zmq.Poller()
610 poller = zmq.Poller()
611 poller.register(self._query_socket, zmq.POLLIN)
611 poller.register(self._query_socket, zmq.POLLIN)
612 # poll expects milliseconds, timeout is seconds
612 # poll expects milliseconds, timeout is seconds
613 evts = poller.poll(timeout*1000)
613 evts = poller.poll(timeout*1000)
614 if not evts:
614 if not evts:
615 raise error.TimeoutError("Hub connection request timed out")
615 raise error.TimeoutError("Hub connection request timed out")
616 idents,msg = self.session.recv(self._query_socket,mode=0)
616 idents,msg = self.session.recv(self._query_socket,mode=0)
617 if self.debug:
617 if self.debug:
618 pprint(msg)
618 pprint(msg)
619 content = msg['content']
619 content = msg['content']
620 # self._config['registration'] = dict(content)
620 # self._config['registration'] = dict(content)
621 cfg = self._config
621 cfg = self._config
622 if content['status'] == 'ok':
622 if content['status'] == 'ok':
623 self._mux_socket = self._context.socket(zmq.DEALER)
623 self._mux_socket = self._context.socket(zmq.DEALER)
624 connect_socket(self._mux_socket, cfg['mux'])
624 connect_socket(self._mux_socket, cfg['mux'])
625
625
626 self._task_socket = self._context.socket(zmq.DEALER)
626 self._task_socket = self._context.socket(zmq.DEALER)
627 connect_socket(self._task_socket, cfg['task'])
627 connect_socket(self._task_socket, cfg['task'])
628
628
629 self._notification_socket = self._context.socket(zmq.SUB)
629 self._notification_socket = self._context.socket(zmq.SUB)
630 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
630 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
631 connect_socket(self._notification_socket, cfg['notification'])
631 connect_socket(self._notification_socket, cfg['notification'])
632
632
633 self._control_socket = self._context.socket(zmq.DEALER)
633 self._control_socket = self._context.socket(zmq.DEALER)
634 connect_socket(self._control_socket, cfg['control'])
634 connect_socket(self._control_socket, cfg['control'])
635
635
636 self._iopub_socket = self._context.socket(zmq.SUB)
636 self._iopub_socket = self._context.socket(zmq.SUB)
637 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
637 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
638 connect_socket(self._iopub_socket, cfg['iopub'])
638 connect_socket(self._iopub_socket, cfg['iopub'])
639
639
640 self._update_engines(dict(content['engines']))
640 self._update_engines(dict(content['engines']))
641 else:
641 else:
642 self._connected = False
642 self._connected = False
643 raise Exception("Failed to connect!")
643 raise Exception("Failed to connect!")
644
644
645 #--------------------------------------------------------------------------
645 #--------------------------------------------------------------------------
646 # handlers and callbacks for incoming messages
646 # handlers and callbacks for incoming messages
647 #--------------------------------------------------------------------------
647 #--------------------------------------------------------------------------
648
648
649 def _unwrap_exception(self, content):
649 def _unwrap_exception(self, content):
650 """unwrap exception, and remap engine_id to int."""
650 """unwrap exception, and remap engine_id to int."""
651 e = error.unwrap_exception(content)
651 e = error.unwrap_exception(content)
652 # print e.traceback
652 # print e.traceback
653 if e.engine_info:
653 if e.engine_info:
654 e_uuid = e.engine_info['engine_uuid']
654 e_uuid = e.engine_info['engine_uuid']
655 eid = self._engines[e_uuid]
655 eid = self._engines[e_uuid]
656 e.engine_info['engine_id'] = eid
656 e.engine_info['engine_id'] = eid
657 return e
657 return e
658
658
659 def _extract_metadata(self, msg):
659 def _extract_metadata(self, msg):
660 header = msg['header']
660 header = msg['header']
661 parent = msg['parent_header']
661 parent = msg['parent_header']
662 msg_meta = msg['metadata']
662 msg_meta = msg['metadata']
663 content = msg['content']
663 content = msg['content']
664 md = {'msg_id' : parent['msg_id'],
664 md = {'msg_id' : parent['msg_id'],
665 'received' : datetime.now(),
665 'received' : datetime.now(),
666 'engine_uuid' : msg_meta.get('engine', None),
666 'engine_uuid' : msg_meta.get('engine', None),
667 'follow' : msg_meta.get('follow', []),
667 'follow' : msg_meta.get('follow', []),
668 'after' : msg_meta.get('after', []),
668 'after' : msg_meta.get('after', []),
669 'status' : content['status'],
669 'status' : content['status'],
670 }
670 }
671
671
672 if md['engine_uuid'] is not None:
672 if md['engine_uuid'] is not None:
673 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
673 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
674
674
675 if 'date' in parent:
675 if 'date' in parent:
676 md['submitted'] = parent['date']
676 md['submitted'] = parent['date']
677 if 'started' in msg_meta:
677 if 'started' in msg_meta:
678 md['started'] = msg_meta['started']
678 md['started'] = parse_date(msg_meta['started'])
679 if 'date' in header:
679 if 'date' in header:
680 md['completed'] = header['date']
680 md['completed'] = header['date']
681 return md
681 return md
682
682
683 def _register_engine(self, msg):
683 def _register_engine(self, msg):
684 """Register a new engine, and update our connection info."""
684 """Register a new engine, and update our connection info."""
685 content = msg['content']
685 content = msg['content']
686 eid = content['id']
686 eid = content['id']
687 d = {eid : content['uuid']}
687 d = {eid : content['uuid']}
688 self._update_engines(d)
688 self._update_engines(d)
689
689
690 def _unregister_engine(self, msg):
690 def _unregister_engine(self, msg):
691 """Unregister an engine that has died."""
691 """Unregister an engine that has died."""
692 content = msg['content']
692 content = msg['content']
693 eid = int(content['id'])
693 eid = int(content['id'])
694 if eid in self._ids:
694 if eid in self._ids:
695 self._ids.remove(eid)
695 self._ids.remove(eid)
696 uuid = self._engines.pop(eid)
696 uuid = self._engines.pop(eid)
697
697
698 self._handle_stranded_msgs(eid, uuid)
698 self._handle_stranded_msgs(eid, uuid)
699
699
700 if self._task_socket and self._task_scheme == 'pure':
700 if self._task_socket and self._task_scheme == 'pure':
701 self._stop_scheduling_tasks()
701 self._stop_scheduling_tasks()
702
702
703 def _handle_stranded_msgs(self, eid, uuid):
703 def _handle_stranded_msgs(self, eid, uuid):
704 """Handle messages known to be on an engine when the engine unregisters.
704 """Handle messages known to be on an engine when the engine unregisters.
705
705
706 It is possible that this will fire prematurely - that is, an engine will
706 It is possible that this will fire prematurely - that is, an engine will
707 go down after completing a result, and the client will be notified
707 go down after completing a result, and the client will be notified
708 of the unregistration and later receive the successful result.
708 of the unregistration and later receive the successful result.
709 """
709 """
710
710
711 outstanding = self._outstanding_dict[uuid]
711 outstanding = self._outstanding_dict[uuid]
712
712
713 for msg_id in list(outstanding):
713 for msg_id in list(outstanding):
714 if msg_id in self.results:
714 if msg_id in self.results:
715 # we already
715 # we already
716 continue
716 continue
717 try:
717 try:
718 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
718 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
719 except:
719 except:
720 content = error.wrap_exception()
720 content = error.wrap_exception()
721 # build a fake message:
721 # build a fake message:
722 msg = self.session.msg('apply_reply', content=content)
722 msg = self.session.msg('apply_reply', content=content)
723 msg['parent_header']['msg_id'] = msg_id
723 msg['parent_header']['msg_id'] = msg_id
724 msg['metadata']['engine'] = uuid
724 msg['metadata']['engine'] = uuid
725 self._handle_apply_reply(msg)
725 self._handle_apply_reply(msg)
726
726
727 def _handle_execute_reply(self, msg):
727 def _handle_execute_reply(self, msg):
728 """Save the reply to an execute_request into our results.
728 """Save the reply to an execute_request into our results.
729
729
730 execute messages are never actually used. apply is used instead.
730 execute messages are never actually used. apply is used instead.
731 """
731 """
732
732
733 parent = msg['parent_header']
733 parent = msg['parent_header']
734 msg_id = parent['msg_id']
734 msg_id = parent['msg_id']
735 if msg_id not in self.outstanding:
735 if msg_id not in self.outstanding:
736 if msg_id in self.history:
736 if msg_id in self.history:
737 print("got stale result: %s"%msg_id)
737 print("got stale result: %s"%msg_id)
738 else:
738 else:
739 print("got unknown result: %s"%msg_id)
739 print("got unknown result: %s"%msg_id)
740 else:
740 else:
741 self.outstanding.remove(msg_id)
741 self.outstanding.remove(msg_id)
742
742
743 content = msg['content']
743 content = msg['content']
744 header = msg['header']
744 header = msg['header']
745
745
746 # construct metadata:
746 # construct metadata:
747 md = self.metadata[msg_id]
747 md = self.metadata[msg_id]
748 md.update(self._extract_metadata(msg))
748 md.update(self._extract_metadata(msg))
749 # is this redundant?
749 # is this redundant?
750 self.metadata[msg_id] = md
750 self.metadata[msg_id] = md
751
751
752 e_outstanding = self._outstanding_dict[md['engine_uuid']]
752 e_outstanding = self._outstanding_dict[md['engine_uuid']]
753 if msg_id in e_outstanding:
753 if msg_id in e_outstanding:
754 e_outstanding.remove(msg_id)
754 e_outstanding.remove(msg_id)
755
755
756 # construct result:
756 # construct result:
757 if content['status'] == 'ok':
757 if content['status'] == 'ok':
758 self.results[msg_id] = ExecuteReply(msg_id, content, md)
758 self.results[msg_id] = ExecuteReply(msg_id, content, md)
759 elif content['status'] == 'aborted':
759 elif content['status'] == 'aborted':
760 self.results[msg_id] = error.TaskAborted(msg_id)
760 self.results[msg_id] = error.TaskAborted(msg_id)
761 elif content['status'] == 'resubmitted':
761 elif content['status'] == 'resubmitted':
762 # TODO: handle resubmission
762 # TODO: handle resubmission
763 pass
763 pass
764 else:
764 else:
765 self.results[msg_id] = self._unwrap_exception(content)
765 self.results[msg_id] = self._unwrap_exception(content)
766
766
767 def _handle_apply_reply(self, msg):
767 def _handle_apply_reply(self, msg):
768 """Save the reply to an apply_request into our results."""
768 """Save the reply to an apply_request into our results."""
769 parent = msg['parent_header']
769 parent = msg['parent_header']
770 msg_id = parent['msg_id']
770 msg_id = parent['msg_id']
771 if msg_id not in self.outstanding:
771 if msg_id not in self.outstanding:
772 if msg_id in self.history:
772 if msg_id in self.history:
773 print("got stale result: %s"%msg_id)
773 print("got stale result: %s"%msg_id)
774 print(self.results[msg_id])
774 print(self.results[msg_id])
775 print(msg)
775 print(msg)
776 else:
776 else:
777 print("got unknown result: %s"%msg_id)
777 print("got unknown result: %s"%msg_id)
778 else:
778 else:
779 self.outstanding.remove(msg_id)
779 self.outstanding.remove(msg_id)
780 content = msg['content']
780 content = msg['content']
781 header = msg['header']
781 header = msg['header']
782
782
783 # construct metadata:
783 # construct metadata:
784 md = self.metadata[msg_id]
784 md = self.metadata[msg_id]
785 md.update(self._extract_metadata(msg))
785 md.update(self._extract_metadata(msg))
786 # is this redundant?
786 # is this redundant?
787 self.metadata[msg_id] = md
787 self.metadata[msg_id] = md
788
788
789 e_outstanding = self._outstanding_dict[md['engine_uuid']]
789 e_outstanding = self._outstanding_dict[md['engine_uuid']]
790 if msg_id in e_outstanding:
790 if msg_id in e_outstanding:
791 e_outstanding.remove(msg_id)
791 e_outstanding.remove(msg_id)
792
792
793 # construct result:
793 # construct result:
794 if content['status'] == 'ok':
794 if content['status'] == 'ok':
795 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
795 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
796 elif content['status'] == 'aborted':
796 elif content['status'] == 'aborted':
797 self.results[msg_id] = error.TaskAborted(msg_id)
797 self.results[msg_id] = error.TaskAborted(msg_id)
798 elif content['status'] == 'resubmitted':
798 elif content['status'] == 'resubmitted':
799 # TODO: handle resubmission
799 # TODO: handle resubmission
800 pass
800 pass
801 else:
801 else:
802 self.results[msg_id] = self._unwrap_exception(content)
802 self.results[msg_id] = self._unwrap_exception(content)
803
803
804 def _flush_notifications(self):
804 def _flush_notifications(self):
805 """Flush notifications of engine registrations waiting
805 """Flush notifications of engine registrations waiting
806 in ZMQ queue."""
806 in ZMQ queue."""
807 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
807 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
808 while msg is not None:
808 while msg is not None:
809 if self.debug:
809 if self.debug:
810 pprint(msg)
810 pprint(msg)
811 msg_type = msg['header']['msg_type']
811 msg_type = msg['header']['msg_type']
812 handler = self._notification_handlers.get(msg_type, None)
812 handler = self._notification_handlers.get(msg_type, None)
813 if handler is None:
813 if handler is None:
814 raise Exception("Unhandled message type: %s" % msg_type)
814 raise Exception("Unhandled message type: %s" % msg_type)
815 else:
815 else:
816 handler(msg)
816 handler(msg)
817 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
817 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
818
818
819 def _flush_results(self, sock):
819 def _flush_results(self, sock):
820 """Flush task or queue results waiting in ZMQ queue."""
820 """Flush task or queue results waiting in ZMQ queue."""
821 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
821 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
822 while msg is not None:
822 while msg is not None:
823 if self.debug:
823 if self.debug:
824 pprint(msg)
824 pprint(msg)
825 msg_type = msg['header']['msg_type']
825 msg_type = msg['header']['msg_type']
826 handler = self._queue_handlers.get(msg_type, None)
826 handler = self._queue_handlers.get(msg_type, None)
827 if handler is None:
827 if handler is None:
828 raise Exception("Unhandled message type: %s" % msg_type)
828 raise Exception("Unhandled message type: %s" % msg_type)
829 else:
829 else:
830 handler(msg)
830 handler(msg)
831 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
831 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
832
832
833 def _flush_control(self, sock):
833 def _flush_control(self, sock):
834 """Flush replies from the control channel waiting
834 """Flush replies from the control channel waiting
835 in the ZMQ queue.
835 in the ZMQ queue.
836
836
837 Currently: ignore them."""
837 Currently: ignore them."""
838 if self._ignored_control_replies <= 0:
838 if self._ignored_control_replies <= 0:
839 return
839 return
840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
841 while msg is not None:
841 while msg is not None:
842 self._ignored_control_replies -= 1
842 self._ignored_control_replies -= 1
843 if self.debug:
843 if self.debug:
844 pprint(msg)
844 pprint(msg)
845 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
845 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
846
846
847 def _flush_ignored_control(self):
847 def _flush_ignored_control(self):
848 """flush ignored control replies"""
848 """flush ignored control replies"""
849 while self._ignored_control_replies > 0:
849 while self._ignored_control_replies > 0:
850 self.session.recv(self._control_socket)
850 self.session.recv(self._control_socket)
851 self._ignored_control_replies -= 1
851 self._ignored_control_replies -= 1
852
852
853 def _flush_ignored_hub_replies(self):
853 def _flush_ignored_hub_replies(self):
854 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
854 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
855 while msg is not None:
855 while msg is not None:
856 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
856 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
857
857
858 def _flush_iopub(self, sock):
858 def _flush_iopub(self, sock):
859 """Flush replies from the iopub channel waiting
859 """Flush replies from the iopub channel waiting
860 in the ZMQ queue.
860 in the ZMQ queue.
861 """
861 """
862 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
862 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
863 while msg is not None:
863 while msg is not None:
864 if self.debug:
864 if self.debug:
865 pprint(msg)
865 pprint(msg)
866 parent = msg['parent_header']
866 parent = msg['parent_header']
867 # ignore IOPub messages with no parent.
867 # ignore IOPub messages with no parent.
868 # Caused by print statements or warnings from before the first execution.
868 # Caused by print statements or warnings from before the first execution.
869 if not parent:
869 if not parent:
870 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
870 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
871 continue
871 continue
872 msg_id = parent['msg_id']
872 msg_id = parent['msg_id']
873 content = msg['content']
873 content = msg['content']
874 header = msg['header']
874 header = msg['header']
875 msg_type = msg['header']['msg_type']
875 msg_type = msg['header']['msg_type']
876
876
877 # init metadata:
877 # init metadata:
878 md = self.metadata[msg_id]
878 md = self.metadata[msg_id]
879
879
880 if msg_type == 'stream':
880 if msg_type == 'stream':
881 name = content['name']
881 name = content['name']
882 s = md[name] or ''
882 s = md[name] or ''
883 md[name] = s + content['data']
883 md[name] = s + content['data']
884 elif msg_type == 'pyerr':
884 elif msg_type == 'pyerr':
885 md.update({'pyerr' : self._unwrap_exception(content)})
885 md.update({'pyerr' : self._unwrap_exception(content)})
886 elif msg_type == 'pyin':
886 elif msg_type == 'pyin':
887 md.update({'pyin' : content['code']})
887 md.update({'pyin' : content['code']})
888 elif msg_type == 'display_data':
888 elif msg_type == 'display_data':
889 md['outputs'].append(content)
889 md['outputs'].append(content)
890 elif msg_type == 'pyout':
890 elif msg_type == 'pyout':
891 md['pyout'] = content
891 md['pyout'] = content
892 elif msg_type == 'data_message':
892 elif msg_type == 'data_message':
893 data, remainder = serialize.unserialize_object(msg['buffers'])
893 data, remainder = serialize.unserialize_object(msg['buffers'])
894 md['data'].update(data)
894 md['data'].update(data)
895 elif msg_type == 'status':
895 elif msg_type == 'status':
896 # idle message comes after all outputs
896 # idle message comes after all outputs
897 if content['execution_state'] == 'idle':
897 if content['execution_state'] == 'idle':
898 md['outputs_ready'] = True
898 md['outputs_ready'] = True
899 else:
899 else:
900 # unhandled msg_type (status, etc.)
900 # unhandled msg_type (status, etc.)
901 pass
901 pass
902
902
903 # reduntant?
903 # reduntant?
904 self.metadata[msg_id] = md
904 self.metadata[msg_id] = md
905
905
906 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
906 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
907
907
908 #--------------------------------------------------------------------------
908 #--------------------------------------------------------------------------
909 # len, getitem
909 # len, getitem
910 #--------------------------------------------------------------------------
910 #--------------------------------------------------------------------------
911
911
912 def __len__(self):
912 def __len__(self):
913 """len(client) returns # of engines."""
913 """len(client) returns # of engines."""
914 return len(self.ids)
914 return len(self.ids)
915
915
916 def __getitem__(self, key):
916 def __getitem__(self, key):
917 """index access returns DirectView multiplexer objects
917 """index access returns DirectView multiplexer objects
918
918
919 Must be int, slice, or list/tuple/xrange of ints"""
919 Must be int, slice, or list/tuple/xrange of ints"""
920 if not isinstance(key, (int, slice, tuple, list, xrange)):
920 if not isinstance(key, (int, slice, tuple, list, xrange)):
921 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
921 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
922 else:
922 else:
923 return self.direct_view(key)
923 return self.direct_view(key)
924
924
925 #--------------------------------------------------------------------------
925 #--------------------------------------------------------------------------
926 # Begin public methods
926 # Begin public methods
927 #--------------------------------------------------------------------------
927 #--------------------------------------------------------------------------
928
928
929 @property
929 @property
930 def ids(self):
930 def ids(self):
931 """Always up-to-date ids property."""
931 """Always up-to-date ids property."""
932 self._flush_notifications()
932 self._flush_notifications()
933 # always copy:
933 # always copy:
934 return list(self._ids)
934 return list(self._ids)
935
935
936 def activate(self, targets='all', suffix=''):
936 def activate(self, targets='all', suffix=''):
937 """Create a DirectView and register it with IPython magics
937 """Create a DirectView and register it with IPython magics
938
938
939 Defines the magics `%px, %autopx, %pxresult, %%px`
939 Defines the magics `%px, %autopx, %pxresult, %%px`
940
940
941 Parameters
941 Parameters
942 ----------
942 ----------
943
943
944 targets: int, list of ints, or 'all'
944 targets: int, list of ints, or 'all'
945 The engines on which the view's magics will run
945 The engines on which the view's magics will run
946 suffix: str [default: '']
946 suffix: str [default: '']
947 The suffix, if any, for the magics. This allows you to have
947 The suffix, if any, for the magics. This allows you to have
948 multiple views associated with parallel magics at the same time.
948 multiple views associated with parallel magics at the same time.
949
949
950 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
950 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
951 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
951 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
952 on engine 0.
952 on engine 0.
953 """
953 """
954 view = self.direct_view(targets)
954 view = self.direct_view(targets)
955 view.block = True
955 view.block = True
956 view.activate(suffix)
956 view.activate(suffix)
957 return view
957 return view
958
958
959 def close(self, linger=None):
959 def close(self, linger=None):
960 """Close my zmq Sockets
960 """Close my zmq Sockets
961
961
962 If `linger`, set the zmq LINGER socket option,
962 If `linger`, set the zmq LINGER socket option,
963 which allows discarding of messages.
963 which allows discarding of messages.
964 """
964 """
965 if self._closed:
965 if self._closed:
966 return
966 return
967 self.stop_spin_thread()
967 self.stop_spin_thread()
968 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
968 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
969 for name in snames:
969 for name in snames:
970 socket = getattr(self, name)
970 socket = getattr(self, name)
971 if socket is not None and not socket.closed:
971 if socket is not None and not socket.closed:
972 if linger is not None:
972 if linger is not None:
973 socket.close(linger=linger)
973 socket.close(linger=linger)
974 else:
974 else:
975 socket.close()
975 socket.close()
976 self._closed = True
976 self._closed = True
977
977
978 def _spin_every(self, interval=1):
978 def _spin_every(self, interval=1):
979 """target func for use in spin_thread"""
979 """target func for use in spin_thread"""
980 while True:
980 while True:
981 if self._stop_spinning.is_set():
981 if self._stop_spinning.is_set():
982 return
982 return
983 time.sleep(interval)
983 time.sleep(interval)
984 self.spin()
984 self.spin()
985
985
986 def spin_thread(self, interval=1):
986 def spin_thread(self, interval=1):
987 """call Client.spin() in a background thread on some regular interval
987 """call Client.spin() in a background thread on some regular interval
988
988
989 This helps ensure that messages don't pile up too much in the zmq queue
989 This helps ensure that messages don't pile up too much in the zmq queue
990 while you are working on other things, or just leaving an idle terminal.
990 while you are working on other things, or just leaving an idle terminal.
991
991
992 It also helps limit potential padding of the `received` timestamp
992 It also helps limit potential padding of the `received` timestamp
993 on AsyncResult objects, used for timings.
993 on AsyncResult objects, used for timings.
994
994
995 Parameters
995 Parameters
996 ----------
996 ----------
997
997
998 interval : float, optional
998 interval : float, optional
999 The interval on which to spin the client in the background thread
999 The interval on which to spin the client in the background thread
1000 (simply passed to time.sleep).
1000 (simply passed to time.sleep).
1001
1001
1002 Notes
1002 Notes
1003 -----
1003 -----
1004
1004
1005 For precision timing, you may want to use this method to put a bound
1005 For precision timing, you may want to use this method to put a bound
1006 on the jitter (in seconds) in `received` timestamps used
1006 on the jitter (in seconds) in `received` timestamps used
1007 in AsyncResult.wall_time.
1007 in AsyncResult.wall_time.
1008
1008
1009 """
1009 """
1010 if self._spin_thread is not None:
1010 if self._spin_thread is not None:
1011 self.stop_spin_thread()
1011 self.stop_spin_thread()
1012 self._stop_spinning.clear()
1012 self._stop_spinning.clear()
1013 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1013 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1014 self._spin_thread.daemon = True
1014 self._spin_thread.daemon = True
1015 self._spin_thread.start()
1015 self._spin_thread.start()
1016
1016
1017 def stop_spin_thread(self):
1017 def stop_spin_thread(self):
1018 """stop background spin_thread, if any"""
1018 """stop background spin_thread, if any"""
1019 if self._spin_thread is not None:
1019 if self._spin_thread is not None:
1020 self._stop_spinning.set()
1020 self._stop_spinning.set()
1021 self._spin_thread.join()
1021 self._spin_thread.join()
1022 self._spin_thread = None
1022 self._spin_thread = None
1023
1023
1024 def spin(self):
1024 def spin(self):
1025 """Flush any registration notifications and execution results
1025 """Flush any registration notifications and execution results
1026 waiting in the ZMQ queue.
1026 waiting in the ZMQ queue.
1027 """
1027 """
1028 if self._notification_socket:
1028 if self._notification_socket:
1029 self._flush_notifications()
1029 self._flush_notifications()
1030 if self._iopub_socket:
1030 if self._iopub_socket:
1031 self._flush_iopub(self._iopub_socket)
1031 self._flush_iopub(self._iopub_socket)
1032 if self._mux_socket:
1032 if self._mux_socket:
1033 self._flush_results(self._mux_socket)
1033 self._flush_results(self._mux_socket)
1034 if self._task_socket:
1034 if self._task_socket:
1035 self._flush_results(self._task_socket)
1035 self._flush_results(self._task_socket)
1036 if self._control_socket:
1036 if self._control_socket:
1037 self._flush_control(self._control_socket)
1037 self._flush_control(self._control_socket)
1038 if self._query_socket:
1038 if self._query_socket:
1039 self._flush_ignored_hub_replies()
1039 self._flush_ignored_hub_replies()
1040
1040
1041 def wait(self, jobs=None, timeout=-1):
1041 def wait(self, jobs=None, timeout=-1):
1042 """waits on one or more `jobs`, for up to `timeout` seconds.
1042 """waits on one or more `jobs`, for up to `timeout` seconds.
1043
1043
1044 Parameters
1044 Parameters
1045 ----------
1045 ----------
1046
1046
1047 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1047 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1048 ints are indices to self.history
1048 ints are indices to self.history
1049 strs are msg_ids
1049 strs are msg_ids
1050 default: wait on all outstanding messages
1050 default: wait on all outstanding messages
1051 timeout : float
1051 timeout : float
1052 a time in seconds, after which to give up.
1052 a time in seconds, after which to give up.
1053 default is -1, which means no timeout
1053 default is -1, which means no timeout
1054
1054
1055 Returns
1055 Returns
1056 -------
1056 -------
1057
1057
1058 True : when all msg_ids are done
1058 True : when all msg_ids are done
1059 False : timeout reached, some msg_ids still outstanding
1059 False : timeout reached, some msg_ids still outstanding
1060 """
1060 """
1061 tic = time.time()
1061 tic = time.time()
1062 if jobs is None:
1062 if jobs is None:
1063 theids = self.outstanding
1063 theids = self.outstanding
1064 else:
1064 else:
1065 if isinstance(jobs, string_types + (int, AsyncResult)):
1065 if isinstance(jobs, string_types + (int, AsyncResult)):
1066 jobs = [jobs]
1066 jobs = [jobs]
1067 theids = set()
1067 theids = set()
1068 for job in jobs:
1068 for job in jobs:
1069 if isinstance(job, int):
1069 if isinstance(job, int):
1070 # index access
1070 # index access
1071 job = self.history[job]
1071 job = self.history[job]
1072 elif isinstance(job, AsyncResult):
1072 elif isinstance(job, AsyncResult):
1073 theids.update(job.msg_ids)
1073 theids.update(job.msg_ids)
1074 continue
1074 continue
1075 theids.add(job)
1075 theids.add(job)
1076 if not theids.intersection(self.outstanding):
1076 if not theids.intersection(self.outstanding):
1077 return True
1077 return True
1078 self.spin()
1078 self.spin()
1079 while theids.intersection(self.outstanding):
1079 while theids.intersection(self.outstanding):
1080 if timeout >= 0 and ( time.time()-tic ) > timeout:
1080 if timeout >= 0 and ( time.time()-tic ) > timeout:
1081 break
1081 break
1082 time.sleep(1e-3)
1082 time.sleep(1e-3)
1083 self.spin()
1083 self.spin()
1084 return len(theids.intersection(self.outstanding)) == 0
1084 return len(theids.intersection(self.outstanding)) == 0
1085
1085
1086 #--------------------------------------------------------------------------
1086 #--------------------------------------------------------------------------
1087 # Control methods
1087 # Control methods
1088 #--------------------------------------------------------------------------
1088 #--------------------------------------------------------------------------
1089
1089
1090 @spin_first
1090 @spin_first
1091 def clear(self, targets=None, block=None):
1091 def clear(self, targets=None, block=None):
1092 """Clear the namespace in target(s)."""
1092 """Clear the namespace in target(s)."""
1093 block = self.block if block is None else block
1093 block = self.block if block is None else block
1094 targets = self._build_targets(targets)[0]
1094 targets = self._build_targets(targets)[0]
1095 for t in targets:
1095 for t in targets:
1096 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1096 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1097 error = False
1097 error = False
1098 if block:
1098 if block:
1099 self._flush_ignored_control()
1099 self._flush_ignored_control()
1100 for i in range(len(targets)):
1100 for i in range(len(targets)):
1101 idents,msg = self.session.recv(self._control_socket,0)
1101 idents,msg = self.session.recv(self._control_socket,0)
1102 if self.debug:
1102 if self.debug:
1103 pprint(msg)
1103 pprint(msg)
1104 if msg['content']['status'] != 'ok':
1104 if msg['content']['status'] != 'ok':
1105 error = self._unwrap_exception(msg['content'])
1105 error = self._unwrap_exception(msg['content'])
1106 else:
1106 else:
1107 self._ignored_control_replies += len(targets)
1107 self._ignored_control_replies += len(targets)
1108 if error:
1108 if error:
1109 raise error
1109 raise error
1110
1110
1111
1111
1112 @spin_first
1112 @spin_first
1113 def abort(self, jobs=None, targets=None, block=None):
1113 def abort(self, jobs=None, targets=None, block=None):
1114 """Abort specific jobs from the execution queues of target(s).
1114 """Abort specific jobs from the execution queues of target(s).
1115
1115
1116 This is a mechanism to prevent jobs that have already been submitted
1116 This is a mechanism to prevent jobs that have already been submitted
1117 from executing.
1117 from executing.
1118
1118
1119 Parameters
1119 Parameters
1120 ----------
1120 ----------
1121
1121
1122 jobs : msg_id, list of msg_ids, or AsyncResult
1122 jobs : msg_id, list of msg_ids, or AsyncResult
1123 The jobs to be aborted
1123 The jobs to be aborted
1124
1124
1125 If unspecified/None: abort all outstanding jobs.
1125 If unspecified/None: abort all outstanding jobs.
1126
1126
1127 """
1127 """
1128 block = self.block if block is None else block
1128 block = self.block if block is None else block
1129 jobs = jobs if jobs is not None else list(self.outstanding)
1129 jobs = jobs if jobs is not None else list(self.outstanding)
1130 targets = self._build_targets(targets)[0]
1130 targets = self._build_targets(targets)[0]
1131
1131
1132 msg_ids = []
1132 msg_ids = []
1133 if isinstance(jobs, string_types + (AsyncResult,)):
1133 if isinstance(jobs, string_types + (AsyncResult,)):
1134 jobs = [jobs]
1134 jobs = [jobs]
1135 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1135 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1136 if bad_ids:
1136 if bad_ids:
1137 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1137 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1138 for j in jobs:
1138 for j in jobs:
1139 if isinstance(j, AsyncResult):
1139 if isinstance(j, AsyncResult):
1140 msg_ids.extend(j.msg_ids)
1140 msg_ids.extend(j.msg_ids)
1141 else:
1141 else:
1142 msg_ids.append(j)
1142 msg_ids.append(j)
1143 content = dict(msg_ids=msg_ids)
1143 content = dict(msg_ids=msg_ids)
1144 for t in targets:
1144 for t in targets:
1145 self.session.send(self._control_socket, 'abort_request',
1145 self.session.send(self._control_socket, 'abort_request',
1146 content=content, ident=t)
1146 content=content, ident=t)
1147 error = False
1147 error = False
1148 if block:
1148 if block:
1149 self._flush_ignored_control()
1149 self._flush_ignored_control()
1150 for i in range(len(targets)):
1150 for i in range(len(targets)):
1151 idents,msg = self.session.recv(self._control_socket,0)
1151 idents,msg = self.session.recv(self._control_socket,0)
1152 if self.debug:
1152 if self.debug:
1153 pprint(msg)
1153 pprint(msg)
1154 if msg['content']['status'] != 'ok':
1154 if msg['content']['status'] != 'ok':
1155 error = self._unwrap_exception(msg['content'])
1155 error = self._unwrap_exception(msg['content'])
1156 else:
1156 else:
1157 self._ignored_control_replies += len(targets)
1157 self._ignored_control_replies += len(targets)
1158 if error:
1158 if error:
1159 raise error
1159 raise error
1160
1160
1161 @spin_first
1161 @spin_first
1162 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1162 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1163 """Terminates one or more engine processes, optionally including the hub.
1163 """Terminates one or more engine processes, optionally including the hub.
1164
1164
1165 Parameters
1165 Parameters
1166 ----------
1166 ----------
1167
1167
1168 targets: list of ints or 'all' [default: all]
1168 targets: list of ints or 'all' [default: all]
1169 Which engines to shutdown.
1169 Which engines to shutdown.
1170 hub: bool [default: False]
1170 hub: bool [default: False]
1171 Whether to include the Hub. hub=True implies targets='all'.
1171 Whether to include the Hub. hub=True implies targets='all'.
1172 block: bool [default: self.block]
1172 block: bool [default: self.block]
1173 Whether to wait for clean shutdown replies or not.
1173 Whether to wait for clean shutdown replies or not.
1174 restart: bool [default: False]
1174 restart: bool [default: False]
1175 NOT IMPLEMENTED
1175 NOT IMPLEMENTED
1176 whether to restart engines after shutting them down.
1176 whether to restart engines after shutting them down.
1177 """
1177 """
1178 from IPython.parallel.error import NoEnginesRegistered
1178 from IPython.parallel.error import NoEnginesRegistered
1179 if restart:
1179 if restart:
1180 raise NotImplementedError("Engine restart is not yet implemented")
1180 raise NotImplementedError("Engine restart is not yet implemented")
1181
1181
1182 block = self.block if block is None else block
1182 block = self.block if block is None else block
1183 if hub:
1183 if hub:
1184 targets = 'all'
1184 targets = 'all'
1185 try:
1185 try:
1186 targets = self._build_targets(targets)[0]
1186 targets = self._build_targets(targets)[0]
1187 except NoEnginesRegistered:
1187 except NoEnginesRegistered:
1188 targets = []
1188 targets = []
1189 for t in targets:
1189 for t in targets:
1190 self.session.send(self._control_socket, 'shutdown_request',
1190 self.session.send(self._control_socket, 'shutdown_request',
1191 content={'restart':restart},ident=t)
1191 content={'restart':restart},ident=t)
1192 error = False
1192 error = False
1193 if block or hub:
1193 if block or hub:
1194 self._flush_ignored_control()
1194 self._flush_ignored_control()
1195 for i in range(len(targets)):
1195 for i in range(len(targets)):
1196 idents,msg = self.session.recv(self._control_socket, 0)
1196 idents,msg = self.session.recv(self._control_socket, 0)
1197 if self.debug:
1197 if self.debug:
1198 pprint(msg)
1198 pprint(msg)
1199 if msg['content']['status'] != 'ok':
1199 if msg['content']['status'] != 'ok':
1200 error = self._unwrap_exception(msg['content'])
1200 error = self._unwrap_exception(msg['content'])
1201 else:
1201 else:
1202 self._ignored_control_replies += len(targets)
1202 self._ignored_control_replies += len(targets)
1203
1203
1204 if hub:
1204 if hub:
1205 time.sleep(0.25)
1205 time.sleep(0.25)
1206 self.session.send(self._query_socket, 'shutdown_request')
1206 self.session.send(self._query_socket, 'shutdown_request')
1207 idents,msg = self.session.recv(self._query_socket, 0)
1207 idents,msg = self.session.recv(self._query_socket, 0)
1208 if self.debug:
1208 if self.debug:
1209 pprint(msg)
1209 pprint(msg)
1210 if msg['content']['status'] != 'ok':
1210 if msg['content']['status'] != 'ok':
1211 error = self._unwrap_exception(msg['content'])
1211 error = self._unwrap_exception(msg['content'])
1212
1212
1213 if error:
1213 if error:
1214 raise error
1214 raise error
1215
1215
1216 #--------------------------------------------------------------------------
1216 #--------------------------------------------------------------------------
1217 # Execution related methods
1217 # Execution related methods
1218 #--------------------------------------------------------------------------
1218 #--------------------------------------------------------------------------
1219
1219
1220 def _maybe_raise(self, result):
1220 def _maybe_raise(self, result):
1221 """wrapper for maybe raising an exception if apply failed."""
1221 """wrapper for maybe raising an exception if apply failed."""
1222 if isinstance(result, error.RemoteError):
1222 if isinstance(result, error.RemoteError):
1223 raise result
1223 raise result
1224
1224
1225 return result
1225 return result
1226
1226
1227 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1227 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1228 ident=None):
1228 ident=None):
1229 """construct and send an apply message via a socket.
1229 """construct and send an apply message via a socket.
1230
1230
1231 This is the principal method with which all engine execution is performed by views.
1231 This is the principal method with which all engine execution is performed by views.
1232 """
1232 """
1233
1233
1234 if self._closed:
1234 if self._closed:
1235 raise RuntimeError("Client cannot be used after its sockets have been closed")
1235 raise RuntimeError("Client cannot be used after its sockets have been closed")
1236
1236
1237 # defaults:
1237 # defaults:
1238 args = args if args is not None else []
1238 args = args if args is not None else []
1239 kwargs = kwargs if kwargs is not None else {}
1239 kwargs = kwargs if kwargs is not None else {}
1240 metadata = metadata if metadata is not None else {}
1240 metadata = metadata if metadata is not None else {}
1241
1241
1242 # validate arguments
1242 # validate arguments
1243 if not callable(f) and not isinstance(f, Reference):
1243 if not callable(f) and not isinstance(f, Reference):
1244 raise TypeError("f must be callable, not %s"%type(f))
1244 raise TypeError("f must be callable, not %s"%type(f))
1245 if not isinstance(args, (tuple, list)):
1245 if not isinstance(args, (tuple, list)):
1246 raise TypeError("args must be tuple or list, not %s"%type(args))
1246 raise TypeError("args must be tuple or list, not %s"%type(args))
1247 if not isinstance(kwargs, dict):
1247 if not isinstance(kwargs, dict):
1248 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1248 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1249 if not isinstance(metadata, dict):
1249 if not isinstance(metadata, dict):
1250 raise TypeError("metadata must be dict, not %s"%type(metadata))
1250 raise TypeError("metadata must be dict, not %s"%type(metadata))
1251
1251
1252 bufs = serialize.pack_apply_message(f, args, kwargs,
1252 bufs = serialize.pack_apply_message(f, args, kwargs,
1253 buffer_threshold=self.session.buffer_threshold,
1253 buffer_threshold=self.session.buffer_threshold,
1254 item_threshold=self.session.item_threshold,
1254 item_threshold=self.session.item_threshold,
1255 )
1255 )
1256
1256
1257 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1257 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1258 metadata=metadata, track=track)
1258 metadata=metadata, track=track)
1259
1259
1260 msg_id = msg['header']['msg_id']
1260 msg_id = msg['header']['msg_id']
1261 self.outstanding.add(msg_id)
1261 self.outstanding.add(msg_id)
1262 if ident:
1262 if ident:
1263 # possibly routed to a specific engine
1263 # possibly routed to a specific engine
1264 if isinstance(ident, list):
1264 if isinstance(ident, list):
1265 ident = ident[-1]
1265 ident = ident[-1]
1266 if ident in self._engines.values():
1266 if ident in self._engines.values():
1267 # save for later, in case of engine death
1267 # save for later, in case of engine death
1268 self._outstanding_dict[ident].add(msg_id)
1268 self._outstanding_dict[ident].add(msg_id)
1269 self.history.append(msg_id)
1269 self.history.append(msg_id)
1270 self.metadata[msg_id]['submitted'] = datetime.now()
1270 self.metadata[msg_id]['submitted'] = datetime.now()
1271
1271
1272 return msg
1272 return msg
1273
1273
1274 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1274 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1275 """construct and send an execute request via a socket.
1275 """construct and send an execute request via a socket.
1276
1276
1277 """
1277 """
1278
1278
1279 if self._closed:
1279 if self._closed:
1280 raise RuntimeError("Client cannot be used after its sockets have been closed")
1280 raise RuntimeError("Client cannot be used after its sockets have been closed")
1281
1281
1282 # defaults:
1282 # defaults:
1283 metadata = metadata if metadata is not None else {}
1283 metadata = metadata if metadata is not None else {}
1284
1284
1285 # validate arguments
1285 # validate arguments
1286 if not isinstance(code, string_types):
1286 if not isinstance(code, string_types):
1287 raise TypeError("code must be text, not %s" % type(code))
1287 raise TypeError("code must be text, not %s" % type(code))
1288 if not isinstance(metadata, dict):
1288 if not isinstance(metadata, dict):
1289 raise TypeError("metadata must be dict, not %s" % type(metadata))
1289 raise TypeError("metadata must be dict, not %s" % type(metadata))
1290
1290
1291 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1291 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1292
1292
1293
1293
1294 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1294 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1295 metadata=metadata)
1295 metadata=metadata)
1296
1296
1297 msg_id = msg['header']['msg_id']
1297 msg_id = msg['header']['msg_id']
1298 self.outstanding.add(msg_id)
1298 self.outstanding.add(msg_id)
1299 if ident:
1299 if ident:
1300 # possibly routed to a specific engine
1300 # possibly routed to a specific engine
1301 if isinstance(ident, list):
1301 if isinstance(ident, list):
1302 ident = ident[-1]
1302 ident = ident[-1]
1303 if ident in self._engines.values():
1303 if ident in self._engines.values():
1304 # save for later, in case of engine death
1304 # save for later, in case of engine death
1305 self._outstanding_dict[ident].add(msg_id)
1305 self._outstanding_dict[ident].add(msg_id)
1306 self.history.append(msg_id)
1306 self.history.append(msg_id)
1307 self.metadata[msg_id]['submitted'] = datetime.now()
1307 self.metadata[msg_id]['submitted'] = datetime.now()
1308
1308
1309 return msg
1309 return msg
1310
1310
1311 #--------------------------------------------------------------------------
1311 #--------------------------------------------------------------------------
1312 # construct a View object
1312 # construct a View object
1313 #--------------------------------------------------------------------------
1313 #--------------------------------------------------------------------------
1314
1314
1315 def load_balanced_view(self, targets=None):
1315 def load_balanced_view(self, targets=None):
1316 """construct a DirectView object.
1316 """construct a DirectView object.
1317
1317
1318 If no arguments are specified, create a LoadBalancedView
1318 If no arguments are specified, create a LoadBalancedView
1319 using all engines.
1319 using all engines.
1320
1320
1321 Parameters
1321 Parameters
1322 ----------
1322 ----------
1323
1323
1324 targets: list,slice,int,etc. [default: use all engines]
1324 targets: list,slice,int,etc. [default: use all engines]
1325 The subset of engines across which to load-balance
1325 The subset of engines across which to load-balance
1326 """
1326 """
1327 if targets == 'all':
1327 if targets == 'all':
1328 targets = None
1328 targets = None
1329 if targets is not None:
1329 if targets is not None:
1330 targets = self._build_targets(targets)[1]
1330 targets = self._build_targets(targets)[1]
1331 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1331 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1332
1332
1333 def direct_view(self, targets='all'):
1333 def direct_view(self, targets='all'):
1334 """construct a DirectView object.
1334 """construct a DirectView object.
1335
1335
1336 If no targets are specified, create a DirectView using all engines.
1336 If no targets are specified, create a DirectView using all engines.
1337
1337
1338 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1338 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1339 evaluate the target engines at each execution, whereas rc[:] will connect to
1339 evaluate the target engines at each execution, whereas rc[:] will connect to
1340 all *current* engines, and that list will not change.
1340 all *current* engines, and that list will not change.
1341
1341
1342 That is, 'all' will always use all engines, whereas rc[:] will not use
1342 That is, 'all' will always use all engines, whereas rc[:] will not use
1343 engines added after the DirectView is constructed.
1343 engines added after the DirectView is constructed.
1344
1344
1345 Parameters
1345 Parameters
1346 ----------
1346 ----------
1347
1347
1348 targets: list,slice,int,etc. [default: use all engines]
1348 targets: list,slice,int,etc. [default: use all engines]
1349 The engines to use for the View
1349 The engines to use for the View
1350 """
1350 """
1351 single = isinstance(targets, int)
1351 single = isinstance(targets, int)
1352 # allow 'all' to be lazily evaluated at each execution
1352 # allow 'all' to be lazily evaluated at each execution
1353 if targets != 'all':
1353 if targets != 'all':
1354 targets = self._build_targets(targets)[1]
1354 targets = self._build_targets(targets)[1]
1355 if single:
1355 if single:
1356 targets = targets[0]
1356 targets = targets[0]
1357 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1357 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1358
1358
1359 #--------------------------------------------------------------------------
1359 #--------------------------------------------------------------------------
1360 # Query methods
1360 # Query methods
1361 #--------------------------------------------------------------------------
1361 #--------------------------------------------------------------------------
1362
1362
1363 @spin_first
1363 @spin_first
1364 def get_result(self, indices_or_msg_ids=None, block=None):
1364 def get_result(self, indices_or_msg_ids=None, block=None):
1365 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1365 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1366
1366
1367 If the client already has the results, no request to the Hub will be made.
1367 If the client already has the results, no request to the Hub will be made.
1368
1368
1369 This is a convenient way to construct AsyncResult objects, which are wrappers
1369 This is a convenient way to construct AsyncResult objects, which are wrappers
1370 that include metadata about execution, and allow for awaiting results that
1370 that include metadata about execution, and allow for awaiting results that
1371 were not submitted by this Client.
1371 were not submitted by this Client.
1372
1372
1373 It can also be a convenient way to retrieve the metadata associated with
1373 It can also be a convenient way to retrieve the metadata associated with
1374 blocking execution, since it always retrieves
1374 blocking execution, since it always retrieves
1375
1375
1376 Examples
1376 Examples
1377 --------
1377 --------
1378 ::
1378 ::
1379
1379
1380 In [10]: r = client.apply()
1380 In [10]: r = client.apply()
1381
1381
1382 Parameters
1382 Parameters
1383 ----------
1383 ----------
1384
1384
1385 indices_or_msg_ids : integer history index, str msg_id, or list of either
1385 indices_or_msg_ids : integer history index, str msg_id, or list of either
1386 The indices or msg_ids of indices to be retrieved
1386 The indices or msg_ids of indices to be retrieved
1387
1387
1388 block : bool
1388 block : bool
1389 Whether to wait for the result to be done
1389 Whether to wait for the result to be done
1390
1390
1391 Returns
1391 Returns
1392 -------
1392 -------
1393
1393
1394 AsyncResult
1394 AsyncResult
1395 A single AsyncResult object will always be returned.
1395 A single AsyncResult object will always be returned.
1396
1396
1397 AsyncHubResult
1397 AsyncHubResult
1398 A subclass of AsyncResult that retrieves results from the Hub
1398 A subclass of AsyncResult that retrieves results from the Hub
1399
1399
1400 """
1400 """
1401 block = self.block if block is None else block
1401 block = self.block if block is None else block
1402 if indices_or_msg_ids is None:
1402 if indices_or_msg_ids is None:
1403 indices_or_msg_ids = -1
1403 indices_or_msg_ids = -1
1404
1404
1405 single_result = False
1405 single_result = False
1406 if not isinstance(indices_or_msg_ids, (list,tuple)):
1406 if not isinstance(indices_or_msg_ids, (list,tuple)):
1407 indices_or_msg_ids = [indices_or_msg_ids]
1407 indices_or_msg_ids = [indices_or_msg_ids]
1408 single_result = True
1408 single_result = True
1409
1409
1410 theids = []
1410 theids = []
1411 for id in indices_or_msg_ids:
1411 for id in indices_or_msg_ids:
1412 if isinstance(id, int):
1412 if isinstance(id, int):
1413 id = self.history[id]
1413 id = self.history[id]
1414 if not isinstance(id, string_types):
1414 if not isinstance(id, string_types):
1415 raise TypeError("indices must be str or int, not %r"%id)
1415 raise TypeError("indices must be str or int, not %r"%id)
1416 theids.append(id)
1416 theids.append(id)
1417
1417
1418 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1418 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1419 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1419 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1420
1420
1421 # given single msg_id initially, get_result shot get the result itself,
1421 # given single msg_id initially, get_result shot get the result itself,
1422 # not a length-one list
1422 # not a length-one list
1423 if single_result:
1423 if single_result:
1424 theids = theids[0]
1424 theids = theids[0]
1425
1425
1426 if remote_ids:
1426 if remote_ids:
1427 ar = AsyncHubResult(self, msg_ids=theids)
1427 ar = AsyncHubResult(self, msg_ids=theids)
1428 else:
1428 else:
1429 ar = AsyncResult(self, msg_ids=theids)
1429 ar = AsyncResult(self, msg_ids=theids)
1430
1430
1431 if block:
1431 if block:
1432 ar.wait()
1432 ar.wait()
1433
1433
1434 return ar
1434 return ar
1435
1435
1436 @spin_first
1436 @spin_first
1437 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1437 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1438 """Resubmit one or more tasks.
1438 """Resubmit one or more tasks.
1439
1439
1440 in-flight tasks may not be resubmitted.
1440 in-flight tasks may not be resubmitted.
1441
1441
1442 Parameters
1442 Parameters
1443 ----------
1443 ----------
1444
1444
1445 indices_or_msg_ids : integer history index, str msg_id, or list of either
1445 indices_or_msg_ids : integer history index, str msg_id, or list of either
1446 The indices or msg_ids of indices to be retrieved
1446 The indices or msg_ids of indices to be retrieved
1447
1447
1448 block : bool
1448 block : bool
1449 Whether to wait for the result to be done
1449 Whether to wait for the result to be done
1450
1450
1451 Returns
1451 Returns
1452 -------
1452 -------
1453
1453
1454 AsyncHubResult
1454 AsyncHubResult
1455 A subclass of AsyncResult that retrieves results from the Hub
1455 A subclass of AsyncResult that retrieves results from the Hub
1456
1456
1457 """
1457 """
1458 block = self.block if block is None else block
1458 block = self.block if block is None else block
1459 if indices_or_msg_ids is None:
1459 if indices_or_msg_ids is None:
1460 indices_or_msg_ids = -1
1460 indices_or_msg_ids = -1
1461
1461
1462 if not isinstance(indices_or_msg_ids, (list,tuple)):
1462 if not isinstance(indices_or_msg_ids, (list,tuple)):
1463 indices_or_msg_ids = [indices_or_msg_ids]
1463 indices_or_msg_ids = [indices_or_msg_ids]
1464
1464
1465 theids = []
1465 theids = []
1466 for id in indices_or_msg_ids:
1466 for id in indices_or_msg_ids:
1467 if isinstance(id, int):
1467 if isinstance(id, int):
1468 id = self.history[id]
1468 id = self.history[id]
1469 if not isinstance(id, string_types):
1469 if not isinstance(id, string_types):
1470 raise TypeError("indices must be str or int, not %r"%id)
1470 raise TypeError("indices must be str or int, not %r"%id)
1471 theids.append(id)
1471 theids.append(id)
1472
1472
1473 content = dict(msg_ids = theids)
1473 content = dict(msg_ids = theids)
1474
1474
1475 self.session.send(self._query_socket, 'resubmit_request', content)
1475 self.session.send(self._query_socket, 'resubmit_request', content)
1476
1476
1477 zmq.select([self._query_socket], [], [])
1477 zmq.select([self._query_socket], [], [])
1478 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1478 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1479 if self.debug:
1479 if self.debug:
1480 pprint(msg)
1480 pprint(msg)
1481 content = msg['content']
1481 content = msg['content']
1482 if content['status'] != 'ok':
1482 if content['status'] != 'ok':
1483 raise self._unwrap_exception(content)
1483 raise self._unwrap_exception(content)
1484 mapping = content['resubmitted']
1484 mapping = content['resubmitted']
1485 new_ids = [ mapping[msg_id] for msg_id in theids ]
1485 new_ids = [ mapping[msg_id] for msg_id in theids ]
1486
1486
1487 ar = AsyncHubResult(self, msg_ids=new_ids)
1487 ar = AsyncHubResult(self, msg_ids=new_ids)
1488
1488
1489 if block:
1489 if block:
1490 ar.wait()
1490 ar.wait()
1491
1491
1492 return ar
1492 return ar
1493
1493
1494 @spin_first
1494 @spin_first
1495 def result_status(self, msg_ids, status_only=True):
1495 def result_status(self, msg_ids, status_only=True):
1496 """Check on the status of the result(s) of the apply request with `msg_ids`.
1496 """Check on the status of the result(s) of the apply request with `msg_ids`.
1497
1497
1498 If status_only is False, then the actual results will be retrieved, else
1498 If status_only is False, then the actual results will be retrieved, else
1499 only the status of the results will be checked.
1499 only the status of the results will be checked.
1500
1500
1501 Parameters
1501 Parameters
1502 ----------
1502 ----------
1503
1503
1504 msg_ids : list of msg_ids
1504 msg_ids : list of msg_ids
1505 if int:
1505 if int:
1506 Passed as index to self.history for convenience.
1506 Passed as index to self.history for convenience.
1507 status_only : bool (default: True)
1507 status_only : bool (default: True)
1508 if False:
1508 if False:
1509 Retrieve the actual results of completed tasks.
1509 Retrieve the actual results of completed tasks.
1510
1510
1511 Returns
1511 Returns
1512 -------
1512 -------
1513
1513
1514 results : dict
1514 results : dict
1515 There will always be the keys 'pending' and 'completed', which will
1515 There will always be the keys 'pending' and 'completed', which will
1516 be lists of msg_ids that are incomplete or complete. If `status_only`
1516 be lists of msg_ids that are incomplete or complete. If `status_only`
1517 is False, then completed results will be keyed by their `msg_id`.
1517 is False, then completed results will be keyed by their `msg_id`.
1518 """
1518 """
1519 if not isinstance(msg_ids, (list,tuple)):
1519 if not isinstance(msg_ids, (list,tuple)):
1520 msg_ids = [msg_ids]
1520 msg_ids = [msg_ids]
1521
1521
1522 theids = []
1522 theids = []
1523 for msg_id in msg_ids:
1523 for msg_id in msg_ids:
1524 if isinstance(msg_id, int):
1524 if isinstance(msg_id, int):
1525 msg_id = self.history[msg_id]
1525 msg_id = self.history[msg_id]
1526 if not isinstance(msg_id, string_types):
1526 if not isinstance(msg_id, string_types):
1527 raise TypeError("msg_ids must be str, not %r"%msg_id)
1527 raise TypeError("msg_ids must be str, not %r"%msg_id)
1528 theids.append(msg_id)
1528 theids.append(msg_id)
1529
1529
1530 completed = []
1530 completed = []
1531 local_results = {}
1531 local_results = {}
1532
1532
1533 # comment this block out to temporarily disable local shortcut:
1533 # comment this block out to temporarily disable local shortcut:
1534 for msg_id in theids:
1534 for msg_id in theids:
1535 if msg_id in self.results:
1535 if msg_id in self.results:
1536 completed.append(msg_id)
1536 completed.append(msg_id)
1537 local_results[msg_id] = self.results[msg_id]
1537 local_results[msg_id] = self.results[msg_id]
1538 theids.remove(msg_id)
1538 theids.remove(msg_id)
1539
1539
1540 if theids: # some not locally cached
1540 if theids: # some not locally cached
1541 content = dict(msg_ids=theids, status_only=status_only)
1541 content = dict(msg_ids=theids, status_only=status_only)
1542 msg = self.session.send(self._query_socket, "result_request", content=content)
1542 msg = self.session.send(self._query_socket, "result_request", content=content)
1543 zmq.select([self._query_socket], [], [])
1543 zmq.select([self._query_socket], [], [])
1544 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1544 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1545 if self.debug:
1545 if self.debug:
1546 pprint(msg)
1546 pprint(msg)
1547 content = msg['content']
1547 content = msg['content']
1548 if content['status'] != 'ok':
1548 if content['status'] != 'ok':
1549 raise self._unwrap_exception(content)
1549 raise self._unwrap_exception(content)
1550 buffers = msg['buffers']
1550 buffers = msg['buffers']
1551 else:
1551 else:
1552 content = dict(completed=[],pending=[])
1552 content = dict(completed=[],pending=[])
1553
1553
1554 content['completed'].extend(completed)
1554 content['completed'].extend(completed)
1555
1555
1556 if status_only:
1556 if status_only:
1557 return content
1557 return content
1558
1558
1559 failures = []
1559 failures = []
1560 # load cached results into result:
1560 # load cached results into result:
1561 content.update(local_results)
1561 content.update(local_results)
1562
1562
1563 # update cache with results:
1563 # update cache with results:
1564 for msg_id in sorted(theids):
1564 for msg_id in sorted(theids):
1565 if msg_id in content['completed']:
1565 if msg_id in content['completed']:
1566 rec = content[msg_id]
1566 rec = content[msg_id]
1567 parent = rec['header']
1567 parent = extract_dates(rec['header'])
1568 header = rec['result_header']
1568 header = extract_dates(rec['result_header'])
1569 rcontent = rec['result_content']
1569 rcontent = rec['result_content']
1570 iodict = rec['io']
1570 iodict = rec['io']
1571 if isinstance(rcontent, str):
1571 if isinstance(rcontent, str):
1572 rcontent = self.session.unpack(rcontent)
1572 rcontent = self.session.unpack(rcontent)
1573
1573
1574 md = self.metadata[msg_id]
1574 md = self.metadata[msg_id]
1575 md_msg = dict(
1575 md_msg = dict(
1576 content=rcontent,
1576 content=rcontent,
1577 parent_header=parent,
1577 parent_header=parent,
1578 header=header,
1578 header=header,
1579 metadata=rec['result_metadata'],
1579 metadata=rec['result_metadata'],
1580 )
1580 )
1581 md.update(self._extract_metadata(md_msg))
1581 md.update(self._extract_metadata(md_msg))
1582 if rec.get('received'):
1582 if rec.get('received'):
1583 md['received'] = rec['received']
1583 md['received'] = parse_date(rec['received'])
1584 md.update(iodict)
1584 md.update(iodict)
1585
1585
1586 if rcontent['status'] == 'ok':
1586 if rcontent['status'] == 'ok':
1587 if header['msg_type'] == 'apply_reply':
1587 if header['msg_type'] == 'apply_reply':
1588 res,buffers = serialize.unserialize_object(buffers)
1588 res,buffers = serialize.unserialize_object(buffers)
1589 elif header['msg_type'] == 'execute_reply':
1589 elif header['msg_type'] == 'execute_reply':
1590 res = ExecuteReply(msg_id, rcontent, md)
1590 res = ExecuteReply(msg_id, rcontent, md)
1591 else:
1591 else:
1592 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1592 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1593 else:
1593 else:
1594 res = self._unwrap_exception(rcontent)
1594 res = self._unwrap_exception(rcontent)
1595 failures.append(res)
1595 failures.append(res)
1596
1596
1597 self.results[msg_id] = res
1597 self.results[msg_id] = res
1598 content[msg_id] = res
1598 content[msg_id] = res
1599
1599
1600 if len(theids) == 1 and failures:
1600 if len(theids) == 1 and failures:
1601 raise failures[0]
1601 raise failures[0]
1602
1602
1603 error.collect_exceptions(failures, "result_status")
1603 error.collect_exceptions(failures, "result_status")
1604 return content
1604 return content
1605
1605
1606 @spin_first
1606 @spin_first
1607 def queue_status(self, targets='all', verbose=False):
1607 def queue_status(self, targets='all', verbose=False):
1608 """Fetch the status of engine queues.
1608 """Fetch the status of engine queues.
1609
1609
1610 Parameters
1610 Parameters
1611 ----------
1611 ----------
1612
1612
1613 targets : int/str/list of ints/strs
1613 targets : int/str/list of ints/strs
1614 the engines whose states are to be queried.
1614 the engines whose states are to be queried.
1615 default : all
1615 default : all
1616 verbose : bool
1616 verbose : bool
1617 Whether to return lengths only, or lists of ids for each element
1617 Whether to return lengths only, or lists of ids for each element
1618 """
1618 """
1619 if targets == 'all':
1619 if targets == 'all':
1620 # allow 'all' to be evaluated on the engine
1620 # allow 'all' to be evaluated on the engine
1621 engine_ids = None
1621 engine_ids = None
1622 else:
1622 else:
1623 engine_ids = self._build_targets(targets)[1]
1623 engine_ids = self._build_targets(targets)[1]
1624 content = dict(targets=engine_ids, verbose=verbose)
1624 content = dict(targets=engine_ids, verbose=verbose)
1625 self.session.send(self._query_socket, "queue_request", content=content)
1625 self.session.send(self._query_socket, "queue_request", content=content)
1626 idents,msg = self.session.recv(self._query_socket, 0)
1626 idents,msg = self.session.recv(self._query_socket, 0)
1627 if self.debug:
1627 if self.debug:
1628 pprint(msg)
1628 pprint(msg)
1629 content = msg['content']
1629 content = msg['content']
1630 status = content.pop('status')
1630 status = content.pop('status')
1631 if status != 'ok':
1631 if status != 'ok':
1632 raise self._unwrap_exception(content)
1632 raise self._unwrap_exception(content)
1633 content = rekey(content)
1633 content = rekey(content)
1634 if isinstance(targets, int):
1634 if isinstance(targets, int):
1635 return content[targets]
1635 return content[targets]
1636 else:
1636 else:
1637 return content
1637 return content
1638
1638
1639 def _build_msgids_from_target(self, targets=None):
1639 def _build_msgids_from_target(self, targets=None):
1640 """Build a list of msg_ids from the list of engine targets"""
1640 """Build a list of msg_ids from the list of engine targets"""
1641 if not targets: # needed as _build_targets otherwise uses all engines
1641 if not targets: # needed as _build_targets otherwise uses all engines
1642 return []
1642 return []
1643 target_ids = self._build_targets(targets)[0]
1643 target_ids = self._build_targets(targets)[0]
1644 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1644 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1645
1645
1646 def _build_msgids_from_jobs(self, jobs=None):
1646 def _build_msgids_from_jobs(self, jobs=None):
1647 """Build a list of msg_ids from "jobs" """
1647 """Build a list of msg_ids from "jobs" """
1648 if not jobs:
1648 if not jobs:
1649 return []
1649 return []
1650 msg_ids = []
1650 msg_ids = []
1651 if isinstance(jobs, string_types + (AsyncResult,)):
1651 if isinstance(jobs, string_types + (AsyncResult,)):
1652 jobs = [jobs]
1652 jobs = [jobs]
1653 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1653 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1654 if bad_ids:
1654 if bad_ids:
1655 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1655 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1656 for j in jobs:
1656 for j in jobs:
1657 if isinstance(j, AsyncResult):
1657 if isinstance(j, AsyncResult):
1658 msg_ids.extend(j.msg_ids)
1658 msg_ids.extend(j.msg_ids)
1659 else:
1659 else:
1660 msg_ids.append(j)
1660 msg_ids.append(j)
1661 return msg_ids
1661 return msg_ids
1662
1662
1663 def purge_local_results(self, jobs=[], targets=[]):
1663 def purge_local_results(self, jobs=[], targets=[]):
1664 """Clears the client caches of results and frees such memory.
1664 """Clears the client caches of results and frees such memory.
1665
1665
1666 Individual results can be purged by msg_id, or the entire
1666 Individual results can be purged by msg_id, or the entire
1667 history of specific targets can be purged.
1667 history of specific targets can be purged.
1668
1668
1669 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1669 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1670
1670
1671 The client must have no outstanding tasks before purging the caches.
1671 The client must have no outstanding tasks before purging the caches.
1672 Raises `AssertionError` if there are still outstanding tasks.
1672 Raises `AssertionError` if there are still outstanding tasks.
1673
1673
1674 After this call all `AsyncResults` are invalid and should be discarded.
1674 After this call all `AsyncResults` are invalid and should be discarded.
1675
1675
1676 If you must "reget" the results, you can still do so by using
1676 If you must "reget" the results, you can still do so by using
1677 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1677 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1678 redownload the results from the hub if they are still available
1678 redownload the results from the hub if they are still available
1679 (i.e `client.purge_hub_results(...)` has not been called.
1679 (i.e `client.purge_hub_results(...)` has not been called.
1680
1680
1681 Parameters
1681 Parameters
1682 ----------
1682 ----------
1683
1683
1684 jobs : str or list of str or AsyncResult objects
1684 jobs : str or list of str or AsyncResult objects
1685 the msg_ids whose results should be purged.
1685 the msg_ids whose results should be purged.
1686 targets : int/str/list of ints/strs
1686 targets : int/str/list of ints/strs
1687 The targets, by int_id, whose entire results are to be purged.
1687 The targets, by int_id, whose entire results are to be purged.
1688
1688
1689 default : None
1689 default : None
1690 """
1690 """
1691 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1691 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1692
1692
1693 if not targets and not jobs:
1693 if not targets and not jobs:
1694 raise ValueError("Must specify at least one of `targets` and `jobs`")
1694 raise ValueError("Must specify at least one of `targets` and `jobs`")
1695
1695
1696 if jobs == 'all':
1696 if jobs == 'all':
1697 self.results.clear()
1697 self.results.clear()
1698 self.metadata.clear()
1698 self.metadata.clear()
1699 return
1699 return
1700 else:
1700 else:
1701 msg_ids = []
1701 msg_ids = []
1702 msg_ids.extend(self._build_msgids_from_target(targets))
1702 msg_ids.extend(self._build_msgids_from_target(targets))
1703 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1703 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1704 for mid in msg_ids:
1704 for mid in msg_ids:
1705 self.results.pop(mid)
1705 self.results.pop(mid)
1706 self.metadata.pop(mid)
1706 self.metadata.pop(mid)
1707
1707
1708
1708
1709 @spin_first
1709 @spin_first
1710 def purge_hub_results(self, jobs=[], targets=[]):
1710 def purge_hub_results(self, jobs=[], targets=[]):
1711 """Tell the Hub to forget results.
1711 """Tell the Hub to forget results.
1712
1712
1713 Individual results can be purged by msg_id, or the entire
1713 Individual results can be purged by msg_id, or the entire
1714 history of specific targets can be purged.
1714 history of specific targets can be purged.
1715
1715
1716 Use `purge_results('all')` to scrub everything from the Hub's db.
1716 Use `purge_results('all')` to scrub everything from the Hub's db.
1717
1717
1718 Parameters
1718 Parameters
1719 ----------
1719 ----------
1720
1720
1721 jobs : str or list of str or AsyncResult objects
1721 jobs : str or list of str or AsyncResult objects
1722 the msg_ids whose results should be forgotten.
1722 the msg_ids whose results should be forgotten.
1723 targets : int/str/list of ints/strs
1723 targets : int/str/list of ints/strs
1724 The targets, by int_id, whose entire history is to be purged.
1724 The targets, by int_id, whose entire history is to be purged.
1725
1725
1726 default : None
1726 default : None
1727 """
1727 """
1728 if not targets and not jobs:
1728 if not targets and not jobs:
1729 raise ValueError("Must specify at least one of `targets` and `jobs`")
1729 raise ValueError("Must specify at least one of `targets` and `jobs`")
1730 if targets:
1730 if targets:
1731 targets = self._build_targets(targets)[1]
1731 targets = self._build_targets(targets)[1]
1732
1732
1733 # construct msg_ids from jobs
1733 # construct msg_ids from jobs
1734 if jobs == 'all':
1734 if jobs == 'all':
1735 msg_ids = jobs
1735 msg_ids = jobs
1736 else:
1736 else:
1737 msg_ids = self._build_msgids_from_jobs(jobs)
1737 msg_ids = self._build_msgids_from_jobs(jobs)
1738
1738
1739 content = dict(engine_ids=targets, msg_ids=msg_ids)
1739 content = dict(engine_ids=targets, msg_ids=msg_ids)
1740 self.session.send(self._query_socket, "purge_request", content=content)
1740 self.session.send(self._query_socket, "purge_request", content=content)
1741 idents, msg = self.session.recv(self._query_socket, 0)
1741 idents, msg = self.session.recv(self._query_socket, 0)
1742 if self.debug:
1742 if self.debug:
1743 pprint(msg)
1743 pprint(msg)
1744 content = msg['content']
1744 content = msg['content']
1745 if content['status'] != 'ok':
1745 if content['status'] != 'ok':
1746 raise self._unwrap_exception(content)
1746 raise self._unwrap_exception(content)
1747
1747
1748 def purge_results(self, jobs=[], targets=[]):
1748 def purge_results(self, jobs=[], targets=[]):
1749 """Clears the cached results from both the hub and the local client
1749 """Clears the cached results from both the hub and the local client
1750
1750
1751 Individual results can be purged by msg_id, or the entire
1751 Individual results can be purged by msg_id, or the entire
1752 history of specific targets can be purged.
1752 history of specific targets can be purged.
1753
1753
1754 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1754 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1755 the Client's db.
1755 the Client's db.
1756
1756
1757 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1757 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1758 the same arguments.
1758 the same arguments.
1759
1759
1760 Parameters
1760 Parameters
1761 ----------
1761 ----------
1762
1762
1763 jobs : str or list of str or AsyncResult objects
1763 jobs : str or list of str or AsyncResult objects
1764 the msg_ids whose results should be forgotten.
1764 the msg_ids whose results should be forgotten.
1765 targets : int/str/list of ints/strs
1765 targets : int/str/list of ints/strs
1766 The targets, by int_id, whose entire history is to be purged.
1766 The targets, by int_id, whose entire history is to be purged.
1767
1767
1768 default : None
1768 default : None
1769 """
1769 """
1770 self.purge_local_results(jobs=jobs, targets=targets)
1770 self.purge_local_results(jobs=jobs, targets=targets)
1771 self.purge_hub_results(jobs=jobs, targets=targets)
1771 self.purge_hub_results(jobs=jobs, targets=targets)
1772
1772
1773 def purge_everything(self):
1773 def purge_everything(self):
1774 """Clears all content from previous Tasks from both the hub and the local client
1774 """Clears all content from previous Tasks from both the hub and the local client
1775
1775
1776 In addition to calling `purge_results("all")` it also deletes the history and
1776 In addition to calling `purge_results("all")` it also deletes the history and
1777 other bookkeeping lists.
1777 other bookkeeping lists.
1778 """
1778 """
1779 self.purge_results("all")
1779 self.purge_results("all")
1780 self.history = []
1780 self.history = []
1781 self.session.digest_history.clear()
1781 self.session.digest_history.clear()
1782
1782
1783 @spin_first
1783 @spin_first
1784 def hub_history(self):
1784 def hub_history(self):
1785 """Get the Hub's history
1785 """Get the Hub's history
1786
1786
1787 Just like the Client, the Hub has a history, which is a list of msg_ids.
1787 Just like the Client, the Hub has a history, which is a list of msg_ids.
1788 This will contain the history of all clients, and, depending on configuration,
1788 This will contain the history of all clients, and, depending on configuration,
1789 may contain history across multiple cluster sessions.
1789 may contain history across multiple cluster sessions.
1790
1790
1791 Any msg_id returned here is a valid argument to `get_result`.
1791 Any msg_id returned here is a valid argument to `get_result`.
1792
1792
1793 Returns
1793 Returns
1794 -------
1794 -------
1795
1795
1796 msg_ids : list of strs
1796 msg_ids : list of strs
1797 list of all msg_ids, ordered by task submission time.
1797 list of all msg_ids, ordered by task submission time.
1798 """
1798 """
1799
1799
1800 self.session.send(self._query_socket, "history_request", content={})
1800 self.session.send(self._query_socket, "history_request", content={})
1801 idents, msg = self.session.recv(self._query_socket, 0)
1801 idents, msg = self.session.recv(self._query_socket, 0)
1802
1802
1803 if self.debug:
1803 if self.debug:
1804 pprint(msg)
1804 pprint(msg)
1805 content = msg['content']
1805 content = msg['content']
1806 if content['status'] != 'ok':
1806 if content['status'] != 'ok':
1807 raise self._unwrap_exception(content)
1807 raise self._unwrap_exception(content)
1808 else:
1808 else:
1809 return content['history']
1809 return content['history']
1810
1810
1811 @spin_first
1811 @spin_first
1812 def db_query(self, query, keys=None):
1812 def db_query(self, query, keys=None):
1813 """Query the Hub's TaskRecord database
1813 """Query the Hub's TaskRecord database
1814
1814
1815 This will return a list of task record dicts that match `query`
1815 This will return a list of task record dicts that match `query`
1816
1816
1817 Parameters
1817 Parameters
1818 ----------
1818 ----------
1819
1819
1820 query : mongodb query dict
1820 query : mongodb query dict
1821 The search dict. See mongodb query docs for details.
1821 The search dict. See mongodb query docs for details.
1822 keys : list of strs [optional]
1822 keys : list of strs [optional]
1823 The subset of keys to be returned. The default is to fetch everything but buffers.
1823 The subset of keys to be returned. The default is to fetch everything but buffers.
1824 'msg_id' will *always* be included.
1824 'msg_id' will *always* be included.
1825 """
1825 """
1826 if isinstance(keys, string_types):
1826 if isinstance(keys, string_types):
1827 keys = [keys]
1827 keys = [keys]
1828 content = dict(query=query, keys=keys)
1828 content = dict(query=query, keys=keys)
1829 self.session.send(self._query_socket, "db_request", content=content)
1829 self.session.send(self._query_socket, "db_request", content=content)
1830 idents, msg = self.session.recv(self._query_socket, 0)
1830 idents, msg = self.session.recv(self._query_socket, 0)
1831 if self.debug:
1831 if self.debug:
1832 pprint(msg)
1832 pprint(msg)
1833 content = msg['content']
1833 content = msg['content']
1834 if content['status'] != 'ok':
1834 if content['status'] != 'ok':
1835 raise self._unwrap_exception(content)
1835 raise self._unwrap_exception(content)
1836
1836
1837 records = content['records']
1837 records = content['records']
1838
1838
1839 buffer_lens = content['buffer_lens']
1839 buffer_lens = content['buffer_lens']
1840 result_buffer_lens = content['result_buffer_lens']
1840 result_buffer_lens = content['result_buffer_lens']
1841 buffers = msg['buffers']
1841 buffers = msg['buffers']
1842 has_bufs = buffer_lens is not None
1842 has_bufs = buffer_lens is not None
1843 has_rbufs = result_buffer_lens is not None
1843 has_rbufs = result_buffer_lens is not None
1844 for i,rec in enumerate(records):
1844 for i,rec in enumerate(records):
1845 # unpack datetime objects
1846 for hkey in ('header', 'result_header'):
1847 if hkey in rec:
1848 rec[hkey] = extract_dates(rec[hkey])
1849 for dtkey in ('submitted', 'started', 'completed', 'received'):
1850 if dtkey in rec:
1851 rec[dtkey] = parse_date(rec[dtkey])
1845 # relink buffers
1852 # relink buffers
1846 if has_bufs:
1853 if has_bufs:
1847 blen = buffer_lens[i]
1854 blen = buffer_lens[i]
1848 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1855 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1849 if has_rbufs:
1856 if has_rbufs:
1850 blen = result_buffer_lens[i]
1857 blen = result_buffer_lens[i]
1851 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1858 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1852
1859
1853 return records
1860 return records
1854
1861
1855 __all__ = [ 'Client' ]
1862 __all__ = [ 'Client' ]
@@ -1,1421 +1,1422
1 """The IPython Controller Hub with 0MQ
1 """The IPython Controller Hub with 0MQ
2 This is the master object that handles connections from engines and clients,
2 This is the master object that handles connections from engines and clients,
3 and monitors traffic through the various queues.
3 and monitors traffic through the various queues.
4
4
5 Authors:
5 Authors:
6
6
7 * Min RK
7 * Min RK
8 """
8 """
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Copyright (C) 2010-2011 The IPython Development Team
10 # Copyright (C) 2010-2011 The IPython Development Team
11 #
11 #
12 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Imports
17 # Imports
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 from __future__ import print_function
19 from __future__ import print_function
20
20
21 import json
21 import json
22 import os
22 import os
23 import sys
23 import sys
24 import time
24 import time
25 from datetime import datetime
25 from datetime import datetime
26
26
27 import zmq
27 import zmq
28 from zmq.eventloop import ioloop
28 from zmq.eventloop import ioloop
29 from zmq.eventloop.zmqstream import ZMQStream
29 from zmq.eventloop.zmqstream import ZMQStream
30
30
31 # internal:
31 # internal:
32 from IPython.utils.importstring import import_item
32 from IPython.utils.importstring import import_item
33 from IPython.utils.jsonutil import extract_dates
33 from IPython.utils.localinterfaces import localhost
34 from IPython.utils.localinterfaces import localhost
34 from IPython.utils.py3compat import cast_bytes, unicode_type, iteritems
35 from IPython.utils.py3compat import cast_bytes, unicode_type, iteritems
35 from IPython.utils.traitlets import (
36 from IPython.utils.traitlets import (
36 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
37 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
37 )
38 )
38
39
39 from IPython.parallel import error, util
40 from IPython.parallel import error, util
40 from IPython.parallel.factory import RegistrationFactory
41 from IPython.parallel.factory import RegistrationFactory
41
42
42 from IPython.kernel.zmq.session import SessionFactory
43 from IPython.kernel.zmq.session import SessionFactory
43
44
44 from .heartmonitor import HeartMonitor
45 from .heartmonitor import HeartMonitor
45
46
46 #-----------------------------------------------------------------------------
47 #-----------------------------------------------------------------------------
47 # Code
48 # Code
48 #-----------------------------------------------------------------------------
49 #-----------------------------------------------------------------------------
49
50
50 def _passer(*args, **kwargs):
51 def _passer(*args, **kwargs):
51 return
52 return
52
53
53 def _printer(*args, **kwargs):
54 def _printer(*args, **kwargs):
54 print (args)
55 print (args)
55 print (kwargs)
56 print (kwargs)
56
57
57 def empty_record():
58 def empty_record():
58 """Return an empty dict with all record keys."""
59 """Return an empty dict with all record keys."""
59 return {
60 return {
60 'msg_id' : None,
61 'msg_id' : None,
61 'header' : None,
62 'header' : None,
62 'metadata' : None,
63 'metadata' : None,
63 'content': None,
64 'content': None,
64 'buffers': None,
65 'buffers': None,
65 'submitted': None,
66 'submitted': None,
66 'client_uuid' : None,
67 'client_uuid' : None,
67 'engine_uuid' : None,
68 'engine_uuid' : None,
68 'started': None,
69 'started': None,
69 'completed': None,
70 'completed': None,
70 'resubmitted': None,
71 'resubmitted': None,
71 'received': None,
72 'received': None,
72 'result_header' : None,
73 'result_header' : None,
73 'result_metadata' : None,
74 'result_metadata' : None,
74 'result_content' : None,
75 'result_content' : None,
75 'result_buffers' : None,
76 'result_buffers' : None,
76 'queue' : None,
77 'queue' : None,
77 'pyin' : None,
78 'pyin' : None,
78 'pyout': None,
79 'pyout': None,
79 'pyerr': None,
80 'pyerr': None,
80 'stdout': '',
81 'stdout': '',
81 'stderr': '',
82 'stderr': '',
82 }
83 }
83
84
84 def init_record(msg):
85 def init_record(msg):
85 """Initialize a TaskRecord based on a request."""
86 """Initialize a TaskRecord based on a request."""
86 header = msg['header']
87 header = msg['header']
87 return {
88 return {
88 'msg_id' : header['msg_id'],
89 'msg_id' : header['msg_id'],
89 'header' : header,
90 'header' : header,
90 'content': msg['content'],
91 'content': msg['content'],
91 'metadata': msg['metadata'],
92 'metadata': msg['metadata'],
92 'buffers': msg['buffers'],
93 'buffers': msg['buffers'],
93 'submitted': header['date'],
94 'submitted': header['date'],
94 'client_uuid' : None,
95 'client_uuid' : None,
95 'engine_uuid' : None,
96 'engine_uuid' : None,
96 'started': None,
97 'started': None,
97 'completed': None,
98 'completed': None,
98 'resubmitted': None,
99 'resubmitted': None,
99 'received': None,
100 'received': None,
100 'result_header' : None,
101 'result_header' : None,
101 'result_metadata': None,
102 'result_metadata': None,
102 'result_content' : None,
103 'result_content' : None,
103 'result_buffers' : None,
104 'result_buffers' : None,
104 'queue' : None,
105 'queue' : None,
105 'pyin' : None,
106 'pyin' : None,
106 'pyout': None,
107 'pyout': None,
107 'pyerr': None,
108 'pyerr': None,
108 'stdout': '',
109 'stdout': '',
109 'stderr': '',
110 'stderr': '',
110 }
111 }
111
112
112
113
113 class EngineConnector(HasTraits):
114 class EngineConnector(HasTraits):
114 """A simple object for accessing the various zmq connections of an object.
115 """A simple object for accessing the various zmq connections of an object.
115 Attributes are:
116 Attributes are:
116 id (int): engine ID
117 id (int): engine ID
117 uuid (unicode): engine UUID
118 uuid (unicode): engine UUID
118 pending: set of msg_ids
119 pending: set of msg_ids
119 stallback: DelayedCallback for stalled registration
120 stallback: DelayedCallback for stalled registration
120 """
121 """
121
122
122 id = Integer(0)
123 id = Integer(0)
123 uuid = Unicode()
124 uuid = Unicode()
124 pending = Set()
125 pending = Set()
125 stallback = Instance(ioloop.DelayedCallback)
126 stallback = Instance(ioloop.DelayedCallback)
126
127
127
128
128 _db_shortcuts = {
129 _db_shortcuts = {
129 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
130 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
130 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
131 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
131 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
132 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
132 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
133 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
133 }
134 }
134
135
135 class HubFactory(RegistrationFactory):
136 class HubFactory(RegistrationFactory):
136 """The Configurable for setting up a Hub."""
137 """The Configurable for setting up a Hub."""
137
138
138 # port-pairs for monitoredqueues:
139 # port-pairs for monitoredqueues:
139 hb = Tuple(Integer,Integer,config=True,
140 hb = Tuple(Integer,Integer,config=True,
140 help="""PUB/ROUTER Port pair for Engine heartbeats""")
141 help="""PUB/ROUTER Port pair for Engine heartbeats""")
141 def _hb_default(self):
142 def _hb_default(self):
142 return tuple(util.select_random_ports(2))
143 return tuple(util.select_random_ports(2))
143
144
144 mux = Tuple(Integer,Integer,config=True,
145 mux = Tuple(Integer,Integer,config=True,
145 help="""Client/Engine Port pair for MUX queue""")
146 help="""Client/Engine Port pair for MUX queue""")
146
147
147 def _mux_default(self):
148 def _mux_default(self):
148 return tuple(util.select_random_ports(2))
149 return tuple(util.select_random_ports(2))
149
150
150 task = Tuple(Integer,Integer,config=True,
151 task = Tuple(Integer,Integer,config=True,
151 help="""Client/Engine Port pair for Task queue""")
152 help="""Client/Engine Port pair for Task queue""")
152 def _task_default(self):
153 def _task_default(self):
153 return tuple(util.select_random_ports(2))
154 return tuple(util.select_random_ports(2))
154
155
155 control = Tuple(Integer,Integer,config=True,
156 control = Tuple(Integer,Integer,config=True,
156 help="""Client/Engine Port pair for Control queue""")
157 help="""Client/Engine Port pair for Control queue""")
157
158
158 def _control_default(self):
159 def _control_default(self):
159 return tuple(util.select_random_ports(2))
160 return tuple(util.select_random_ports(2))
160
161
161 iopub = Tuple(Integer,Integer,config=True,
162 iopub = Tuple(Integer,Integer,config=True,
162 help="""Client/Engine Port pair for IOPub relay""")
163 help="""Client/Engine Port pair for IOPub relay""")
163
164
164 def _iopub_default(self):
165 def _iopub_default(self):
165 return tuple(util.select_random_ports(2))
166 return tuple(util.select_random_ports(2))
166
167
167 # single ports:
168 # single ports:
168 mon_port = Integer(config=True,
169 mon_port = Integer(config=True,
169 help="""Monitor (SUB) port for queue traffic""")
170 help="""Monitor (SUB) port for queue traffic""")
170
171
171 def _mon_port_default(self):
172 def _mon_port_default(self):
172 return util.select_random_ports(1)[0]
173 return util.select_random_ports(1)[0]
173
174
174 notifier_port = Integer(config=True,
175 notifier_port = Integer(config=True,
175 help="""PUB port for sending engine status notifications""")
176 help="""PUB port for sending engine status notifications""")
176
177
177 def _notifier_port_default(self):
178 def _notifier_port_default(self):
178 return util.select_random_ports(1)[0]
179 return util.select_random_ports(1)[0]
179
180
180 engine_ip = Unicode(config=True,
181 engine_ip = Unicode(config=True,
181 help="IP on which to listen for engine connections. [default: loopback]")
182 help="IP on which to listen for engine connections. [default: loopback]")
182 def _engine_ip_default(self):
183 def _engine_ip_default(self):
183 return localhost()
184 return localhost()
184 engine_transport = Unicode('tcp', config=True,
185 engine_transport = Unicode('tcp', config=True,
185 help="0MQ transport for engine connections. [default: tcp]")
186 help="0MQ transport for engine connections. [default: tcp]")
186
187
187 client_ip = Unicode(config=True,
188 client_ip = Unicode(config=True,
188 help="IP on which to listen for client connections. [default: loopback]")
189 help="IP on which to listen for client connections. [default: loopback]")
189 client_transport = Unicode('tcp', config=True,
190 client_transport = Unicode('tcp', config=True,
190 help="0MQ transport for client connections. [default : tcp]")
191 help="0MQ transport for client connections. [default : tcp]")
191
192
192 monitor_ip = Unicode(config=True,
193 monitor_ip = Unicode(config=True,
193 help="IP on which to listen for monitor messages. [default: loopback]")
194 help="IP on which to listen for monitor messages. [default: loopback]")
194 monitor_transport = Unicode('tcp', config=True,
195 monitor_transport = Unicode('tcp', config=True,
195 help="0MQ transport for monitor messages. [default : tcp]")
196 help="0MQ transport for monitor messages. [default : tcp]")
196
197
197 _client_ip_default = _monitor_ip_default = _engine_ip_default
198 _client_ip_default = _monitor_ip_default = _engine_ip_default
198
199
199
200
200 monitor_url = Unicode('')
201 monitor_url = Unicode('')
201
202
202 db_class = DottedObjectName('NoDB',
203 db_class = DottedObjectName('NoDB',
203 config=True, help="""The class to use for the DB backend
204 config=True, help="""The class to use for the DB backend
204
205
205 Options include:
206 Options include:
206
207
207 SQLiteDB: SQLite
208 SQLiteDB: SQLite
208 MongoDB : use MongoDB
209 MongoDB : use MongoDB
209 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
210 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
210 NoDB : disable database altogether (default)
211 NoDB : disable database altogether (default)
211
212
212 """)
213 """)
213
214
214 # not configurable
215 # not configurable
215 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
216 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
216 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
217 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
217
218
218 def _ip_changed(self, name, old, new):
219 def _ip_changed(self, name, old, new):
219 self.engine_ip = new
220 self.engine_ip = new
220 self.client_ip = new
221 self.client_ip = new
221 self.monitor_ip = new
222 self.monitor_ip = new
222 self._update_monitor_url()
223 self._update_monitor_url()
223
224
224 def _update_monitor_url(self):
225 def _update_monitor_url(self):
225 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
226 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
226
227
227 def _transport_changed(self, name, old, new):
228 def _transport_changed(self, name, old, new):
228 self.engine_transport = new
229 self.engine_transport = new
229 self.client_transport = new
230 self.client_transport = new
230 self.monitor_transport = new
231 self.monitor_transport = new
231 self._update_monitor_url()
232 self._update_monitor_url()
232
233
233 def __init__(self, **kwargs):
234 def __init__(self, **kwargs):
234 super(HubFactory, self).__init__(**kwargs)
235 super(HubFactory, self).__init__(**kwargs)
235 self._update_monitor_url()
236 self._update_monitor_url()
236
237
237
238
238 def construct(self):
239 def construct(self):
239 self.init_hub()
240 self.init_hub()
240
241
241 def start(self):
242 def start(self):
242 self.heartmonitor.start()
243 self.heartmonitor.start()
243 self.log.info("Heartmonitor started")
244 self.log.info("Heartmonitor started")
244
245
245 def client_url(self, channel):
246 def client_url(self, channel):
246 """return full zmq url for a named client channel"""
247 """return full zmq url for a named client channel"""
247 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
248 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
248
249
249 def engine_url(self, channel):
250 def engine_url(self, channel):
250 """return full zmq url for a named engine channel"""
251 """return full zmq url for a named engine channel"""
251 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
252 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
252
253
253 def init_hub(self):
254 def init_hub(self):
254 """construct Hub object"""
255 """construct Hub object"""
255
256
256 ctx = self.context
257 ctx = self.context
257 loop = self.loop
258 loop = self.loop
258 if 'TaskScheduler.scheme_name' in self.config:
259 if 'TaskScheduler.scheme_name' in self.config:
259 scheme = self.config.TaskScheduler.scheme_name
260 scheme = self.config.TaskScheduler.scheme_name
260 else:
261 else:
261 from .scheduler import TaskScheduler
262 from .scheduler import TaskScheduler
262 scheme = TaskScheduler.scheme_name.get_default_value()
263 scheme = TaskScheduler.scheme_name.get_default_value()
263
264
264 # build connection dicts
265 # build connection dicts
265 engine = self.engine_info = {
266 engine = self.engine_info = {
266 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
267 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
267 'registration' : self.regport,
268 'registration' : self.regport,
268 'control' : self.control[1],
269 'control' : self.control[1],
269 'mux' : self.mux[1],
270 'mux' : self.mux[1],
270 'hb_ping' : self.hb[0],
271 'hb_ping' : self.hb[0],
271 'hb_pong' : self.hb[1],
272 'hb_pong' : self.hb[1],
272 'task' : self.task[1],
273 'task' : self.task[1],
273 'iopub' : self.iopub[1],
274 'iopub' : self.iopub[1],
274 }
275 }
275
276
276 client = self.client_info = {
277 client = self.client_info = {
277 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
278 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
278 'registration' : self.regport,
279 'registration' : self.regport,
279 'control' : self.control[0],
280 'control' : self.control[0],
280 'mux' : self.mux[0],
281 'mux' : self.mux[0],
281 'task' : self.task[0],
282 'task' : self.task[0],
282 'task_scheme' : scheme,
283 'task_scheme' : scheme,
283 'iopub' : self.iopub[0],
284 'iopub' : self.iopub[0],
284 'notification' : self.notifier_port,
285 'notification' : self.notifier_port,
285 }
286 }
286
287
287 self.log.debug("Hub engine addrs: %s", self.engine_info)
288 self.log.debug("Hub engine addrs: %s", self.engine_info)
288 self.log.debug("Hub client addrs: %s", self.client_info)
289 self.log.debug("Hub client addrs: %s", self.client_info)
289
290
290 # Registrar socket
291 # Registrar socket
291 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
292 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
292 util.set_hwm(q, 0)
293 util.set_hwm(q, 0)
293 q.bind(self.client_url('registration'))
294 q.bind(self.client_url('registration'))
294 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
295 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
295 if self.client_ip != self.engine_ip:
296 if self.client_ip != self.engine_ip:
296 q.bind(self.engine_url('registration'))
297 q.bind(self.engine_url('registration'))
297 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
298 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
298
299
299 ### Engine connections ###
300 ### Engine connections ###
300
301
301 # heartbeat
302 # heartbeat
302 hpub = ctx.socket(zmq.PUB)
303 hpub = ctx.socket(zmq.PUB)
303 hpub.bind(self.engine_url('hb_ping'))
304 hpub.bind(self.engine_url('hb_ping'))
304 hrep = ctx.socket(zmq.ROUTER)
305 hrep = ctx.socket(zmq.ROUTER)
305 util.set_hwm(hrep, 0)
306 util.set_hwm(hrep, 0)
306 hrep.bind(self.engine_url('hb_pong'))
307 hrep.bind(self.engine_url('hb_pong'))
307 self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
308 self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
308 pingstream=ZMQStream(hpub,loop),
309 pingstream=ZMQStream(hpub,loop),
309 pongstream=ZMQStream(hrep,loop)
310 pongstream=ZMQStream(hrep,loop)
310 )
311 )
311
312
312 ### Client connections ###
313 ### Client connections ###
313
314
314 # Notifier socket
315 # Notifier socket
315 n = ZMQStream(ctx.socket(zmq.PUB), loop)
316 n = ZMQStream(ctx.socket(zmq.PUB), loop)
316 n.bind(self.client_url('notification'))
317 n.bind(self.client_url('notification'))
317
318
318 ### build and launch the queues ###
319 ### build and launch the queues ###
319
320
320 # monitor socket
321 # monitor socket
321 sub = ctx.socket(zmq.SUB)
322 sub = ctx.socket(zmq.SUB)
322 sub.setsockopt(zmq.SUBSCRIBE, b"")
323 sub.setsockopt(zmq.SUBSCRIBE, b"")
323 sub.bind(self.monitor_url)
324 sub.bind(self.monitor_url)
324 sub.bind('inproc://monitor')
325 sub.bind('inproc://monitor')
325 sub = ZMQStream(sub, loop)
326 sub = ZMQStream(sub, loop)
326
327
327 # connect the db
328 # connect the db
328 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
329 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
329 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
330 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
330 self.db = import_item(str(db_class))(session=self.session.session,
331 self.db = import_item(str(db_class))(session=self.session.session,
331 parent=self, log=self.log)
332 parent=self, log=self.log)
332 time.sleep(.25)
333 time.sleep(.25)
333
334
334 # resubmit stream
335 # resubmit stream
335 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
336 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
336 url = util.disambiguate_url(self.client_url('task'))
337 url = util.disambiguate_url(self.client_url('task'))
337 r.connect(url)
338 r.connect(url)
338
339
339 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
340 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
340 query=q, notifier=n, resubmit=r, db=self.db,
341 query=q, notifier=n, resubmit=r, db=self.db,
341 engine_info=self.engine_info, client_info=self.client_info,
342 engine_info=self.engine_info, client_info=self.client_info,
342 log=self.log)
343 log=self.log)
343
344
344
345
345 class Hub(SessionFactory):
346 class Hub(SessionFactory):
346 """The IPython Controller Hub with 0MQ connections
347 """The IPython Controller Hub with 0MQ connections
347
348
348 Parameters
349 Parameters
349 ==========
350 ==========
350 loop: zmq IOLoop instance
351 loop: zmq IOLoop instance
351 session: Session object
352 session: Session object
352 <removed> context: zmq context for creating new connections (?)
353 <removed> context: zmq context for creating new connections (?)
353 queue: ZMQStream for monitoring the command queue (SUB)
354 queue: ZMQStream for monitoring the command queue (SUB)
354 query: ZMQStream for engine registration and client queries requests (ROUTER)
355 query: ZMQStream for engine registration and client queries requests (ROUTER)
355 heartbeat: HeartMonitor object checking the pulse of the engines
356 heartbeat: HeartMonitor object checking the pulse of the engines
356 notifier: ZMQStream for broadcasting engine registration changes (PUB)
357 notifier: ZMQStream for broadcasting engine registration changes (PUB)
357 db: connection to db for out of memory logging of commands
358 db: connection to db for out of memory logging of commands
358 NotImplemented
359 NotImplemented
359 engine_info: dict of zmq connection information for engines to connect
360 engine_info: dict of zmq connection information for engines to connect
360 to the queues.
361 to the queues.
361 client_info: dict of zmq connection information for engines to connect
362 client_info: dict of zmq connection information for engines to connect
362 to the queues.
363 to the queues.
363 """
364 """
364
365
365 engine_state_file = Unicode()
366 engine_state_file = Unicode()
366
367
367 # internal data structures:
368 # internal data structures:
368 ids=Set() # engine IDs
369 ids=Set() # engine IDs
369 keytable=Dict()
370 keytable=Dict()
370 by_ident=Dict()
371 by_ident=Dict()
371 engines=Dict()
372 engines=Dict()
372 clients=Dict()
373 clients=Dict()
373 hearts=Dict()
374 hearts=Dict()
374 pending=Set()
375 pending=Set()
375 queues=Dict() # pending msg_ids keyed by engine_id
376 queues=Dict() # pending msg_ids keyed by engine_id
376 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
377 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
377 completed=Dict() # completed msg_ids keyed by engine_id
378 completed=Dict() # completed msg_ids keyed by engine_id
378 all_completed=Set() # completed msg_ids keyed by engine_id
379 all_completed=Set() # completed msg_ids keyed by engine_id
379 dead_engines=Set() # completed msg_ids keyed by engine_id
380 dead_engines=Set() # completed msg_ids keyed by engine_id
380 unassigned=Set() # set of task msg_ds not yet assigned a destination
381 unassigned=Set() # set of task msg_ds not yet assigned a destination
381 incoming_registrations=Dict()
382 incoming_registrations=Dict()
382 registration_timeout=Integer()
383 registration_timeout=Integer()
383 _idcounter=Integer(0)
384 _idcounter=Integer(0)
384
385
385 # objects from constructor:
386 # objects from constructor:
386 query=Instance(ZMQStream)
387 query=Instance(ZMQStream)
387 monitor=Instance(ZMQStream)
388 monitor=Instance(ZMQStream)
388 notifier=Instance(ZMQStream)
389 notifier=Instance(ZMQStream)
389 resubmit=Instance(ZMQStream)
390 resubmit=Instance(ZMQStream)
390 heartmonitor=Instance(HeartMonitor)
391 heartmonitor=Instance(HeartMonitor)
391 db=Instance(object)
392 db=Instance(object)
392 client_info=Dict()
393 client_info=Dict()
393 engine_info=Dict()
394 engine_info=Dict()
394
395
395
396
396 def __init__(self, **kwargs):
397 def __init__(self, **kwargs):
397 """
398 """
398 # universal:
399 # universal:
399 loop: IOLoop for creating future connections
400 loop: IOLoop for creating future connections
400 session: streamsession for sending serialized data
401 session: streamsession for sending serialized data
401 # engine:
402 # engine:
402 queue: ZMQStream for monitoring queue messages
403 queue: ZMQStream for monitoring queue messages
403 query: ZMQStream for engine+client registration and client requests
404 query: ZMQStream for engine+client registration and client requests
404 heartbeat: HeartMonitor object for tracking engines
405 heartbeat: HeartMonitor object for tracking engines
405 # extra:
406 # extra:
406 db: ZMQStream for db connection (NotImplemented)
407 db: ZMQStream for db connection (NotImplemented)
407 engine_info: zmq address/protocol dict for engine connections
408 engine_info: zmq address/protocol dict for engine connections
408 client_info: zmq address/protocol dict for client connections
409 client_info: zmq address/protocol dict for client connections
409 """
410 """
410
411
411 super(Hub, self).__init__(**kwargs)
412 super(Hub, self).__init__(**kwargs)
412 self.registration_timeout = max(10000, 5*self.heartmonitor.period)
413 self.registration_timeout = max(10000, 5*self.heartmonitor.period)
413
414
414 # register our callbacks
415 # register our callbacks
415 self.query.on_recv(self.dispatch_query)
416 self.query.on_recv(self.dispatch_query)
416 self.monitor.on_recv(self.dispatch_monitor_traffic)
417 self.monitor.on_recv(self.dispatch_monitor_traffic)
417
418
418 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
419 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
419 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
420 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
420
421
421 self.monitor_handlers = {b'in' : self.save_queue_request,
422 self.monitor_handlers = {b'in' : self.save_queue_request,
422 b'out': self.save_queue_result,
423 b'out': self.save_queue_result,
423 b'intask': self.save_task_request,
424 b'intask': self.save_task_request,
424 b'outtask': self.save_task_result,
425 b'outtask': self.save_task_result,
425 b'tracktask': self.save_task_destination,
426 b'tracktask': self.save_task_destination,
426 b'incontrol': _passer,
427 b'incontrol': _passer,
427 b'outcontrol': _passer,
428 b'outcontrol': _passer,
428 b'iopub': self.save_iopub_message,
429 b'iopub': self.save_iopub_message,
429 }
430 }
430
431
431 self.query_handlers = {'queue_request': self.queue_status,
432 self.query_handlers = {'queue_request': self.queue_status,
432 'result_request': self.get_results,
433 'result_request': self.get_results,
433 'history_request': self.get_history,
434 'history_request': self.get_history,
434 'db_request': self.db_query,
435 'db_request': self.db_query,
435 'purge_request': self.purge_results,
436 'purge_request': self.purge_results,
436 'load_request': self.check_load,
437 'load_request': self.check_load,
437 'resubmit_request': self.resubmit_task,
438 'resubmit_request': self.resubmit_task,
438 'shutdown_request': self.shutdown_request,
439 'shutdown_request': self.shutdown_request,
439 'registration_request' : self.register_engine,
440 'registration_request' : self.register_engine,
440 'unregistration_request' : self.unregister_engine,
441 'unregistration_request' : self.unregister_engine,
441 'connection_request': self.connection_request,
442 'connection_request': self.connection_request,
442 }
443 }
443
444
444 # ignore resubmit replies
445 # ignore resubmit replies
445 self.resubmit.on_recv(lambda msg: None, copy=False)
446 self.resubmit.on_recv(lambda msg: None, copy=False)
446
447
447 self.log.info("hub::created hub")
448 self.log.info("hub::created hub")
448
449
449 @property
450 @property
450 def _next_id(self):
451 def _next_id(self):
451 """gemerate a new ID.
452 """gemerate a new ID.
452
453
453 No longer reuse old ids, just count from 0."""
454 No longer reuse old ids, just count from 0."""
454 newid = self._idcounter
455 newid = self._idcounter
455 self._idcounter += 1
456 self._idcounter += 1
456 return newid
457 return newid
457 # newid = 0
458 # newid = 0
458 # incoming = [id[0] for id in itervalues(self.incoming_registrations)]
459 # incoming = [id[0] for id in itervalues(self.incoming_registrations)]
459 # # print newid, self.ids, self.incoming_registrations
460 # # print newid, self.ids, self.incoming_registrations
460 # while newid in self.ids or newid in incoming:
461 # while newid in self.ids or newid in incoming:
461 # newid += 1
462 # newid += 1
462 # return newid
463 # return newid
463
464
464 #-----------------------------------------------------------------------------
465 #-----------------------------------------------------------------------------
465 # message validation
466 # message validation
466 #-----------------------------------------------------------------------------
467 #-----------------------------------------------------------------------------
467
468
468 def _validate_targets(self, targets):
469 def _validate_targets(self, targets):
469 """turn any valid targets argument into a list of integer ids"""
470 """turn any valid targets argument into a list of integer ids"""
470 if targets is None:
471 if targets is None:
471 # default to all
472 # default to all
472 return self.ids
473 return self.ids
473
474
474 if isinstance(targets, (int,str,unicode_type)):
475 if isinstance(targets, (int,str,unicode_type)):
475 # only one target specified
476 # only one target specified
476 targets = [targets]
477 targets = [targets]
477 _targets = []
478 _targets = []
478 for t in targets:
479 for t in targets:
479 # map raw identities to ids
480 # map raw identities to ids
480 if isinstance(t, (str,unicode_type)):
481 if isinstance(t, (str,unicode_type)):
481 t = self.by_ident.get(cast_bytes(t), t)
482 t = self.by_ident.get(cast_bytes(t), t)
482 _targets.append(t)
483 _targets.append(t)
483 targets = _targets
484 targets = _targets
484 bad_targets = [ t for t in targets if t not in self.ids ]
485 bad_targets = [ t for t in targets if t not in self.ids ]
485 if bad_targets:
486 if bad_targets:
486 raise IndexError("No Such Engine: %r" % bad_targets)
487 raise IndexError("No Such Engine: %r" % bad_targets)
487 if not targets:
488 if not targets:
488 raise IndexError("No Engines Registered")
489 raise IndexError("No Engines Registered")
489 return targets
490 return targets
490
491
491 #-----------------------------------------------------------------------------
492 #-----------------------------------------------------------------------------
492 # dispatch methods (1 per stream)
493 # dispatch methods (1 per stream)
493 #-----------------------------------------------------------------------------
494 #-----------------------------------------------------------------------------
494
495
495
496
496 @util.log_errors
497 @util.log_errors
497 def dispatch_monitor_traffic(self, msg):
498 def dispatch_monitor_traffic(self, msg):
498 """all ME and Task queue messages come through here, as well as
499 """all ME and Task queue messages come through here, as well as
499 IOPub traffic."""
500 IOPub traffic."""
500 self.log.debug("monitor traffic: %r", msg[0])
501 self.log.debug("monitor traffic: %r", msg[0])
501 switch = msg[0]
502 switch = msg[0]
502 try:
503 try:
503 idents, msg = self.session.feed_identities(msg[1:])
504 idents, msg = self.session.feed_identities(msg[1:])
504 except ValueError:
505 except ValueError:
505 idents=[]
506 idents=[]
506 if not idents:
507 if not idents:
507 self.log.error("Monitor message without topic: %r", msg)
508 self.log.error("Monitor message without topic: %r", msg)
508 return
509 return
509 handler = self.monitor_handlers.get(switch, None)
510 handler = self.monitor_handlers.get(switch, None)
510 if handler is not None:
511 if handler is not None:
511 handler(idents, msg)
512 handler(idents, msg)
512 else:
513 else:
513 self.log.error("Unrecognized monitor topic: %r", switch)
514 self.log.error("Unrecognized monitor topic: %r", switch)
514
515
515
516
516 @util.log_errors
517 @util.log_errors
517 def dispatch_query(self, msg):
518 def dispatch_query(self, msg):
518 """Route registration requests and queries from clients."""
519 """Route registration requests and queries from clients."""
519 try:
520 try:
520 idents, msg = self.session.feed_identities(msg)
521 idents, msg = self.session.feed_identities(msg)
521 except ValueError:
522 except ValueError:
522 idents = []
523 idents = []
523 if not idents:
524 if not idents:
524 self.log.error("Bad Query Message: %r", msg)
525 self.log.error("Bad Query Message: %r", msg)
525 return
526 return
526 client_id = idents[0]
527 client_id = idents[0]
527 try:
528 try:
528 msg = self.session.unserialize(msg, content=True)
529 msg = self.session.unserialize(msg, content=True)
529 except Exception:
530 except Exception:
530 content = error.wrap_exception()
531 content = error.wrap_exception()
531 self.log.error("Bad Query Message: %r", msg, exc_info=True)
532 self.log.error("Bad Query Message: %r", msg, exc_info=True)
532 self.session.send(self.query, "hub_error", ident=client_id,
533 self.session.send(self.query, "hub_error", ident=client_id,
533 content=content)
534 content=content)
534 return
535 return
535 # print client_id, header, parent, content
536 # print client_id, header, parent, content
536 #switch on message type:
537 #switch on message type:
537 msg_type = msg['header']['msg_type']
538 msg_type = msg['header']['msg_type']
538 self.log.info("client::client %r requested %r", client_id, msg_type)
539 self.log.info("client::client %r requested %r", client_id, msg_type)
539 handler = self.query_handlers.get(msg_type, None)
540 handler = self.query_handlers.get(msg_type, None)
540 try:
541 try:
541 assert handler is not None, "Bad Message Type: %r" % msg_type
542 assert handler is not None, "Bad Message Type: %r" % msg_type
542 except:
543 except:
543 content = error.wrap_exception()
544 content = error.wrap_exception()
544 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
545 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
545 self.session.send(self.query, "hub_error", ident=client_id,
546 self.session.send(self.query, "hub_error", ident=client_id,
546 content=content)
547 content=content)
547 return
548 return
548
549
549 else:
550 else:
550 handler(idents, msg)
551 handler(idents, msg)
551
552
552 def dispatch_db(self, msg):
553 def dispatch_db(self, msg):
553 """"""
554 """"""
554 raise NotImplementedError
555 raise NotImplementedError
555
556
556 #---------------------------------------------------------------------------
557 #---------------------------------------------------------------------------
557 # handler methods (1 per event)
558 # handler methods (1 per event)
558 #---------------------------------------------------------------------------
559 #---------------------------------------------------------------------------
559
560
560 #----------------------- Heartbeat --------------------------------------
561 #----------------------- Heartbeat --------------------------------------
561
562
562 def handle_new_heart(self, heart):
563 def handle_new_heart(self, heart):
563 """handler to attach to heartbeater.
564 """handler to attach to heartbeater.
564 Called when a new heart starts to beat.
565 Called when a new heart starts to beat.
565 Triggers completion of registration."""
566 Triggers completion of registration."""
566 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
567 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
567 if heart not in self.incoming_registrations:
568 if heart not in self.incoming_registrations:
568 self.log.info("heartbeat::ignoring new heart: %r", heart)
569 self.log.info("heartbeat::ignoring new heart: %r", heart)
569 else:
570 else:
570 self.finish_registration(heart)
571 self.finish_registration(heart)
571
572
572
573
573 def handle_heart_failure(self, heart):
574 def handle_heart_failure(self, heart):
574 """handler to attach to heartbeater.
575 """handler to attach to heartbeater.
575 called when a previously registered heart fails to respond to beat request.
576 called when a previously registered heart fails to respond to beat request.
576 triggers unregistration"""
577 triggers unregistration"""
577 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
578 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
578 eid = self.hearts.get(heart, None)
579 eid = self.hearts.get(heart, None)
579 uuid = self.engines[eid].uuid
580 uuid = self.engines[eid].uuid
580 if eid is None or self.keytable[eid] in self.dead_engines:
581 if eid is None or self.keytable[eid] in self.dead_engines:
581 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
582 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
582 else:
583 else:
583 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
584 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
584
585
585 #----------------------- MUX Queue Traffic ------------------------------
586 #----------------------- MUX Queue Traffic ------------------------------
586
587
587 def save_queue_request(self, idents, msg):
588 def save_queue_request(self, idents, msg):
588 if len(idents) < 2:
589 if len(idents) < 2:
589 self.log.error("invalid identity prefix: %r", idents)
590 self.log.error("invalid identity prefix: %r", idents)
590 return
591 return
591 queue_id, client_id = idents[:2]
592 queue_id, client_id = idents[:2]
592 try:
593 try:
593 msg = self.session.unserialize(msg)
594 msg = self.session.unserialize(msg)
594 except Exception:
595 except Exception:
595 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
596 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
596 return
597 return
597
598
598 eid = self.by_ident.get(queue_id, None)
599 eid = self.by_ident.get(queue_id, None)
599 if eid is None:
600 if eid is None:
600 self.log.error("queue::target %r not registered", queue_id)
601 self.log.error("queue::target %r not registered", queue_id)
601 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
602 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
602 return
603 return
603 record = init_record(msg)
604 record = init_record(msg)
604 msg_id = record['msg_id']
605 msg_id = record['msg_id']
605 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
606 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
606 # Unicode in records
607 # Unicode in records
607 record['engine_uuid'] = queue_id.decode('ascii')
608 record['engine_uuid'] = queue_id.decode('ascii')
608 record['client_uuid'] = msg['header']['session']
609 record['client_uuid'] = msg['header']['session']
609 record['queue'] = 'mux'
610 record['queue'] = 'mux'
610
611
611 try:
612 try:
612 # it's posible iopub arrived first:
613 # it's posible iopub arrived first:
613 existing = self.db.get_record(msg_id)
614 existing = self.db.get_record(msg_id)
614 for key,evalue in iteritems(existing):
615 for key,evalue in iteritems(existing):
615 rvalue = record.get(key, None)
616 rvalue = record.get(key, None)
616 if evalue and rvalue and evalue != rvalue:
617 if evalue and rvalue and evalue != rvalue:
617 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
618 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
618 elif evalue and not rvalue:
619 elif evalue and not rvalue:
619 record[key] = evalue
620 record[key] = evalue
620 try:
621 try:
621 self.db.update_record(msg_id, record)
622 self.db.update_record(msg_id, record)
622 except Exception:
623 except Exception:
623 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
624 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
624 except KeyError:
625 except KeyError:
625 try:
626 try:
626 self.db.add_record(msg_id, record)
627 self.db.add_record(msg_id, record)
627 except Exception:
628 except Exception:
628 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
629 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
629
630
630
631
631 self.pending.add(msg_id)
632 self.pending.add(msg_id)
632 self.queues[eid].append(msg_id)
633 self.queues[eid].append(msg_id)
633
634
634 def save_queue_result(self, idents, msg):
635 def save_queue_result(self, idents, msg):
635 if len(idents) < 2:
636 if len(idents) < 2:
636 self.log.error("invalid identity prefix: %r", idents)
637 self.log.error("invalid identity prefix: %r", idents)
637 return
638 return
638
639
639 client_id, queue_id = idents[:2]
640 client_id, queue_id = idents[:2]
640 try:
641 try:
641 msg = self.session.unserialize(msg)
642 msg = self.session.unserialize(msg)
642 except Exception:
643 except Exception:
643 self.log.error("queue::engine %r sent invalid message to %r: %r",
644 self.log.error("queue::engine %r sent invalid message to %r: %r",
644 queue_id, client_id, msg, exc_info=True)
645 queue_id, client_id, msg, exc_info=True)
645 return
646 return
646
647
647 eid = self.by_ident.get(queue_id, None)
648 eid = self.by_ident.get(queue_id, None)
648 if eid is None:
649 if eid is None:
649 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
650 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
650 return
651 return
651
652
652 parent = msg['parent_header']
653 parent = msg['parent_header']
653 if not parent:
654 if not parent:
654 return
655 return
655 msg_id = parent['msg_id']
656 msg_id = parent['msg_id']
656 if msg_id in self.pending:
657 if msg_id in self.pending:
657 self.pending.remove(msg_id)
658 self.pending.remove(msg_id)
658 self.all_completed.add(msg_id)
659 self.all_completed.add(msg_id)
659 self.queues[eid].remove(msg_id)
660 self.queues[eid].remove(msg_id)
660 self.completed[eid].append(msg_id)
661 self.completed[eid].append(msg_id)
661 self.log.info("queue::request %r completed on %s", msg_id, eid)
662 self.log.info("queue::request %r completed on %s", msg_id, eid)
662 elif msg_id not in self.all_completed:
663 elif msg_id not in self.all_completed:
663 # it could be a result from a dead engine that died before delivering the
664 # it could be a result from a dead engine that died before delivering the
664 # result
665 # result
665 self.log.warn("queue:: unknown msg finished %r", msg_id)
666 self.log.warn("queue:: unknown msg finished %r", msg_id)
666 return
667 return
667 # update record anyway, because the unregistration could have been premature
668 # update record anyway, because the unregistration could have been premature
668 rheader = msg['header']
669 rheader = msg['header']
669 md = msg['metadata']
670 md = msg['metadata']
670 completed = rheader['date']
671 completed = rheader['date']
671 started = md.get('started', None)
672 started = md.get('started', None)
672 result = {
673 result = {
673 'result_header' : rheader,
674 'result_header' : rheader,
674 'result_metadata': md,
675 'result_metadata': md,
675 'result_content': msg['content'],
676 'result_content': msg['content'],
676 'received': datetime.now(),
677 'received': datetime.now(),
677 'started' : started,
678 'started' : started,
678 'completed' : completed
679 'completed' : completed
679 }
680 }
680
681
681 result['result_buffers'] = msg['buffers']
682 result['result_buffers'] = msg['buffers']
682 try:
683 try:
683 self.db.update_record(msg_id, result)
684 self.db.update_record(msg_id, result)
684 except Exception:
685 except Exception:
685 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
686 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
686
687
687
688
688 #--------------------- Task Queue Traffic ------------------------------
689 #--------------------- Task Queue Traffic ------------------------------
689
690
690 def save_task_request(self, idents, msg):
691 def save_task_request(self, idents, msg):
691 """Save the submission of a task."""
692 """Save the submission of a task."""
692 client_id = idents[0]
693 client_id = idents[0]
693
694
694 try:
695 try:
695 msg = self.session.unserialize(msg)
696 msg = self.session.unserialize(msg)
696 except Exception:
697 except Exception:
697 self.log.error("task::client %r sent invalid task message: %r",
698 self.log.error("task::client %r sent invalid task message: %r",
698 client_id, msg, exc_info=True)
699 client_id, msg, exc_info=True)
699 return
700 return
700 record = init_record(msg)
701 record = init_record(msg)
701
702
702 record['client_uuid'] = msg['header']['session']
703 record['client_uuid'] = msg['header']['session']
703 record['queue'] = 'task'
704 record['queue'] = 'task'
704 header = msg['header']
705 header = msg['header']
705 msg_id = header['msg_id']
706 msg_id = header['msg_id']
706 self.pending.add(msg_id)
707 self.pending.add(msg_id)
707 self.unassigned.add(msg_id)
708 self.unassigned.add(msg_id)
708 try:
709 try:
709 # it's posible iopub arrived first:
710 # it's posible iopub arrived first:
710 existing = self.db.get_record(msg_id)
711 existing = self.db.get_record(msg_id)
711 if existing['resubmitted']:
712 if existing['resubmitted']:
712 for key in ('submitted', 'client_uuid', 'buffers'):
713 for key in ('submitted', 'client_uuid', 'buffers'):
713 # don't clobber these keys on resubmit
714 # don't clobber these keys on resubmit
714 # submitted and client_uuid should be different
715 # submitted and client_uuid should be different
715 # and buffers might be big, and shouldn't have changed
716 # and buffers might be big, and shouldn't have changed
716 record.pop(key)
717 record.pop(key)
717 # still check content,header which should not change
718 # still check content,header which should not change
718 # but are not expensive to compare as buffers
719 # but are not expensive to compare as buffers
719
720
720 for key,evalue in iteritems(existing):
721 for key,evalue in iteritems(existing):
721 if key.endswith('buffers'):
722 if key.endswith('buffers'):
722 # don't compare buffers
723 # don't compare buffers
723 continue
724 continue
724 rvalue = record.get(key, None)
725 rvalue = record.get(key, None)
725 if evalue and rvalue and evalue != rvalue:
726 if evalue and rvalue and evalue != rvalue:
726 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
727 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
727 elif evalue and not rvalue:
728 elif evalue and not rvalue:
728 record[key] = evalue
729 record[key] = evalue
729 try:
730 try:
730 self.db.update_record(msg_id, record)
731 self.db.update_record(msg_id, record)
731 except Exception:
732 except Exception:
732 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
733 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
733 except KeyError:
734 except KeyError:
734 try:
735 try:
735 self.db.add_record(msg_id, record)
736 self.db.add_record(msg_id, record)
736 except Exception:
737 except Exception:
737 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
738 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
738 except Exception:
739 except Exception:
739 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
740 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
740
741
741 def save_task_result(self, idents, msg):
742 def save_task_result(self, idents, msg):
742 """save the result of a completed task."""
743 """save the result of a completed task."""
743 client_id = idents[0]
744 client_id = idents[0]
744 try:
745 try:
745 msg = self.session.unserialize(msg)
746 msg = self.session.unserialize(msg)
746 except Exception:
747 except Exception:
747 self.log.error("task::invalid task result message send to %r: %r",
748 self.log.error("task::invalid task result message send to %r: %r",
748 client_id, msg, exc_info=True)
749 client_id, msg, exc_info=True)
749 return
750 return
750
751
751 parent = msg['parent_header']
752 parent = msg['parent_header']
752 if not parent:
753 if not parent:
753 # print msg
754 # print msg
754 self.log.warn("Task %r had no parent!", msg)
755 self.log.warn("Task %r had no parent!", msg)
755 return
756 return
756 msg_id = parent['msg_id']
757 msg_id = parent['msg_id']
757 if msg_id in self.unassigned:
758 if msg_id in self.unassigned:
758 self.unassigned.remove(msg_id)
759 self.unassigned.remove(msg_id)
759
760
760 header = msg['header']
761 header = msg['header']
761 md = msg['metadata']
762 md = msg['metadata']
762 engine_uuid = md.get('engine', u'')
763 engine_uuid = md.get('engine', u'')
763 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
764 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
764
765
765 status = md.get('status', None)
766 status = md.get('status', None)
766
767
767 if msg_id in self.pending:
768 if msg_id in self.pending:
768 self.log.info("task::task %r finished on %s", msg_id, eid)
769 self.log.info("task::task %r finished on %s", msg_id, eid)
769 self.pending.remove(msg_id)
770 self.pending.remove(msg_id)
770 self.all_completed.add(msg_id)
771 self.all_completed.add(msg_id)
771 if eid is not None:
772 if eid is not None:
772 if status != 'aborted':
773 if status != 'aborted':
773 self.completed[eid].append(msg_id)
774 self.completed[eid].append(msg_id)
774 if msg_id in self.tasks[eid]:
775 if msg_id in self.tasks[eid]:
775 self.tasks[eid].remove(msg_id)
776 self.tasks[eid].remove(msg_id)
776 completed = header['date']
777 completed = header['date']
777 started = md.get('started', None)
778 started = md.get('started', None)
778 result = {
779 result = {
779 'result_header' : header,
780 'result_header' : header,
780 'result_metadata': msg['metadata'],
781 'result_metadata': msg['metadata'],
781 'result_content': msg['content'],
782 'result_content': msg['content'],
782 'started' : started,
783 'started' : started,
783 'completed' : completed,
784 'completed' : completed,
784 'received' : datetime.now(),
785 'received' : datetime.now(),
785 'engine_uuid': engine_uuid,
786 'engine_uuid': engine_uuid,
786 }
787 }
787
788
788 result['result_buffers'] = msg['buffers']
789 result['result_buffers'] = msg['buffers']
789 try:
790 try:
790 self.db.update_record(msg_id, result)
791 self.db.update_record(msg_id, result)
791 except Exception:
792 except Exception:
792 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
793 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
793
794
794 else:
795 else:
795 self.log.debug("task::unknown task %r finished", msg_id)
796 self.log.debug("task::unknown task %r finished", msg_id)
796
797
797 def save_task_destination(self, idents, msg):
798 def save_task_destination(self, idents, msg):
798 try:
799 try:
799 msg = self.session.unserialize(msg, content=True)
800 msg = self.session.unserialize(msg, content=True)
800 except Exception:
801 except Exception:
801 self.log.error("task::invalid task tracking message", exc_info=True)
802 self.log.error("task::invalid task tracking message", exc_info=True)
802 return
803 return
803 content = msg['content']
804 content = msg['content']
804 # print (content)
805 # print (content)
805 msg_id = content['msg_id']
806 msg_id = content['msg_id']
806 engine_uuid = content['engine_id']
807 engine_uuid = content['engine_id']
807 eid = self.by_ident[cast_bytes(engine_uuid)]
808 eid = self.by_ident[cast_bytes(engine_uuid)]
808
809
809 self.log.info("task::task %r arrived on %r", msg_id, eid)
810 self.log.info("task::task %r arrived on %r", msg_id, eid)
810 if msg_id in self.unassigned:
811 if msg_id in self.unassigned:
811 self.unassigned.remove(msg_id)
812 self.unassigned.remove(msg_id)
812 # else:
813 # else:
813 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
814 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
814
815
815 self.tasks[eid].append(msg_id)
816 self.tasks[eid].append(msg_id)
816 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
817 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
817 try:
818 try:
818 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
819 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
819 except Exception:
820 except Exception:
820 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
821 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
821
822
822
823
823 def mia_task_request(self, idents, msg):
824 def mia_task_request(self, idents, msg):
824 raise NotImplementedError
825 raise NotImplementedError
825 client_id = idents[0]
826 client_id = idents[0]
826 # content = dict(mia=self.mia,status='ok')
827 # content = dict(mia=self.mia,status='ok')
827 # self.session.send('mia_reply', content=content, idents=client_id)
828 # self.session.send('mia_reply', content=content, idents=client_id)
828
829
829
830
830 #--------------------- IOPub Traffic ------------------------------
831 #--------------------- IOPub Traffic ------------------------------
831
832
832 def save_iopub_message(self, topics, msg):
833 def save_iopub_message(self, topics, msg):
833 """save an iopub message into the db"""
834 """save an iopub message into the db"""
834 # print (topics)
835 # print (topics)
835 try:
836 try:
836 msg = self.session.unserialize(msg, content=True)
837 msg = self.session.unserialize(msg, content=True)
837 except Exception:
838 except Exception:
838 self.log.error("iopub::invalid IOPub message", exc_info=True)
839 self.log.error("iopub::invalid IOPub message", exc_info=True)
839 return
840 return
840
841
841 parent = msg['parent_header']
842 parent = msg['parent_header']
842 if not parent:
843 if not parent:
843 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
844 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
844 return
845 return
845 msg_id = parent['msg_id']
846 msg_id = parent['msg_id']
846 msg_type = msg['header']['msg_type']
847 msg_type = msg['header']['msg_type']
847 content = msg['content']
848 content = msg['content']
848
849
849 # ensure msg_id is in db
850 # ensure msg_id is in db
850 try:
851 try:
851 rec = self.db.get_record(msg_id)
852 rec = self.db.get_record(msg_id)
852 except KeyError:
853 except KeyError:
853 rec = empty_record()
854 rec = empty_record()
854 rec['msg_id'] = msg_id
855 rec['msg_id'] = msg_id
855 self.db.add_record(msg_id, rec)
856 self.db.add_record(msg_id, rec)
856 # stream
857 # stream
857 d = {}
858 d = {}
858 if msg_type == 'stream':
859 if msg_type == 'stream':
859 name = content['name']
860 name = content['name']
860 s = rec[name] or ''
861 s = rec[name] or ''
861 d[name] = s + content['data']
862 d[name] = s + content['data']
862
863
863 elif msg_type == 'pyerr':
864 elif msg_type == 'pyerr':
864 d['pyerr'] = content
865 d['pyerr'] = content
865 elif msg_type == 'pyin':
866 elif msg_type == 'pyin':
866 d['pyin'] = content['code']
867 d['pyin'] = content['code']
867 elif msg_type in ('display_data', 'pyout'):
868 elif msg_type in ('display_data', 'pyout'):
868 d[msg_type] = content
869 d[msg_type] = content
869 elif msg_type == 'status':
870 elif msg_type == 'status':
870 pass
871 pass
871 elif msg_type == 'data_pub':
872 elif msg_type == 'data_pub':
872 self.log.info("ignored data_pub message for %s" % msg_id)
873 self.log.info("ignored data_pub message for %s" % msg_id)
873 else:
874 else:
874 self.log.warn("unhandled iopub msg_type: %r", msg_type)
875 self.log.warn("unhandled iopub msg_type: %r", msg_type)
875
876
876 if not d:
877 if not d:
877 return
878 return
878
879
879 try:
880 try:
880 self.db.update_record(msg_id, d)
881 self.db.update_record(msg_id, d)
881 except Exception:
882 except Exception:
882 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
883 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
883
884
884
885
885
886
886 #-------------------------------------------------------------------------
887 #-------------------------------------------------------------------------
887 # Registration requests
888 # Registration requests
888 #-------------------------------------------------------------------------
889 #-------------------------------------------------------------------------
889
890
890 def connection_request(self, client_id, msg):
891 def connection_request(self, client_id, msg):
891 """Reply with connection addresses for clients."""
892 """Reply with connection addresses for clients."""
892 self.log.info("client::client %r connected", client_id)
893 self.log.info("client::client %r connected", client_id)
893 content = dict(status='ok')
894 content = dict(status='ok')
894 jsonable = {}
895 jsonable = {}
895 for k,v in iteritems(self.keytable):
896 for k,v in iteritems(self.keytable):
896 if v not in self.dead_engines:
897 if v not in self.dead_engines:
897 jsonable[str(k)] = v
898 jsonable[str(k)] = v
898 content['engines'] = jsonable
899 content['engines'] = jsonable
899 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
900 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
900
901
901 def register_engine(self, reg, msg):
902 def register_engine(self, reg, msg):
902 """Register a new engine."""
903 """Register a new engine."""
903 content = msg['content']
904 content = msg['content']
904 try:
905 try:
905 uuid = content['uuid']
906 uuid = content['uuid']
906 except KeyError:
907 except KeyError:
907 self.log.error("registration::queue not specified", exc_info=True)
908 self.log.error("registration::queue not specified", exc_info=True)
908 return
909 return
909
910
910 eid = self._next_id
911 eid = self._next_id
911
912
912 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
913 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
913
914
914 content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
915 content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
915 # check if requesting available IDs:
916 # check if requesting available IDs:
916 if cast_bytes(uuid) in self.by_ident:
917 if cast_bytes(uuid) in self.by_ident:
917 try:
918 try:
918 raise KeyError("uuid %r in use" % uuid)
919 raise KeyError("uuid %r in use" % uuid)
919 except:
920 except:
920 content = error.wrap_exception()
921 content = error.wrap_exception()
921 self.log.error("uuid %r in use", uuid, exc_info=True)
922 self.log.error("uuid %r in use", uuid, exc_info=True)
922 else:
923 else:
923 for h, ec in iteritems(self.incoming_registrations):
924 for h, ec in iteritems(self.incoming_registrations):
924 if uuid == h:
925 if uuid == h:
925 try:
926 try:
926 raise KeyError("heart_id %r in use" % uuid)
927 raise KeyError("heart_id %r in use" % uuid)
927 except:
928 except:
928 self.log.error("heart_id %r in use", uuid, exc_info=True)
929 self.log.error("heart_id %r in use", uuid, exc_info=True)
929 content = error.wrap_exception()
930 content = error.wrap_exception()
930 break
931 break
931 elif uuid == ec.uuid:
932 elif uuid == ec.uuid:
932 try:
933 try:
933 raise KeyError("uuid %r in use" % uuid)
934 raise KeyError("uuid %r in use" % uuid)
934 except:
935 except:
935 self.log.error("uuid %r in use", uuid, exc_info=True)
936 self.log.error("uuid %r in use", uuid, exc_info=True)
936 content = error.wrap_exception()
937 content = error.wrap_exception()
937 break
938 break
938
939
939 msg = self.session.send(self.query, "registration_reply",
940 msg = self.session.send(self.query, "registration_reply",
940 content=content,
941 content=content,
941 ident=reg)
942 ident=reg)
942
943
943 heart = cast_bytes(uuid)
944 heart = cast_bytes(uuid)
944
945
945 if content['status'] == 'ok':
946 if content['status'] == 'ok':
946 if heart in self.heartmonitor.hearts:
947 if heart in self.heartmonitor.hearts:
947 # already beating
948 # already beating
948 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
949 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
949 self.finish_registration(heart)
950 self.finish_registration(heart)
950 else:
951 else:
951 purge = lambda : self._purge_stalled_registration(heart)
952 purge = lambda : self._purge_stalled_registration(heart)
952 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
953 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
953 dc.start()
954 dc.start()
954 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=dc)
955 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=dc)
955 else:
956 else:
956 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
957 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
957
958
958 return eid
959 return eid
959
960
960 def unregister_engine(self, ident, msg):
961 def unregister_engine(self, ident, msg):
961 """Unregister an engine that explicitly requested to leave."""
962 """Unregister an engine that explicitly requested to leave."""
962 try:
963 try:
963 eid = msg['content']['id']
964 eid = msg['content']['id']
964 except:
965 except:
965 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
966 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
966 return
967 return
967 self.log.info("registration::unregister_engine(%r)", eid)
968 self.log.info("registration::unregister_engine(%r)", eid)
968 # print (eid)
969 # print (eid)
969 uuid = self.keytable[eid]
970 uuid = self.keytable[eid]
970 content=dict(id=eid, uuid=uuid)
971 content=dict(id=eid, uuid=uuid)
971 self.dead_engines.add(uuid)
972 self.dead_engines.add(uuid)
972 # self.ids.remove(eid)
973 # self.ids.remove(eid)
973 # uuid = self.keytable.pop(eid)
974 # uuid = self.keytable.pop(eid)
974 #
975 #
975 # ec = self.engines.pop(eid)
976 # ec = self.engines.pop(eid)
976 # self.hearts.pop(ec.heartbeat)
977 # self.hearts.pop(ec.heartbeat)
977 # self.by_ident.pop(ec.queue)
978 # self.by_ident.pop(ec.queue)
978 # self.completed.pop(eid)
979 # self.completed.pop(eid)
979 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
980 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
980 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
981 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
981 dc.start()
982 dc.start()
982 ############## TODO: HANDLE IT ################
983 ############## TODO: HANDLE IT ################
983
984
984 self._save_engine_state()
985 self._save_engine_state()
985
986
986 if self.notifier:
987 if self.notifier:
987 self.session.send(self.notifier, "unregistration_notification", content=content)
988 self.session.send(self.notifier, "unregistration_notification", content=content)
988
989
989 def _handle_stranded_msgs(self, eid, uuid):
990 def _handle_stranded_msgs(self, eid, uuid):
990 """Handle messages known to be on an engine when the engine unregisters.
991 """Handle messages known to be on an engine when the engine unregisters.
991
992
992 It is possible that this will fire prematurely - that is, an engine will
993 It is possible that this will fire prematurely - that is, an engine will
993 go down after completing a result, and the client will be notified
994 go down after completing a result, and the client will be notified
994 that the result failed and later receive the actual result.
995 that the result failed and later receive the actual result.
995 """
996 """
996
997
997 outstanding = self.queues[eid]
998 outstanding = self.queues[eid]
998
999
999 for msg_id in outstanding:
1000 for msg_id in outstanding:
1000 self.pending.remove(msg_id)
1001 self.pending.remove(msg_id)
1001 self.all_completed.add(msg_id)
1002 self.all_completed.add(msg_id)
1002 try:
1003 try:
1003 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
1004 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
1004 except:
1005 except:
1005 content = error.wrap_exception()
1006 content = error.wrap_exception()
1006 # build a fake header:
1007 # build a fake header:
1007 header = {}
1008 header = {}
1008 header['engine'] = uuid
1009 header['engine'] = uuid
1009 header['date'] = datetime.now()
1010 header['date'] = datetime.now()
1010 rec = dict(result_content=content, result_header=header, result_buffers=[])
1011 rec = dict(result_content=content, result_header=header, result_buffers=[])
1011 rec['completed'] = header['date']
1012 rec['completed'] = header['date']
1012 rec['engine_uuid'] = uuid
1013 rec['engine_uuid'] = uuid
1013 try:
1014 try:
1014 self.db.update_record(msg_id, rec)
1015 self.db.update_record(msg_id, rec)
1015 except Exception:
1016 except Exception:
1016 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1017 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1017
1018
1018
1019
1019 def finish_registration(self, heart):
1020 def finish_registration(self, heart):
1020 """Second half of engine registration, called after our HeartMonitor
1021 """Second half of engine registration, called after our HeartMonitor
1021 has received a beat from the Engine's Heart."""
1022 has received a beat from the Engine's Heart."""
1022 try:
1023 try:
1023 ec = self.incoming_registrations.pop(heart)
1024 ec = self.incoming_registrations.pop(heart)
1024 except KeyError:
1025 except KeyError:
1025 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1026 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1026 return
1027 return
1027 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1028 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1028 if ec.stallback is not None:
1029 if ec.stallback is not None:
1029 ec.stallback.stop()
1030 ec.stallback.stop()
1030 eid = ec.id
1031 eid = ec.id
1031 self.ids.add(eid)
1032 self.ids.add(eid)
1032 self.keytable[eid] = ec.uuid
1033 self.keytable[eid] = ec.uuid
1033 self.engines[eid] = ec
1034 self.engines[eid] = ec
1034 self.by_ident[cast_bytes(ec.uuid)] = ec.id
1035 self.by_ident[cast_bytes(ec.uuid)] = ec.id
1035 self.queues[eid] = list()
1036 self.queues[eid] = list()
1036 self.tasks[eid] = list()
1037 self.tasks[eid] = list()
1037 self.completed[eid] = list()
1038 self.completed[eid] = list()
1038 self.hearts[heart] = eid
1039 self.hearts[heart] = eid
1039 content = dict(id=eid, uuid=self.engines[eid].uuid)
1040 content = dict(id=eid, uuid=self.engines[eid].uuid)
1040 if self.notifier:
1041 if self.notifier:
1041 self.session.send(self.notifier, "registration_notification", content=content)
1042 self.session.send(self.notifier, "registration_notification", content=content)
1042 self.log.info("engine::Engine Connected: %i", eid)
1043 self.log.info("engine::Engine Connected: %i", eid)
1043
1044
1044 self._save_engine_state()
1045 self._save_engine_state()
1045
1046
1046 def _purge_stalled_registration(self, heart):
1047 def _purge_stalled_registration(self, heart):
1047 if heart in self.incoming_registrations:
1048 if heart in self.incoming_registrations:
1048 ec = self.incoming_registrations.pop(heart)
1049 ec = self.incoming_registrations.pop(heart)
1049 self.log.info("registration::purging stalled registration: %i", ec.id)
1050 self.log.info("registration::purging stalled registration: %i", ec.id)
1050 else:
1051 else:
1051 pass
1052 pass
1052
1053
1053 #-------------------------------------------------------------------------
1054 #-------------------------------------------------------------------------
1054 # Engine State
1055 # Engine State
1055 #-------------------------------------------------------------------------
1056 #-------------------------------------------------------------------------
1056
1057
1057
1058
1058 def _cleanup_engine_state_file(self):
1059 def _cleanup_engine_state_file(self):
1059 """cleanup engine state mapping"""
1060 """cleanup engine state mapping"""
1060
1061
1061 if os.path.exists(self.engine_state_file):
1062 if os.path.exists(self.engine_state_file):
1062 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1063 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1063 try:
1064 try:
1064 os.remove(self.engine_state_file)
1065 os.remove(self.engine_state_file)
1065 except IOError:
1066 except IOError:
1066 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1067 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1067
1068
1068
1069
1069 def _save_engine_state(self):
1070 def _save_engine_state(self):
1070 """save engine mapping to JSON file"""
1071 """save engine mapping to JSON file"""
1071 if not self.engine_state_file:
1072 if not self.engine_state_file:
1072 return
1073 return
1073 self.log.debug("save engine state to %s" % self.engine_state_file)
1074 self.log.debug("save engine state to %s" % self.engine_state_file)
1074 state = {}
1075 state = {}
1075 engines = {}
1076 engines = {}
1076 for eid, ec in iteritems(self.engines):
1077 for eid, ec in iteritems(self.engines):
1077 if ec.uuid not in self.dead_engines:
1078 if ec.uuid not in self.dead_engines:
1078 engines[eid] = ec.uuid
1079 engines[eid] = ec.uuid
1079
1080
1080 state['engines'] = engines
1081 state['engines'] = engines
1081
1082
1082 state['next_id'] = self._idcounter
1083 state['next_id'] = self._idcounter
1083
1084
1084 with open(self.engine_state_file, 'w') as f:
1085 with open(self.engine_state_file, 'w') as f:
1085 json.dump(state, f)
1086 json.dump(state, f)
1086
1087
1087
1088
1088 def _load_engine_state(self):
1089 def _load_engine_state(self):
1089 """load engine mapping from JSON file"""
1090 """load engine mapping from JSON file"""
1090 if not os.path.exists(self.engine_state_file):
1091 if not os.path.exists(self.engine_state_file):
1091 return
1092 return
1092
1093
1093 self.log.info("loading engine state from %s" % self.engine_state_file)
1094 self.log.info("loading engine state from %s" % self.engine_state_file)
1094
1095
1095 with open(self.engine_state_file) as f:
1096 with open(self.engine_state_file) as f:
1096 state = json.load(f)
1097 state = json.load(f)
1097
1098
1098 save_notifier = self.notifier
1099 save_notifier = self.notifier
1099 self.notifier = None
1100 self.notifier = None
1100 for eid, uuid in iteritems(state['engines']):
1101 for eid, uuid in iteritems(state['engines']):
1101 heart = uuid.encode('ascii')
1102 heart = uuid.encode('ascii')
1102 # start with this heart as current and beating:
1103 # start with this heart as current and beating:
1103 self.heartmonitor.responses.add(heart)
1104 self.heartmonitor.responses.add(heart)
1104 self.heartmonitor.hearts.add(heart)
1105 self.heartmonitor.hearts.add(heart)
1105
1106
1106 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1107 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1107 self.finish_registration(heart)
1108 self.finish_registration(heart)
1108
1109
1109 self.notifier = save_notifier
1110 self.notifier = save_notifier
1110
1111
1111 self._idcounter = state['next_id']
1112 self._idcounter = state['next_id']
1112
1113
1113 #-------------------------------------------------------------------------
1114 #-------------------------------------------------------------------------
1114 # Client Requests
1115 # Client Requests
1115 #-------------------------------------------------------------------------
1116 #-------------------------------------------------------------------------
1116
1117
1117 def shutdown_request(self, client_id, msg):
1118 def shutdown_request(self, client_id, msg):
1118 """handle shutdown request."""
1119 """handle shutdown request."""
1119 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1120 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1120 # also notify other clients of shutdown
1121 # also notify other clients of shutdown
1121 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1122 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1122 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1123 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1123 dc.start()
1124 dc.start()
1124
1125
1125 def _shutdown(self):
1126 def _shutdown(self):
1126 self.log.info("hub::hub shutting down.")
1127 self.log.info("hub::hub shutting down.")
1127 time.sleep(0.1)
1128 time.sleep(0.1)
1128 sys.exit(0)
1129 sys.exit(0)
1129
1130
1130
1131
1131 def check_load(self, client_id, msg):
1132 def check_load(self, client_id, msg):
1132 content = msg['content']
1133 content = msg['content']
1133 try:
1134 try:
1134 targets = content['targets']
1135 targets = content['targets']
1135 targets = self._validate_targets(targets)
1136 targets = self._validate_targets(targets)
1136 except:
1137 except:
1137 content = error.wrap_exception()
1138 content = error.wrap_exception()
1138 self.session.send(self.query, "hub_error",
1139 self.session.send(self.query, "hub_error",
1139 content=content, ident=client_id)
1140 content=content, ident=client_id)
1140 return
1141 return
1141
1142
1142 content = dict(status='ok')
1143 content = dict(status='ok')
1143 # loads = {}
1144 # loads = {}
1144 for t in targets:
1145 for t in targets:
1145 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1146 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1146 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1147 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1147
1148
1148
1149
1149 def queue_status(self, client_id, msg):
1150 def queue_status(self, client_id, msg):
1150 """Return the Queue status of one or more targets.
1151 """Return the Queue status of one or more targets.
1151 if verbose: return the msg_ids
1152 if verbose: return the msg_ids
1152 else: return len of each type.
1153 else: return len of each type.
1153 keys: queue (pending MUX jobs)
1154 keys: queue (pending MUX jobs)
1154 tasks (pending Task jobs)
1155 tasks (pending Task jobs)
1155 completed (finished jobs from both queues)"""
1156 completed (finished jobs from both queues)"""
1156 content = msg['content']
1157 content = msg['content']
1157 targets = content['targets']
1158 targets = content['targets']
1158 try:
1159 try:
1159 targets = self._validate_targets(targets)
1160 targets = self._validate_targets(targets)
1160 except:
1161 except:
1161 content = error.wrap_exception()
1162 content = error.wrap_exception()
1162 self.session.send(self.query, "hub_error",
1163 self.session.send(self.query, "hub_error",
1163 content=content, ident=client_id)
1164 content=content, ident=client_id)
1164 return
1165 return
1165 verbose = content.get('verbose', False)
1166 verbose = content.get('verbose', False)
1166 content = dict(status='ok')
1167 content = dict(status='ok')
1167 for t in targets:
1168 for t in targets:
1168 queue = self.queues[t]
1169 queue = self.queues[t]
1169 completed = self.completed[t]
1170 completed = self.completed[t]
1170 tasks = self.tasks[t]
1171 tasks = self.tasks[t]
1171 if not verbose:
1172 if not verbose:
1172 queue = len(queue)
1173 queue = len(queue)
1173 completed = len(completed)
1174 completed = len(completed)
1174 tasks = len(tasks)
1175 tasks = len(tasks)
1175 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1176 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1176 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1177 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1177 # print (content)
1178 # print (content)
1178 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1179 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1179
1180
1180 def purge_results(self, client_id, msg):
1181 def purge_results(self, client_id, msg):
1181 """Purge results from memory. This method is more valuable before we move
1182 """Purge results from memory. This method is more valuable before we move
1182 to a DB based message storage mechanism."""
1183 to a DB based message storage mechanism."""
1183 content = msg['content']
1184 content = msg['content']
1184 self.log.info("Dropping records with %s", content)
1185 self.log.info("Dropping records with %s", content)
1185 msg_ids = content.get('msg_ids', [])
1186 msg_ids = content.get('msg_ids', [])
1186 reply = dict(status='ok')
1187 reply = dict(status='ok')
1187 if msg_ids == 'all':
1188 if msg_ids == 'all':
1188 try:
1189 try:
1189 self.db.drop_matching_records(dict(completed={'$ne':None}))
1190 self.db.drop_matching_records(dict(completed={'$ne':None}))
1190 except Exception:
1191 except Exception:
1191 reply = error.wrap_exception()
1192 reply = error.wrap_exception()
1192 else:
1193 else:
1193 pending = [m for m in msg_ids if (m in self.pending)]
1194 pending = [m for m in msg_ids if (m in self.pending)]
1194 if pending:
1195 if pending:
1195 try:
1196 try:
1196 raise IndexError("msg pending: %r" % pending[0])
1197 raise IndexError("msg pending: %r" % pending[0])
1197 except:
1198 except:
1198 reply = error.wrap_exception()
1199 reply = error.wrap_exception()
1199 else:
1200 else:
1200 try:
1201 try:
1201 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1202 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1202 except Exception:
1203 except Exception:
1203 reply = error.wrap_exception()
1204 reply = error.wrap_exception()
1204
1205
1205 if reply['status'] == 'ok':
1206 if reply['status'] == 'ok':
1206 eids = content.get('engine_ids', [])
1207 eids = content.get('engine_ids', [])
1207 for eid in eids:
1208 for eid in eids:
1208 if eid not in self.engines:
1209 if eid not in self.engines:
1209 try:
1210 try:
1210 raise IndexError("No such engine: %i" % eid)
1211 raise IndexError("No such engine: %i" % eid)
1211 except:
1212 except:
1212 reply = error.wrap_exception()
1213 reply = error.wrap_exception()
1213 break
1214 break
1214 uid = self.engines[eid].uuid
1215 uid = self.engines[eid].uuid
1215 try:
1216 try:
1216 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1217 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1217 except Exception:
1218 except Exception:
1218 reply = error.wrap_exception()
1219 reply = error.wrap_exception()
1219 break
1220 break
1220
1221
1221 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1222 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1222
1223
1223 def resubmit_task(self, client_id, msg):
1224 def resubmit_task(self, client_id, msg):
1224 """Resubmit one or more tasks."""
1225 """Resubmit one or more tasks."""
1225 def finish(reply):
1226 def finish(reply):
1226 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1227 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1227
1228
1228 content = msg['content']
1229 content = msg['content']
1229 msg_ids = content['msg_ids']
1230 msg_ids = content['msg_ids']
1230 reply = dict(status='ok')
1231 reply = dict(status='ok')
1231 try:
1232 try:
1232 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1233 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1233 'header', 'content', 'buffers'])
1234 'header', 'content', 'buffers'])
1234 except Exception:
1235 except Exception:
1235 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1236 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1236 return finish(error.wrap_exception())
1237 return finish(error.wrap_exception())
1237
1238
1238 # validate msg_ids
1239 # validate msg_ids
1239 found_ids = [ rec['msg_id'] for rec in records ]
1240 found_ids = [ rec['msg_id'] for rec in records ]
1240 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1241 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1241 if len(records) > len(msg_ids):
1242 if len(records) > len(msg_ids):
1242 try:
1243 try:
1243 raise RuntimeError("DB appears to be in an inconsistent state."
1244 raise RuntimeError("DB appears to be in an inconsistent state."
1244 "More matching records were found than should exist")
1245 "More matching records were found than should exist")
1245 except Exception:
1246 except Exception:
1246 return finish(error.wrap_exception())
1247 return finish(error.wrap_exception())
1247 elif len(records) < len(msg_ids):
1248 elif len(records) < len(msg_ids):
1248 missing = [ m for m in msg_ids if m not in found_ids ]
1249 missing = [ m for m in msg_ids if m not in found_ids ]
1249 try:
1250 try:
1250 raise KeyError("No such msg(s): %r" % missing)
1251 raise KeyError("No such msg(s): %r" % missing)
1251 except KeyError:
1252 except KeyError:
1252 return finish(error.wrap_exception())
1253 return finish(error.wrap_exception())
1253 elif pending_ids:
1254 elif pending_ids:
1254 pass
1255 pass
1255 # no need to raise on resubmit of pending task, now that we
1256 # no need to raise on resubmit of pending task, now that we
1256 # resubmit under new ID, but do we want to raise anyway?
1257 # resubmit under new ID, but do we want to raise anyway?
1257 # msg_id = invalid_ids[0]
1258 # msg_id = invalid_ids[0]
1258 # try:
1259 # try:
1259 # raise ValueError("Task(s) %r appears to be inflight" % )
1260 # raise ValueError("Task(s) %r appears to be inflight" % )
1260 # except Exception:
1261 # except Exception:
1261 # return finish(error.wrap_exception())
1262 # return finish(error.wrap_exception())
1262
1263
1263 # mapping of original IDs to resubmitted IDs
1264 # mapping of original IDs to resubmitted IDs
1264 resubmitted = {}
1265 resubmitted = {}
1265
1266
1266 # send the messages
1267 # send the messages
1267 for rec in records:
1268 for rec in records:
1268 header = rec['header']
1269 header = rec['header']
1269 msg = self.session.msg(header['msg_type'], parent=header)
1270 msg = self.session.msg(header['msg_type'], parent=header)
1270 msg_id = msg['msg_id']
1271 msg_id = msg['msg_id']
1271 msg['content'] = rec['content']
1272 msg['content'] = rec['content']
1272
1273
1273 # use the old header, but update msg_id and timestamp
1274 # use the old header, but update msg_id and timestamp
1274 fresh = msg['header']
1275 fresh = msg['header']
1275 header['msg_id'] = fresh['msg_id']
1276 header['msg_id'] = fresh['msg_id']
1276 header['date'] = fresh['date']
1277 header['date'] = fresh['date']
1277 msg['header'] = header
1278 msg['header'] = header
1278
1279
1279 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1280 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1280
1281
1281 resubmitted[rec['msg_id']] = msg_id
1282 resubmitted[rec['msg_id']] = msg_id
1282 self.pending.add(msg_id)
1283 self.pending.add(msg_id)
1283 msg['buffers'] = rec['buffers']
1284 msg['buffers'] = rec['buffers']
1284 try:
1285 try:
1285 self.db.add_record(msg_id, init_record(msg))
1286 self.db.add_record(msg_id, init_record(msg))
1286 except Exception:
1287 except Exception:
1287 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1288 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1288 return finish(error.wrap_exception())
1289 return finish(error.wrap_exception())
1289
1290
1290 finish(dict(status='ok', resubmitted=resubmitted))
1291 finish(dict(status='ok', resubmitted=resubmitted))
1291
1292
1292 # store the new IDs in the Task DB
1293 # store the new IDs in the Task DB
1293 for msg_id, resubmit_id in iteritems(resubmitted):
1294 for msg_id, resubmit_id in iteritems(resubmitted):
1294 try:
1295 try:
1295 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1296 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1296 except Exception:
1297 except Exception:
1297 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1298 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1298
1299
1299
1300
1300 def _extract_record(self, rec):
1301 def _extract_record(self, rec):
1301 """decompose a TaskRecord dict into subsection of reply for get_result"""
1302 """decompose a TaskRecord dict into subsection of reply for get_result"""
1302 io_dict = {}
1303 io_dict = {}
1303 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1304 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1304 io_dict[key] = rec[key]
1305 io_dict[key] = rec[key]
1305 content = {
1306 content = {
1306 'header': rec['header'],
1307 'header': rec['header'],
1307 'metadata': rec['metadata'],
1308 'metadata': rec['metadata'],
1308 'result_metadata': rec['result_metadata'],
1309 'result_metadata': rec['result_metadata'],
1309 'result_header' : rec['result_header'],
1310 'result_header' : rec['result_header'],
1310 'result_content': rec['result_content'],
1311 'result_content': rec['result_content'],
1311 'received' : rec['received'],
1312 'received' : rec['received'],
1312 'io' : io_dict,
1313 'io' : io_dict,
1313 }
1314 }
1314 if rec['result_buffers']:
1315 if rec['result_buffers']:
1315 buffers = list(map(bytes, rec['result_buffers']))
1316 buffers = list(map(bytes, rec['result_buffers']))
1316 else:
1317 else:
1317 buffers = []
1318 buffers = []
1318
1319
1319 return content, buffers
1320 return content, buffers
1320
1321
1321 def get_results(self, client_id, msg):
1322 def get_results(self, client_id, msg):
1322 """Get the result of 1 or more messages."""
1323 """Get the result of 1 or more messages."""
1323 content = msg['content']
1324 content = msg['content']
1324 msg_ids = sorted(set(content['msg_ids']))
1325 msg_ids = sorted(set(content['msg_ids']))
1325 statusonly = content.get('status_only', False)
1326 statusonly = content.get('status_only', False)
1326 pending = []
1327 pending = []
1327 completed = []
1328 completed = []
1328 content = dict(status='ok')
1329 content = dict(status='ok')
1329 content['pending'] = pending
1330 content['pending'] = pending
1330 content['completed'] = completed
1331 content['completed'] = completed
1331 buffers = []
1332 buffers = []
1332 if not statusonly:
1333 if not statusonly:
1333 try:
1334 try:
1334 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1335 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1335 # turn match list into dict, for faster lookup
1336 # turn match list into dict, for faster lookup
1336 records = {}
1337 records = {}
1337 for rec in matches:
1338 for rec in matches:
1338 records[rec['msg_id']] = rec
1339 records[rec['msg_id']] = rec
1339 except Exception:
1340 except Exception:
1340 content = error.wrap_exception()
1341 content = error.wrap_exception()
1341 self.session.send(self.query, "result_reply", content=content,
1342 self.session.send(self.query, "result_reply", content=content,
1342 parent=msg, ident=client_id)
1343 parent=msg, ident=client_id)
1343 return
1344 return
1344 else:
1345 else:
1345 records = {}
1346 records = {}
1346 for msg_id in msg_ids:
1347 for msg_id in msg_ids:
1347 if msg_id in self.pending:
1348 if msg_id in self.pending:
1348 pending.append(msg_id)
1349 pending.append(msg_id)
1349 elif msg_id in self.all_completed:
1350 elif msg_id in self.all_completed:
1350 completed.append(msg_id)
1351 completed.append(msg_id)
1351 if not statusonly:
1352 if not statusonly:
1352 c,bufs = self._extract_record(records[msg_id])
1353 c,bufs = self._extract_record(records[msg_id])
1353 content[msg_id] = c
1354 content[msg_id] = c
1354 buffers.extend(bufs)
1355 buffers.extend(bufs)
1355 elif msg_id in records:
1356 elif msg_id in records:
1356 if rec['completed']:
1357 if rec['completed']:
1357 completed.append(msg_id)
1358 completed.append(msg_id)
1358 c,bufs = self._extract_record(records[msg_id])
1359 c,bufs = self._extract_record(records[msg_id])
1359 content[msg_id] = c
1360 content[msg_id] = c
1360 buffers.extend(bufs)
1361 buffers.extend(bufs)
1361 else:
1362 else:
1362 pending.append(msg_id)
1363 pending.append(msg_id)
1363 else:
1364 else:
1364 try:
1365 try:
1365 raise KeyError('No such message: '+msg_id)
1366 raise KeyError('No such message: '+msg_id)
1366 except:
1367 except:
1367 content = error.wrap_exception()
1368 content = error.wrap_exception()
1368 break
1369 break
1369 self.session.send(self.query, "result_reply", content=content,
1370 self.session.send(self.query, "result_reply", content=content,
1370 parent=msg, ident=client_id,
1371 parent=msg, ident=client_id,
1371 buffers=buffers)
1372 buffers=buffers)
1372
1373
1373 def get_history(self, client_id, msg):
1374 def get_history(self, client_id, msg):
1374 """Get a list of all msg_ids in our DB records"""
1375 """Get a list of all msg_ids in our DB records"""
1375 try:
1376 try:
1376 msg_ids = self.db.get_history()
1377 msg_ids = self.db.get_history()
1377 except Exception as e:
1378 except Exception as e:
1378 content = error.wrap_exception()
1379 content = error.wrap_exception()
1379 else:
1380 else:
1380 content = dict(status='ok', history=msg_ids)
1381 content = dict(status='ok', history=msg_ids)
1381
1382
1382 self.session.send(self.query, "history_reply", content=content,
1383 self.session.send(self.query, "history_reply", content=content,
1383 parent=msg, ident=client_id)
1384 parent=msg, ident=client_id)
1384
1385
1385 def db_query(self, client_id, msg):
1386 def db_query(self, client_id, msg):
1386 """Perform a raw query on the task record database."""
1387 """Perform a raw query on the task record database."""
1387 content = msg['content']
1388 content = msg['content']
1388 query = content.get('query', {})
1389 query = extract_dates(content.get('query', {}))
1389 keys = content.get('keys', None)
1390 keys = content.get('keys', None)
1390 buffers = []
1391 buffers = []
1391 empty = list()
1392 empty = list()
1392 try:
1393 try:
1393 records = self.db.find_records(query, keys)
1394 records = self.db.find_records(query, keys)
1394 except Exception as e:
1395 except Exception as e:
1395 content = error.wrap_exception()
1396 content = error.wrap_exception()
1396 else:
1397 else:
1397 # extract buffers from reply content:
1398 # extract buffers from reply content:
1398 if keys is not None:
1399 if keys is not None:
1399 buffer_lens = [] if 'buffers' in keys else None
1400 buffer_lens = [] if 'buffers' in keys else None
1400 result_buffer_lens = [] if 'result_buffers' in keys else None
1401 result_buffer_lens = [] if 'result_buffers' in keys else None
1401 else:
1402 else:
1402 buffer_lens = None
1403 buffer_lens = None
1403 result_buffer_lens = None
1404 result_buffer_lens = None
1404
1405
1405 for rec in records:
1406 for rec in records:
1406 # buffers may be None, so double check
1407 # buffers may be None, so double check
1407 b = rec.pop('buffers', empty) or empty
1408 b = rec.pop('buffers', empty) or empty
1408 if buffer_lens is not None:
1409 if buffer_lens is not None:
1409 buffer_lens.append(len(b))
1410 buffer_lens.append(len(b))
1410 buffers.extend(b)
1411 buffers.extend(b)
1411 rb = rec.pop('result_buffers', empty) or empty
1412 rb = rec.pop('result_buffers', empty) or empty
1412 if result_buffer_lens is not None:
1413 if result_buffer_lens is not None:
1413 result_buffer_lens.append(len(rb))
1414 result_buffer_lens.append(len(rb))
1414 buffers.extend(rb)
1415 buffers.extend(rb)
1415 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1416 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1416 result_buffer_lens=result_buffer_lens)
1417 result_buffer_lens=result_buffer_lens)
1417 # self.log.debug (content)
1418 # self.log.debug (content)
1418 self.session.send(self.query, "db_reply", content=content,
1419 self.session.send(self.query, "db_reply", content=content,
1419 parent=msg, ident=client_id,
1420 parent=msg, ident=client_id,
1420 buffers=buffers)
1421 buffers=buffers)
1421
1422
@@ -1,229 +1,241
1 """Utilities to manipulate JSON objects.
1 """Utilities to manipulate JSON objects.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (C) 2010-2011 The IPython Development Team
4 # Copyright (C) 2010-2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING.txt, distributed as part of this software.
7 # the file COPYING.txt, distributed as part of this software.
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # stdlib
13 # stdlib
14 import math
14 import math
15 import re
15 import re
16 import types
16 import types
17 from datetime import datetime
17 from datetime import datetime
18
18
19 try:
19 try:
20 # base64.encodestring is deprecated in Python 3.x
20 # base64.encodestring is deprecated in Python 3.x
21 from base64 import encodebytes
21 from base64 import encodebytes
22 except ImportError:
22 except ImportError:
23 # Python 2.x
23 # Python 2.x
24 from base64 import encodestring as encodebytes
24 from base64 import encodestring as encodebytes
25
25
26 from IPython.utils import py3compat
26 from IPython.utils import py3compat
27 from IPython.utils.py3compat import string_types, unicode_type, iteritems
27 from IPython.utils.py3compat import string_types, unicode_type, iteritems
28 from IPython.utils.encoding import DEFAULT_ENCODING
28 from IPython.utils.encoding import DEFAULT_ENCODING
29 next_attr_name = '__next__' if py3compat.PY3 else 'next'
29 next_attr_name = '__next__' if py3compat.PY3 else 'next'
30
30
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32 # Globals and constants
32 # Globals and constants
33 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
34
34
35 # timestamp formats
35 # timestamp formats
36 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
36 ISO8601 = "%Y-%m-%dT%H:%M:%S.%f"
37 ISO8601_PAT=re.compile(r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+)Z?([\+\-]\d{2}:?\d{2})?$")
37 ISO8601_PAT=re.compile(r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6})Z?([\+\-]\d{2}:?\d{2})?$")
38
38
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40 # Classes and functions
40 # Classes and functions
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42
42
43 def rekey(dikt):
43 def rekey(dikt):
44 """Rekey a dict that has been forced to use str keys where there should be
44 """Rekey a dict that has been forced to use str keys where there should be
45 ints by json."""
45 ints by json."""
46 for k in dikt:
46 for k in dikt:
47 if isinstance(k, string_types):
47 if isinstance(k, string_types):
48 ik=fk=None
48 ik=fk=None
49 try:
49 try:
50 ik = int(k)
50 ik = int(k)
51 except ValueError:
51 except ValueError:
52 try:
52 try:
53 fk = float(k)
53 fk = float(k)
54 except ValueError:
54 except ValueError:
55 continue
55 continue
56 if ik is not None:
56 if ik is not None:
57 nk = ik
57 nk = ik
58 else:
58 else:
59 nk = fk
59 nk = fk
60 if nk in dikt:
60 if nk in dikt:
61 raise KeyError("already have key %r"%nk)
61 raise KeyError("already have key %r"%nk)
62 dikt[nk] = dikt.pop(k)
62 dikt[nk] = dikt.pop(k)
63 return dikt
63 return dikt
64
64
65 def parse_date(s):
66 """parse an ISO8601 date string
67
68 If it is None or not a valid ISO8601 timestamp,
69 it will be returned unmodified.
70 Otherwise, it will return a datetime object.
71 """
72 if s is None:
73 return s
74 m = ISO8601_PAT.match(s)
75 if m:
76 # FIXME: add actual timezone support
77 # this just drops the timezone info
78 notz = m.groups()[0]
79 return datetime.strptime(notz, ISO8601)
80 return s
65
81
66 def extract_dates(obj):
82 def extract_dates(obj):
67 """extract ISO8601 dates from unpacked JSON"""
83 """extract ISO8601 dates from unpacked JSON"""
68 if isinstance(obj, dict):
84 if isinstance(obj, dict):
69 obj = dict(obj) # don't clobber
85 new_obj = {} # don't clobber
70 for k,v in iteritems(obj):
86 for k,v in iteritems(obj):
71 obj[k] = extract_dates(v)
87 new_obj[k] = extract_dates(v)
88 obj = new_obj
72 elif isinstance(obj, (list, tuple)):
89 elif isinstance(obj, (list, tuple)):
73 obj = [ extract_dates(o) for o in obj ]
90 obj = [ extract_dates(o) for o in obj ]
74 elif isinstance(obj, string_types):
91 elif isinstance(obj, string_types):
75 m = ISO8601_PAT.match(obj)
92 obj = parse_date(obj)
76 if m:
77 # FIXME: add actual timezone support
78 # this just drops the timezone info
79 notz = m.groups()[0]
80 obj = datetime.strptime(notz, ISO8601)
81 return obj
93 return obj
82
94
83 def squash_dates(obj):
95 def squash_dates(obj):
84 """squash datetime objects into ISO8601 strings"""
96 """squash datetime objects into ISO8601 strings"""
85 if isinstance(obj, dict):
97 if isinstance(obj, dict):
86 obj = dict(obj) # don't clobber
98 obj = dict(obj) # don't clobber
87 for k,v in iteritems(obj):
99 for k,v in iteritems(obj):
88 obj[k] = squash_dates(v)
100 obj[k] = squash_dates(v)
89 elif isinstance(obj, (list, tuple)):
101 elif isinstance(obj, (list, tuple)):
90 obj = [ squash_dates(o) for o in obj ]
102 obj = [ squash_dates(o) for o in obj ]
91 elif isinstance(obj, datetime):
103 elif isinstance(obj, datetime):
92 obj = obj.isoformat()
104 obj = obj.isoformat()
93 return obj
105 return obj
94
106
95 def date_default(obj):
107 def date_default(obj):
96 """default function for packing datetime objects in JSON."""
108 """default function for packing datetime objects in JSON."""
97 if isinstance(obj, datetime):
109 if isinstance(obj, datetime):
98 return obj.isoformat()
110 return obj.isoformat()
99 else:
111 else:
100 raise TypeError("%r is not JSON serializable"%obj)
112 raise TypeError("%r is not JSON serializable"%obj)
101
113
102
114
103 # constants for identifying png/jpeg data
115 # constants for identifying png/jpeg data
104 PNG = b'\x89PNG\r\n\x1a\n'
116 PNG = b'\x89PNG\r\n\x1a\n'
105 # front of PNG base64-encoded
117 # front of PNG base64-encoded
106 PNG64 = b'iVBORw0KG'
118 PNG64 = b'iVBORw0KG'
107 JPEG = b'\xff\xd8'
119 JPEG = b'\xff\xd8'
108 # front of JPEG base64-encoded
120 # front of JPEG base64-encoded
109 JPEG64 = b'/9'
121 JPEG64 = b'/9'
110
122
111 def encode_images(format_dict):
123 def encode_images(format_dict):
112 """b64-encodes images in a displaypub format dict
124 """b64-encodes images in a displaypub format dict
113
125
114 Perhaps this should be handled in json_clean itself?
126 Perhaps this should be handled in json_clean itself?
115
127
116 Parameters
128 Parameters
117 ----------
129 ----------
118
130
119 format_dict : dict
131 format_dict : dict
120 A dictionary of display data keyed by mime-type
132 A dictionary of display data keyed by mime-type
121
133
122 Returns
134 Returns
123 -------
135 -------
124
136
125 format_dict : dict
137 format_dict : dict
126 A copy of the same dictionary,
138 A copy of the same dictionary,
127 but binary image data ('image/png' or 'image/jpeg')
139 but binary image data ('image/png' or 'image/jpeg')
128 is base64-encoded.
140 is base64-encoded.
129
141
130 """
142 """
131 encoded = format_dict.copy()
143 encoded = format_dict.copy()
132
144
133 pngdata = format_dict.get('image/png')
145 pngdata = format_dict.get('image/png')
134 if isinstance(pngdata, bytes):
146 if isinstance(pngdata, bytes):
135 # make sure we don't double-encode
147 # make sure we don't double-encode
136 if not pngdata.startswith(PNG64):
148 if not pngdata.startswith(PNG64):
137 pngdata = encodebytes(pngdata)
149 pngdata = encodebytes(pngdata)
138 encoded['image/png'] = pngdata.decode('ascii')
150 encoded['image/png'] = pngdata.decode('ascii')
139
151
140 jpegdata = format_dict.get('image/jpeg')
152 jpegdata = format_dict.get('image/jpeg')
141 if isinstance(jpegdata, bytes):
153 if isinstance(jpegdata, bytes):
142 # make sure we don't double-encode
154 # make sure we don't double-encode
143 if not jpegdata.startswith(JPEG64):
155 if not jpegdata.startswith(JPEG64):
144 jpegdata = encodebytes(jpegdata)
156 jpegdata = encodebytes(jpegdata)
145 encoded['image/jpeg'] = jpegdata.decode('ascii')
157 encoded['image/jpeg'] = jpegdata.decode('ascii')
146
158
147 return encoded
159 return encoded
148
160
149
161
150 def json_clean(obj):
162 def json_clean(obj):
151 """Clean an object to ensure it's safe to encode in JSON.
163 """Clean an object to ensure it's safe to encode in JSON.
152
164
153 Atomic, immutable objects are returned unmodified. Sets and tuples are
165 Atomic, immutable objects are returned unmodified. Sets and tuples are
154 converted to lists, lists are copied and dicts are also copied.
166 converted to lists, lists are copied and dicts are also copied.
155
167
156 Note: dicts whose keys could cause collisions upon encoding (such as a dict
168 Note: dicts whose keys could cause collisions upon encoding (such as a dict
157 with both the number 1 and the string '1' as keys) will cause a ValueError
169 with both the number 1 and the string '1' as keys) will cause a ValueError
158 to be raised.
170 to be raised.
159
171
160 Parameters
172 Parameters
161 ----------
173 ----------
162 obj : any python object
174 obj : any python object
163
175
164 Returns
176 Returns
165 -------
177 -------
166 out : object
178 out : object
167
179
168 A version of the input which will not cause an encoding error when
180 A version of the input which will not cause an encoding error when
169 encoded as JSON. Note that this function does not *encode* its inputs,
181 encoded as JSON. Note that this function does not *encode* its inputs,
170 it simply sanitizes it so that there will be no encoding errors later.
182 it simply sanitizes it so that there will be no encoding errors later.
171
183
172 Examples
184 Examples
173 --------
185 --------
174 >>> json_clean(4)
186 >>> json_clean(4)
175 4
187 4
176 >>> json_clean(list(range(10)))
188 >>> json_clean(list(range(10)))
177 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
189 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
178 >>> sorted(json_clean(dict(x=1, y=2)).items())
190 >>> sorted(json_clean(dict(x=1, y=2)).items())
179 [('x', 1), ('y', 2)]
191 [('x', 1), ('y', 2)]
180 >>> sorted(json_clean(dict(x=1, y=2, z=[1,2,3])).items())
192 >>> sorted(json_clean(dict(x=1, y=2, z=[1,2,3])).items())
181 [('x', 1), ('y', 2), ('z', [1, 2, 3])]
193 [('x', 1), ('y', 2), ('z', [1, 2, 3])]
182 >>> json_clean(True)
194 >>> json_clean(True)
183 True
195 True
184 """
196 """
185 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
197 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
186 # listed explicitly because bools pass as int instances
198 # listed explicitly because bools pass as int instances
187 atomic_ok = (unicode_type, int, type(None))
199 atomic_ok = (unicode_type, int, type(None))
188
200
189 # containers that we need to convert into lists
201 # containers that we need to convert into lists
190 container_to_list = (tuple, set, types.GeneratorType)
202 container_to_list = (tuple, set, types.GeneratorType)
191
203
192 if isinstance(obj, float):
204 if isinstance(obj, float):
193 # cast out-of-range floats to their reprs
205 # cast out-of-range floats to their reprs
194 if math.isnan(obj) or math.isinf(obj):
206 if math.isnan(obj) or math.isinf(obj):
195 return repr(obj)
207 return repr(obj)
196 return obj
208 return obj
197
209
198 if isinstance(obj, atomic_ok):
210 if isinstance(obj, atomic_ok):
199 return obj
211 return obj
200
212
201 if isinstance(obj, bytes):
213 if isinstance(obj, bytes):
202 return obj.decode(DEFAULT_ENCODING, 'replace')
214 return obj.decode(DEFAULT_ENCODING, 'replace')
203
215
204 if isinstance(obj, container_to_list) or (
216 if isinstance(obj, container_to_list) or (
205 hasattr(obj, '__iter__') and hasattr(obj, next_attr_name)):
217 hasattr(obj, '__iter__') and hasattr(obj, next_attr_name)):
206 obj = list(obj)
218 obj = list(obj)
207
219
208 if isinstance(obj, list):
220 if isinstance(obj, list):
209 return [json_clean(x) for x in obj]
221 return [json_clean(x) for x in obj]
210
222
211 if isinstance(obj, dict):
223 if isinstance(obj, dict):
212 # First, validate that the dict won't lose data in conversion due to
224 # First, validate that the dict won't lose data in conversion due to
213 # key collisions after stringification. This can happen with keys like
225 # key collisions after stringification. This can happen with keys like
214 # True and 'true' or 1 and '1', which collide in JSON.
226 # True and 'true' or 1 and '1', which collide in JSON.
215 nkeys = len(obj)
227 nkeys = len(obj)
216 nkeys_collapsed = len(set(map(str, obj)))
228 nkeys_collapsed = len(set(map(str, obj)))
217 if nkeys != nkeys_collapsed:
229 if nkeys != nkeys_collapsed:
218 raise ValueError('dict can not be safely converted to JSON: '
230 raise ValueError('dict can not be safely converted to JSON: '
219 'key collision would lead to dropped values')
231 'key collision would lead to dropped values')
220 # If all OK, proceed by making the new dict that will be json-safe
232 # If all OK, proceed by making the new dict that will be json-safe
221 out = {}
233 out = {}
222 for k,v in iteritems(obj):
234 for k,v in iteritems(obj):
223 out[str(k)] = json_clean(v)
235 out[str(k)] = json_clean(v)
224 return out
236 return out
225
237
226 # If we get here, we don't know how to handle the object, so we just get
238 # If we get here, we don't know how to handle the object, so we just get
227 # its repr and return that. This will catch lambdas, open sockets, class
239 # its repr and return that. This will catch lambdas, open sockets, class
228 # objects, and any other complicated contraption that json can't encode
240 # objects, and any other complicated contraption that json can't encode
229 return repr(obj)
241 return repr(obj)
@@ -1,131 +1,143
1 """Test suite for our JSON utilities.
1 """Test suite for our JSON utilities.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (C) 2010-2011 The IPython Development Team
4 # Copyright (C) 2010-2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING.txt, distributed as part of this software.
7 # the file COPYING.txt, distributed as part of this software.
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # stdlib
13 # stdlib
14 import datetime
14 import datetime
15 import json
15 import json
16 from base64 import decodestring
16 from base64 import decodestring
17
17
18 # third party
18 # third party
19 import nose.tools as nt
19 import nose.tools as nt
20
20
21 # our own
21 # our own
22 from IPython.utils import jsonutil, tz
22 from IPython.utils import jsonutil, tz
23 from ..jsonutil import json_clean, encode_images
23 from ..jsonutil import json_clean, encode_images
24 from ..py3compat import unicode_to_str, str_to_bytes, iteritems
24 from ..py3compat import unicode_to_str, str_to_bytes, iteritems
25
25
26 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
27 # Test functions
27 # Test functions
28 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
29
29
30 def test():
30 def test():
31 # list of input/expected output. Use None for the expected output if it
31 # list of input/expected output. Use None for the expected output if it
32 # can be the same as the input.
32 # can be the same as the input.
33 pairs = [(1, None), # start with scalars
33 pairs = [(1, None), # start with scalars
34 (1.0, None),
34 (1.0, None),
35 ('a', None),
35 ('a', None),
36 (True, None),
36 (True, None),
37 (False, None),
37 (False, None),
38 (None, None),
38 (None, None),
39 # complex numbers for now just go to strings, as otherwise they
39 # complex numbers for now just go to strings, as otherwise they
40 # are unserializable
40 # are unserializable
41 (1j, '1j'),
41 (1j, '1j'),
42 # Containers
42 # Containers
43 ([1, 2], None),
43 ([1, 2], None),
44 ((1, 2), [1, 2]),
44 ((1, 2), [1, 2]),
45 (set([1, 2]), [1, 2]),
45 (set([1, 2]), [1, 2]),
46 (dict(x=1), None),
46 (dict(x=1), None),
47 ({'x': 1, 'y':[1,2,3], '1':'int'}, None),
47 ({'x': 1, 'y':[1,2,3], '1':'int'}, None),
48 # More exotic objects
48 # More exotic objects
49 ((x for x in range(3)), [0, 1, 2]),
49 ((x for x in range(3)), [0, 1, 2]),
50 (iter([1, 2]), [1, 2]),
50 (iter([1, 2]), [1, 2]),
51 ]
51 ]
52
52
53 for val, jval in pairs:
53 for val, jval in pairs:
54 if jval is None:
54 if jval is None:
55 jval = val
55 jval = val
56 out = json_clean(val)
56 out = json_clean(val)
57 # validate our cleanup
57 # validate our cleanup
58 nt.assert_equal(out, jval)
58 nt.assert_equal(out, jval)
59 # and ensure that what we return, indeed encodes cleanly
59 # and ensure that what we return, indeed encodes cleanly
60 json.loads(json.dumps(out))
60 json.loads(json.dumps(out))
61
61
62
62
63
63
64 def test_encode_images():
64 def test_encode_images():
65 # invalid data, but the header and footer are from real files
65 # invalid data, but the header and footer are from real files
66 pngdata = b'\x89PNG\r\n\x1a\nblahblahnotactuallyvalidIEND\xaeB`\x82'
66 pngdata = b'\x89PNG\r\n\x1a\nblahblahnotactuallyvalidIEND\xaeB`\x82'
67 jpegdata = b'\xff\xd8\xff\xe0\x00\x10JFIFblahblahjpeg(\xa0\x0f\xff\xd9'
67 jpegdata = b'\xff\xd8\xff\xe0\x00\x10JFIFblahblahjpeg(\xa0\x0f\xff\xd9'
68
68
69 fmt = {
69 fmt = {
70 'image/png' : pngdata,
70 'image/png' : pngdata,
71 'image/jpeg' : jpegdata,
71 'image/jpeg' : jpegdata,
72 }
72 }
73 encoded = encode_images(fmt)
73 encoded = encode_images(fmt)
74 for key, value in iteritems(fmt):
74 for key, value in iteritems(fmt):
75 # encoded has unicode, want bytes
75 # encoded has unicode, want bytes
76 decoded = decodestring(encoded[key].encode('ascii'))
76 decoded = decodestring(encoded[key].encode('ascii'))
77 nt.assert_equal(decoded, value)
77 nt.assert_equal(decoded, value)
78 encoded2 = encode_images(encoded)
78 encoded2 = encode_images(encoded)
79 nt.assert_equal(encoded, encoded2)
79 nt.assert_equal(encoded, encoded2)
80
80
81 b64_str = {}
81 b64_str = {}
82 for key, encoded in iteritems(encoded):
82 for key, encoded in iteritems(encoded):
83 b64_str[key] = unicode_to_str(encoded)
83 b64_str[key] = unicode_to_str(encoded)
84 encoded3 = encode_images(b64_str)
84 encoded3 = encode_images(b64_str)
85 nt.assert_equal(encoded3, b64_str)
85 nt.assert_equal(encoded3, b64_str)
86 for key, value in iteritems(fmt):
86 for key, value in iteritems(fmt):
87 # encoded3 has str, want bytes
87 # encoded3 has str, want bytes
88 decoded = decodestring(str_to_bytes(encoded3[key]))
88 decoded = decodestring(str_to_bytes(encoded3[key]))
89 nt.assert_equal(decoded, value)
89 nt.assert_equal(decoded, value)
90
90
91 def test_lambda():
91 def test_lambda():
92 jc = json_clean(lambda : 1)
92 jc = json_clean(lambda : 1)
93 assert isinstance(jc, str)
93 assert isinstance(jc, str)
94 assert '<lambda>' in jc
94 assert '<lambda>' in jc
95 json.dumps(jc)
95 json.dumps(jc)
96
96
97 def test_extract_dates():
97 def test_extract_dates():
98 timestamps = [
98 timestamps = [
99 '2013-07-03T16:34:52.249482',
99 '2013-07-03T16:34:52.249482',
100 '2013-07-03T16:34:52.249482Z',
100 '2013-07-03T16:34:52.249482Z',
101 '2013-07-03T16:34:52.249482Z-0800',
101 '2013-07-03T16:34:52.249482Z-0800',
102 '2013-07-03T16:34:52.249482Z+0800',
102 '2013-07-03T16:34:52.249482Z+0800',
103 '2013-07-03T16:34:52.249482Z+08:00',
103 '2013-07-03T16:34:52.249482Z+08:00',
104 '2013-07-03T16:34:52.249482Z-08:00',
104 '2013-07-03T16:34:52.249482Z-08:00',
105 '2013-07-03T16:34:52.249482-0800',
105 '2013-07-03T16:34:52.249482-0800',
106 '2013-07-03T16:34:52.249482+0800',
106 '2013-07-03T16:34:52.249482+0800',
107 '2013-07-03T16:34:52.249482+08:00',
107 '2013-07-03T16:34:52.249482+08:00',
108 '2013-07-03T16:34:52.249482-08:00',
108 '2013-07-03T16:34:52.249482-08:00',
109 ]
109 ]
110 extracted = jsonutil.extract_dates(timestamps)
110 extracted = jsonutil.extract_dates(timestamps)
111 ref = extracted[0]
111 ref = extracted[0]
112 for dt in extracted:
112 for dt in extracted:
113 nt.assert_true(isinstance(dt, datetime.datetime))
113 nt.assert_true(isinstance(dt, datetime.datetime))
114 nt.assert_equal(dt, ref)
114 nt.assert_equal(dt, ref)
115
115
116 def test_parse_ms_precision():
117 base = '2013-07-03T16:34:52.'
118 digits = '1234567890'
119
120 for i in range(len(digits)):
121 ts = base + digits[:i]
122 parsed = jsonutil.parse_date(ts)
123 if i >= 1 and i <= 6:
124 assert isinstance(parsed, datetime.datetime)
125 else:
126 assert isinstance(parsed, str)
127
116 def test_date_default():
128 def test_date_default():
117 data = dict(today=datetime.datetime.now(), utcnow=tz.utcnow())
129 data = dict(today=datetime.datetime.now(), utcnow=tz.utcnow())
118 jsondata = json.dumps(data, default=jsonutil.date_default)
130 jsondata = json.dumps(data, default=jsonutil.date_default)
119 nt.assert_in("+00", jsondata)
131 nt.assert_in("+00", jsondata)
120 nt.assert_equal(jsondata.count("+00"), 1)
132 nt.assert_equal(jsondata.count("+00"), 1)
121 extracted = jsonutil.extract_dates(json.loads(jsondata))
133 extracted = jsonutil.extract_dates(json.loads(jsondata))
122 for dt in extracted.values():
134 for dt in extracted.values():
123 nt.assert_true(isinstance(dt, datetime.datetime))
135 nt.assert_true(isinstance(dt, datetime.datetime))
124
136
125 def test_exception():
137 def test_exception():
126 bad_dicts = [{1:'number', '1':'string'},
138 bad_dicts = [{1:'number', '1':'string'},
127 {True:'bool', 'True':'string'},
139 {True:'bool', 'True':'string'},
128 ]
140 ]
129 for d in bad_dicts:
141 for d in bad_dicts:
130 nt.assert_raises(ValueError, json_clean, d)
142 nt.assert_raises(ValueError, json_clean, d)
131
143
General Comments 0
You need to be logged in to leave comments. Login now