##// END OF EJS Templates
Merge pull request #4840 from dsblank/master...
Thomas Kluyver -
r14648:8858f7ee merge
parent child Browse files
Show More
@@ -1,850 +1,850 b''
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9
9
10 Authors:
10 Authors:
11
11
12 * Min RK
12 * Min RK
13 * Brian Granger
13 * Brian Granger
14 * Fernando Perez
14 * Fernando Perez
15 """
15 """
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Copyright (C) 2010-2011 The IPython Development Team
17 # Copyright (C) 2010-2011 The IPython Development Team
18 #
18 #
19 # Distributed under the terms of the BSD License. The full license is in
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
20 # the file COPYING, distributed as part of this software.
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22
22
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24 # Imports
24 # Imports
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26
26
27 import hashlib
27 import hashlib
28 import hmac
28 import hmac
29 import logging
29 import logging
30 import os
30 import os
31 import pprint
31 import pprint
32 import random
32 import random
33 import uuid
33 import uuid
34 from datetime import datetime
34 from datetime import datetime
35
35
36 try:
36 try:
37 import cPickle
37 import cPickle
38 pickle = cPickle
38 pickle = cPickle
39 except:
39 except:
40 cPickle = None
40 cPickle = None
41 import pickle
41 import pickle
42
42
43 import zmq
43 import zmq
44 from zmq.utils import jsonapi
44 from zmq.utils import jsonapi
45 from zmq.eventloop.ioloop import IOLoop
45 from zmq.eventloop.ioloop import IOLoop
46 from zmq.eventloop.zmqstream import ZMQStream
46 from zmq.eventloop.zmqstream import ZMQStream
47
47
48 from IPython.config.configurable import Configurable, LoggingConfigurable
48 from IPython.config.configurable import Configurable, LoggingConfigurable
49 from IPython.utils import io
49 from IPython.utils import io
50 from IPython.utils.importstring import import_item
50 from IPython.utils.importstring import import_item
51 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
51 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
52 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
52 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
53 iteritems)
53 iteritems)
54 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
54 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
55 DottedObjectName, CUnicode, Dict, Integer,
55 DottedObjectName, CUnicode, Dict, Integer,
56 TraitError,
56 TraitError,
57 )
57 )
58 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
58 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
59
59
60 #-----------------------------------------------------------------------------
60 #-----------------------------------------------------------------------------
61 # utility functions
61 # utility functions
62 #-----------------------------------------------------------------------------
62 #-----------------------------------------------------------------------------
63
63
64 def squash_unicode(obj):
64 def squash_unicode(obj):
65 """coerce unicode back to bytestrings."""
65 """coerce unicode back to bytestrings."""
66 if isinstance(obj,dict):
66 if isinstance(obj,dict):
67 for key in obj.keys():
67 for key in obj.keys():
68 obj[key] = squash_unicode(obj[key])
68 obj[key] = squash_unicode(obj[key])
69 if isinstance(key, unicode_type):
69 if isinstance(key, unicode_type):
70 obj[squash_unicode(key)] = obj.pop(key)
70 obj[squash_unicode(key)] = obj.pop(key)
71 elif isinstance(obj, list):
71 elif isinstance(obj, list):
72 for i,v in enumerate(obj):
72 for i,v in enumerate(obj):
73 obj[i] = squash_unicode(v)
73 obj[i] = squash_unicode(v)
74 elif isinstance(obj, unicode_type):
74 elif isinstance(obj, unicode_type):
75 obj = obj.encode('utf8')
75 obj = obj.encode('utf8')
76 return obj
76 return obj
77
77
78 #-----------------------------------------------------------------------------
78 #-----------------------------------------------------------------------------
79 # globals and defaults
79 # globals and defaults
80 #-----------------------------------------------------------------------------
80 #-----------------------------------------------------------------------------
81
81
82 # ISO8601-ify datetime objects
82 # ISO8601-ify datetime objects
83 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
83 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
84 json_unpacker = lambda s: jsonapi.loads(s)
84 json_unpacker = lambda s: jsonapi.loads(s)
85
85
86 pickle_packer = lambda o: pickle.dumps(squash_dates(o),-1)
86 pickle_packer = lambda o: pickle.dumps(squash_dates(o),-1)
87 pickle_unpacker = pickle.loads
87 pickle_unpacker = pickle.loads
88
88
89 default_packer = json_packer
89 default_packer = json_packer
90 default_unpacker = json_unpacker
90 default_unpacker = json_unpacker
91
91
92 DELIM = b"<IDS|MSG>"
92 DELIM = b"<IDS|MSG>"
93 # singleton dummy tracker, which will always report as done
93 # singleton dummy tracker, which will always report as done
94 DONE = zmq.MessageTracker()
94 DONE = zmq.MessageTracker()
95
95
96 #-----------------------------------------------------------------------------
96 #-----------------------------------------------------------------------------
97 # Mixin tools for apps that use Sessions
97 # Mixin tools for apps that use Sessions
98 #-----------------------------------------------------------------------------
98 #-----------------------------------------------------------------------------
99
99
100 session_aliases = dict(
100 session_aliases = dict(
101 ident = 'Session.session',
101 ident = 'Session.session',
102 user = 'Session.username',
102 user = 'Session.username',
103 keyfile = 'Session.keyfile',
103 keyfile = 'Session.keyfile',
104 )
104 )
105
105
106 session_flags = {
106 session_flags = {
107 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
107 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
108 'keyfile' : '' }},
108 'keyfile' : '' }},
109 """Use HMAC digests for authentication of messages.
109 """Use HMAC digests for authentication of messages.
110 Setting this flag will generate a new UUID to use as the HMAC key.
110 Setting this flag will generate a new UUID to use as the HMAC key.
111 """),
111 """),
112 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
112 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
113 """Don't authenticate messages."""),
113 """Don't authenticate messages."""),
114 }
114 }
115
115
116 def default_secure(cfg):
116 def default_secure(cfg):
117 """Set the default behavior for a config environment to be secure.
117 """Set the default behavior for a config environment to be secure.
118
118
119 If Session.key/keyfile have not been set, set Session.key to
119 If Session.key/keyfile have not been set, set Session.key to
120 a new random UUID.
120 a new random UUID.
121 """
121 """
122
122
123 if 'Session' in cfg:
123 if 'Session' in cfg:
124 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
124 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
125 return
125 return
126 # key/keyfile not specified, generate new UUID:
126 # key/keyfile not specified, generate new UUID:
127 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
127 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
128
128
129
129
130 #-----------------------------------------------------------------------------
130 #-----------------------------------------------------------------------------
131 # Classes
131 # Classes
132 #-----------------------------------------------------------------------------
132 #-----------------------------------------------------------------------------
133
133
134 class SessionFactory(LoggingConfigurable):
134 class SessionFactory(LoggingConfigurable):
135 """The Base class for configurables that have a Session, Context, logger,
135 """The Base class for configurables that have a Session, Context, logger,
136 and IOLoop.
136 and IOLoop.
137 """
137 """
138
138
139 logname = Unicode('')
139 logname = Unicode('')
140 def _logname_changed(self, name, old, new):
140 def _logname_changed(self, name, old, new):
141 self.log = logging.getLogger(new)
141 self.log = logging.getLogger(new)
142
142
143 # not configurable:
143 # not configurable:
144 context = Instance('zmq.Context')
144 context = Instance('zmq.Context')
145 def _context_default(self):
145 def _context_default(self):
146 return zmq.Context.instance()
146 return zmq.Context.instance()
147
147
148 session = Instance('IPython.kernel.zmq.session.Session')
148 session = Instance('IPython.kernel.zmq.session.Session')
149
149
150 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
150 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
151 def _loop_default(self):
151 def _loop_default(self):
152 return IOLoop.instance()
152 return IOLoop.instance()
153
153
154 def __init__(self, **kwargs):
154 def __init__(self, **kwargs):
155 super(SessionFactory, self).__init__(**kwargs)
155 super(SessionFactory, self).__init__(**kwargs)
156
156
157 if self.session is None:
157 if self.session is None:
158 # construct the session
158 # construct the session
159 self.session = Session(**kwargs)
159 self.session = Session(**kwargs)
160
160
161
161
162 class Message(object):
162 class Message(object):
163 """A simple message object that maps dict keys to attributes.
163 """A simple message object that maps dict keys to attributes.
164
164
165 A Message can be created from a dict and a dict from a Message instance
165 A Message can be created from a dict and a dict from a Message instance
166 simply by calling dict(msg_obj)."""
166 simply by calling dict(msg_obj)."""
167
167
168 def __init__(self, msg_dict):
168 def __init__(self, msg_dict):
169 dct = self.__dict__
169 dct = self.__dict__
170 for k, v in iteritems(dict(msg_dict)):
170 for k, v in iteritems(dict(msg_dict)):
171 if isinstance(v, dict):
171 if isinstance(v, dict):
172 v = Message(v)
172 v = Message(v)
173 dct[k] = v
173 dct[k] = v
174
174
175 # Having this iterator lets dict(msg_obj) work out of the box.
175 # Having this iterator lets dict(msg_obj) work out of the box.
176 def __iter__(self):
176 def __iter__(self):
177 return iter(iteritems(self.__dict__))
177 return iter(iteritems(self.__dict__))
178
178
179 def __repr__(self):
179 def __repr__(self):
180 return repr(self.__dict__)
180 return repr(self.__dict__)
181
181
182 def __str__(self):
182 def __str__(self):
183 return pprint.pformat(self.__dict__)
183 return pprint.pformat(self.__dict__)
184
184
185 def __contains__(self, k):
185 def __contains__(self, k):
186 return k in self.__dict__
186 return k in self.__dict__
187
187
188 def __getitem__(self, k):
188 def __getitem__(self, k):
189 return self.__dict__[k]
189 return self.__dict__[k]
190
190
191
191
192 def msg_header(msg_id, msg_type, username, session):
192 def msg_header(msg_id, msg_type, username, session):
193 date = datetime.now()
193 date = datetime.now()
194 return locals()
194 return locals()
195
195
196 def extract_header(msg_or_header):
196 def extract_header(msg_or_header):
197 """Given a message or header, return the header."""
197 """Given a message or header, return the header."""
198 if not msg_or_header:
198 if not msg_or_header:
199 return {}
199 return {}
200 try:
200 try:
201 # See if msg_or_header is the entire message.
201 # See if msg_or_header is the entire message.
202 h = msg_or_header['header']
202 h = msg_or_header['header']
203 except KeyError:
203 except KeyError:
204 try:
204 try:
205 # See if msg_or_header is just the header
205 # See if msg_or_header is just the header
206 h = msg_or_header['msg_id']
206 h = msg_or_header['msg_id']
207 except KeyError:
207 except KeyError:
208 raise
208 raise
209 else:
209 else:
210 h = msg_or_header
210 h = msg_or_header
211 if not isinstance(h, dict):
211 if not isinstance(h, dict):
212 h = dict(h)
212 h = dict(h)
213 return h
213 return h
214
214
215 class Session(Configurable):
215 class Session(Configurable):
216 """Object for handling serialization and sending of messages.
216 """Object for handling serialization and sending of messages.
217
217
218 The Session object handles building messages and sending them
218 The Session object handles building messages and sending them
219 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
219 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
220 other over the network via Session objects, and only need to work with the
220 other over the network via Session objects, and only need to work with the
221 dict-based IPython message spec. The Session will handle
221 dict-based IPython message spec. The Session will handle
222 serialization/deserialization, security, and metadata.
222 serialization/deserialization, security, and metadata.
223
223
224 Sessions support configurable serialiization via packer/unpacker traits,
224 Sessions support configurable serialiization via packer/unpacker traits,
225 and signing with HMAC digests via the key/keyfile traits.
225 and signing with HMAC digests via the key/keyfile traits.
226
226
227 Parameters
227 Parameters
228 ----------
228 ----------
229
229
230 debug : bool
230 debug : bool
231 whether to trigger extra debugging statements
231 whether to trigger extra debugging statements
232 packer/unpacker : str : 'json', 'pickle' or import_string
232 packer/unpacker : str : 'json', 'pickle' or import_string
233 importstrings for methods to serialize message parts. If just
233 importstrings for methods to serialize message parts. If just
234 'json' or 'pickle', predefined JSON and pickle packers will be used.
234 'json' or 'pickle', predefined JSON and pickle packers will be used.
235 Otherwise, the entire importstring must be used.
235 Otherwise, the entire importstring must be used.
236
236
237 The functions must accept at least valid JSON input, and output *bytes*.
237 The functions must accept at least valid JSON input, and output *bytes*.
238
238
239 For example, to use msgpack:
239 For example, to use msgpack:
240 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
240 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
241 pack/unpack : callables
241 pack/unpack : callables
242 You can also set the pack/unpack callables for serialization directly.
242 You can also set the pack/unpack callables for serialization directly.
243 session : bytes
243 session : bytes
244 the ID of this Session object. The default is to generate a new UUID.
244 the ID of this Session object. The default is to generate a new UUID.
245 username : unicode
245 username : unicode
246 username added to message headers. The default is to ask the OS.
246 username added to message headers. The default is to ask the OS.
247 key : bytes
247 key : bytes
248 The key used to initialize an HMAC signature. If unset, messages
248 The key used to initialize an HMAC signature. If unset, messages
249 will not be signed or checked.
249 will not be signed or checked.
250 keyfile : filepath
250 keyfile : filepath
251 The file containing a key. If this is set, `key` will be initialized
251 The file containing a key. If this is set, `key` will be initialized
252 to the contents of the file.
252 to the contents of the file.
253
253
254 """
254 """
255
255
256 debug=Bool(False, config=True, help="""Debug output in the Session""")
256 debug=Bool(False, config=True, help="""Debug output in the Session""")
257
257
258 packer = DottedObjectName('json',config=True,
258 packer = DottedObjectName('json',config=True,
259 help="""The name of the packer for serializing messages.
259 help="""The name of the packer for serializing messages.
260 Should be one of 'json', 'pickle', or an import name
260 Should be one of 'json', 'pickle', or an import name
261 for a custom callable serializer.""")
261 for a custom callable serializer.""")
262 def _packer_changed(self, name, old, new):
262 def _packer_changed(self, name, old, new):
263 if new.lower() == 'json':
263 if new.lower() == 'json':
264 self.pack = json_packer
264 self.pack = json_packer
265 self.unpack = json_unpacker
265 self.unpack = json_unpacker
266 self.unpacker = new
266 self.unpacker = new
267 elif new.lower() == 'pickle':
267 elif new.lower() == 'pickle':
268 self.pack = pickle_packer
268 self.pack = pickle_packer
269 self.unpack = pickle_unpacker
269 self.unpack = pickle_unpacker
270 self.unpacker = new
270 self.unpacker = new
271 else:
271 else:
272 self.pack = import_item(str(new))
272 self.pack = import_item(str(new))
273
273
274 unpacker = DottedObjectName('json', config=True,
274 unpacker = DottedObjectName('json', config=True,
275 help="""The name of the unpacker for unserializing messages.
275 help="""The name of the unpacker for unserializing messages.
276 Only used with custom functions for `packer`.""")
276 Only used with custom functions for `packer`.""")
277 def _unpacker_changed(self, name, old, new):
277 def _unpacker_changed(self, name, old, new):
278 if new.lower() == 'json':
278 if new.lower() == 'json':
279 self.pack = json_packer
279 self.pack = json_packer
280 self.unpack = json_unpacker
280 self.unpack = json_unpacker
281 self.packer = new
281 self.packer = new
282 elif new.lower() == 'pickle':
282 elif new.lower() == 'pickle':
283 self.pack = pickle_packer
283 self.pack = pickle_packer
284 self.unpack = pickle_unpacker
284 self.unpack = pickle_unpacker
285 self.packer = new
285 self.packer = new
286 else:
286 else:
287 self.unpack = import_item(str(new))
287 self.unpack = import_item(str(new))
288
288
289 session = CUnicode(u'', config=True,
289 session = CUnicode(u'', config=True,
290 help="""The UUID identifying this session.""")
290 help="""The UUID identifying this session.""")
291 def _session_default(self):
291 def _session_default(self):
292 u = unicode_type(uuid.uuid4())
292 u = unicode_type(uuid.uuid4())
293 self.bsession = u.encode('ascii')
293 self.bsession = u.encode('ascii')
294 return u
294 return u
295
295
296 def _session_changed(self, name, old, new):
296 def _session_changed(self, name, old, new):
297 self.bsession = self.session.encode('ascii')
297 self.bsession = self.session.encode('ascii')
298
298
299 # bsession is the session as bytes
299 # bsession is the session as bytes
300 bsession = CBytes(b'')
300 bsession = CBytes(b'')
301
301
302 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
302 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
303 help="""Username for the Session. Default is your system username.""",
303 help="""Username for the Session. Default is your system username.""",
304 config=True)
304 config=True)
305
305
306 metadata = Dict({}, config=True,
306 metadata = Dict({}, config=True,
307 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
307 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
308
308
309 # message signature related traits:
309 # message signature related traits:
310
310
311 key = CBytes(b'', config=True,
311 key = CBytes(b'', config=True,
312 help="""execution key, for extra authentication.""")
312 help="""execution key, for extra authentication.""")
313 def _key_changed(self, name, old, new):
313 def _key_changed(self, name, old, new):
314 if new:
314 if new:
315 self.auth = hmac.HMAC(new, digestmod=self.digest_mod)
315 self.auth = hmac.HMAC(new, digestmod=self.digest_mod)
316 else:
316 else:
317 self.auth = None
317 self.auth = None
318
318
319 signature_scheme = Unicode('hmac-sha256', config=True,
319 signature_scheme = Unicode('hmac-sha256', config=True,
320 help="""The digest scheme used to construct the message signatures.
320 help="""The digest scheme used to construct the message signatures.
321 Must have the form 'hmac-HASH'.""")
321 Must have the form 'hmac-HASH'.""")
322 def _signature_scheme_changed(self, name, old, new):
322 def _signature_scheme_changed(self, name, old, new):
323 if not new.startswith('hmac-'):
323 if not new.startswith('hmac-'):
324 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
324 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
325 hash_name = new.split('-', 1)[1]
325 hash_name = new.split('-', 1)[1]
326 try:
326 try:
327 self.digest_mod = getattr(hashlib, hash_name)
327 self.digest_mod = getattr(hashlib, hash_name)
328 except AttributeError:
328 except AttributeError:
329 raise TraitError("hashlib has no such attribute: %s" % hash_name)
329 raise TraitError("hashlib has no such attribute: %s" % hash_name)
330
330
331 digest_mod = Any()
331 digest_mod = Any()
332 def _digest_mod_default(self):
332 def _digest_mod_default(self):
333 return hashlib.sha256
333 return hashlib.sha256
334
334
335 auth = Instance(hmac.HMAC)
335 auth = Instance(hmac.HMAC)
336
336
337 digest_history = Set()
337 digest_history = Set()
338 digest_history_size = Integer(2**16, config=True,
338 digest_history_size = Integer(2**16, config=True,
339 help="""The maximum number of digests to remember.
339 help="""The maximum number of digests to remember.
340
340
341 The digest history will be culled when it exceeds this value.
341 The digest history will be culled when it exceeds this value.
342 """
342 """
343 )
343 )
344
344
345 keyfile = Unicode('', config=True,
345 keyfile = Unicode('', config=True,
346 help="""path to file containing execution key.""")
346 help="""path to file containing execution key.""")
347 def _keyfile_changed(self, name, old, new):
347 def _keyfile_changed(self, name, old, new):
348 with open(new, 'rb') as f:
348 with open(new, 'rb') as f:
349 self.key = f.read().strip()
349 self.key = f.read().strip()
350
350
351 # for protecting against sends from forks
351 # for protecting against sends from forks
352 pid = Integer()
352 pid = Integer()
353
353
354 # serialization traits:
354 # serialization traits:
355
355
356 pack = Any(default_packer) # the actual packer function
356 pack = Any(default_packer) # the actual packer function
357 def _pack_changed(self, name, old, new):
357 def _pack_changed(self, name, old, new):
358 if not callable(new):
358 if not callable(new):
359 raise TypeError("packer must be callable, not %s"%type(new))
359 raise TypeError("packer must be callable, not %s"%type(new))
360
360
361 unpack = Any(default_unpacker) # the actual packer function
361 unpack = Any(default_unpacker) # the actual packer function
362 def _unpack_changed(self, name, old, new):
362 def _unpack_changed(self, name, old, new):
363 # unpacker is not checked - it is assumed to be
363 # unpacker is not checked - it is assumed to be
364 if not callable(new):
364 if not callable(new):
365 raise TypeError("unpacker must be callable, not %s"%type(new))
365 raise TypeError("unpacker must be callable, not %s"%type(new))
366
366
367 # thresholds:
367 # thresholds:
368 copy_threshold = Integer(2**16, config=True,
368 copy_threshold = Integer(2**16, config=True,
369 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
369 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
370 buffer_threshold = Integer(MAX_BYTES, config=True,
370 buffer_threshold = Integer(MAX_BYTES, config=True,
371 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
371 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
372 item_threshold = Integer(MAX_ITEMS, config=True,
372 item_threshold = Integer(MAX_ITEMS, config=True,
373 help="""The maximum number of items for a container to be introspected for custom serialization.
373 help="""The maximum number of items for a container to be introspected for custom serialization.
374 Containers larger than this are pickled outright.
374 Containers larger than this are pickled outright.
375 """
375 """
376 )
376 )
377
377
378
378
379 def __init__(self, **kwargs):
379 def __init__(self, **kwargs):
380 """create a Session object
380 """create a Session object
381
381
382 Parameters
382 Parameters
383 ----------
383 ----------
384
384
385 debug : bool
385 debug : bool
386 whether to trigger extra debugging statements
386 whether to trigger extra debugging statements
387 packer/unpacker : str : 'json', 'pickle' or import_string
387 packer/unpacker : str : 'json', 'pickle' or import_string
388 importstrings for methods to serialize message parts. If just
388 importstrings for methods to serialize message parts. If just
389 'json' or 'pickle', predefined JSON and pickle packers will be used.
389 'json' or 'pickle', predefined JSON and pickle packers will be used.
390 Otherwise, the entire importstring must be used.
390 Otherwise, the entire importstring must be used.
391
391
392 The functions must accept at least valid JSON input, and output
392 The functions must accept at least valid JSON input, and output
393 *bytes*.
393 *bytes*.
394
394
395 For example, to use msgpack:
395 For example, to use msgpack:
396 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
396 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
397 pack/unpack : callables
397 pack/unpack : callables
398 You can also set the pack/unpack callables for serialization
398 You can also set the pack/unpack callables for serialization
399 directly.
399 directly.
400 session : unicode (must be ascii)
400 session : unicode (must be ascii)
401 the ID of this Session object. The default is to generate a new
401 the ID of this Session object. The default is to generate a new
402 UUID.
402 UUID.
403 bsession : bytes
403 bsession : bytes
404 The session as bytes
404 The session as bytes
405 username : unicode
405 username : unicode
406 username added to message headers. The default is to ask the OS.
406 username added to message headers. The default is to ask the OS.
407 key : bytes
407 key : bytes
408 The key used to initialize an HMAC signature. If unset, messages
408 The key used to initialize an HMAC signature. If unset, messages
409 will not be signed or checked.
409 will not be signed or checked.
410 signature_scheme : str
410 signature_scheme : str
411 The message digest scheme. Currently must be of the form 'hmac-HASH',
411 The message digest scheme. Currently must be of the form 'hmac-HASH',
412 where 'HASH' is a hashing function available in Python's hashlib.
412 where 'HASH' is a hashing function available in Python's hashlib.
413 The default is 'hmac-sha256'.
413 The default is 'hmac-sha256'.
414 This is ignored if 'key' is empty.
414 This is ignored if 'key' is empty.
415 keyfile : filepath
415 keyfile : filepath
416 The file containing a key. If this is set, `key` will be
416 The file containing a key. If this is set, `key` will be
417 initialized to the contents of the file.
417 initialized to the contents of the file.
418 """
418 """
419 super(Session, self).__init__(**kwargs)
419 super(Session, self).__init__(**kwargs)
420 self._check_packers()
420 self._check_packers()
421 self.none = self.pack({})
421 self.none = self.pack({})
422 # ensure self._session_default() if necessary, so bsession is defined:
422 # ensure self._session_default() if necessary, so bsession is defined:
423 self.session
423 self.session
424 self.pid = os.getpid()
424 self.pid = os.getpid()
425
425
426 @property
426 @property
427 def msg_id(self):
427 def msg_id(self):
428 """always return new uuid"""
428 """always return new uuid"""
429 return str(uuid.uuid4())
429 return str(uuid.uuid4())
430
430
431 def _check_packers(self):
431 def _check_packers(self):
432 """check packers for datetime support."""
432 """check packers for datetime support."""
433 pack = self.pack
433 pack = self.pack
434 unpack = self.unpack
434 unpack = self.unpack
435
435
436 # check simple serialization
436 # check simple serialization
437 msg = dict(a=[1,'hi'])
437 msg = dict(a=[1,'hi'])
438 try:
438 try:
439 packed = pack(msg)
439 packed = pack(msg)
440 except Exception as e:
440 except Exception as e:
441 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
441 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
442 if self.packer == 'json':
442 if self.packer == 'json':
443 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
443 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
444 else:
444 else:
445 jsonmsg = ""
445 jsonmsg = ""
446 raise ValueError(
446 raise ValueError(
447 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
447 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
448 )
448 )
449
449
450 # ensure packed message is bytes
450 # ensure packed message is bytes
451 if not isinstance(packed, bytes):
451 if not isinstance(packed, bytes):
452 raise ValueError("message packed to %r, but bytes are required"%type(packed))
452 raise ValueError("message packed to %r, but bytes are required"%type(packed))
453
453
454 # check that unpack is pack's inverse
454 # check that unpack is pack's inverse
455 try:
455 try:
456 unpacked = unpack(packed)
456 unpacked = unpack(packed)
457 assert unpacked == msg
457 assert unpacked == msg
458 except Exception as e:
458 except Exception as e:
459 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
459 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
460 if self.packer == 'json':
460 if self.packer == 'json':
461 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
461 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
462 else:
462 else:
463 jsonmsg = ""
463 jsonmsg = ""
464 raise ValueError(
464 raise ValueError(
465 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
465 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
466 )
466 )
467
467
468 # check datetime support
468 # check datetime support
469 msg = dict(t=datetime.now())
469 msg = dict(t=datetime.now())
470 try:
470 try:
471 unpacked = unpack(pack(msg))
471 unpacked = unpack(pack(msg))
472 if isinstance(unpacked['t'], datetime):
472 if isinstance(unpacked['t'], datetime):
473 raise ValueError("Shouldn't deserialize to datetime")
473 raise ValueError("Shouldn't deserialize to datetime")
474 except Exception:
474 except Exception:
475 self.pack = lambda o: pack(squash_dates(o))
475 self.pack = lambda o: pack(squash_dates(o))
476 self.unpack = lambda s: unpack(s)
476 self.unpack = lambda s: unpack(s)
477
477
478 def msg_header(self, msg_type):
478 def msg_header(self, msg_type):
479 return msg_header(self.msg_id, msg_type, self.username, self.session)
479 return msg_header(self.msg_id, msg_type, self.username, self.session)
480
480
481 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
481 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
482 """Return the nested message dict.
482 """Return the nested message dict.
483
483
484 This format is different from what is sent over the wire. The
484 This format is different from what is sent over the wire. The
485 serialize/unserialize methods converts this nested message dict to the wire
485 serialize/unserialize methods converts this nested message dict to the wire
486 format, which is a list of message parts.
486 format, which is a list of message parts.
487 """
487 """
488 msg = {}
488 msg = {}
489 header = self.msg_header(msg_type) if header is None else header
489 header = self.msg_header(msg_type) if header is None else header
490 msg['header'] = header
490 msg['header'] = header
491 msg['msg_id'] = header['msg_id']
491 msg['msg_id'] = header['msg_id']
492 msg['msg_type'] = header['msg_type']
492 msg['msg_type'] = header['msg_type']
493 msg['parent_header'] = {} if parent is None else extract_header(parent)
493 msg['parent_header'] = {} if parent is None else extract_header(parent)
494 msg['content'] = {} if content is None else content
494 msg['content'] = {} if content is None else content
495 msg['metadata'] = self.metadata.copy()
495 msg['metadata'] = self.metadata.copy()
496 if metadata is not None:
496 if metadata is not None:
497 msg['metadata'].update(metadata)
497 msg['metadata'].update(metadata)
498 return msg
498 return msg
499
499
500 def sign(self, msg_list):
500 def sign(self, msg_list):
501 """Sign a message with HMAC digest. If no auth, return b''.
501 """Sign a message with HMAC digest. If no auth, return b''.
502
502
503 Parameters
503 Parameters
504 ----------
504 ----------
505 msg_list : list
505 msg_list : list
506 The [p_header,p_parent,p_content] part of the message list.
506 The [p_header,p_parent,p_content] part of the message list.
507 """
507 """
508 if self.auth is None:
508 if self.auth is None:
509 return b''
509 return b''
510 h = self.auth.copy()
510 h = self.auth.copy()
511 for m in msg_list:
511 for m in msg_list:
512 h.update(m)
512 h.update(m)
513 return str_to_bytes(h.hexdigest())
513 return str_to_bytes(h.hexdigest())
514
514
515 def serialize(self, msg, ident=None):
515 def serialize(self, msg, ident=None):
516 """Serialize the message components to bytes.
516 """Serialize the message components to bytes.
517
517
518 This is roughly the inverse of unserialize. The serialize/unserialize
518 This is roughly the inverse of unserialize. The serialize/unserialize
519 methods work with full message lists, whereas pack/unpack work with
519 methods work with full message lists, whereas pack/unpack work with
520 the individual message parts in the message list.
520 the individual message parts in the message list.
521
521
522 Parameters
522 Parameters
523 ----------
523 ----------
524 msg : dict or Message
524 msg : dict or Message
525 The nexted message dict as returned by the self.msg method.
525 The nexted message dict as returned by the self.msg method.
526
526
527 Returns
527 Returns
528 -------
528 -------
529 msg_list : list
529 msg_list : list
530 The list of bytes objects to be sent with the format::
530 The list of bytes objects to be sent with the format::
531
531
532 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
532 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
533 p_metadata, p_content, buffer1, buffer2, ...]
533 p_metadata, p_content, buffer1, buffer2, ...]
534
534
535 In this list, the ``p_*`` entities are the packed or serialized
535 In this list, the ``p_*`` entities are the packed or serialized
536 versions, so if JSON is used, these are utf8 encoded JSON strings.
536 versions, so if JSON is used, these are utf8 encoded JSON strings.
537 """
537 """
538 content = msg.get('content', {})
538 content = msg.get('content', {})
539 if content is None:
539 if content is None:
540 content = self.none
540 content = self.none
541 elif isinstance(content, dict):
541 elif isinstance(content, dict):
542 content = self.pack(content)
542 content = self.pack(content)
543 elif isinstance(content, bytes):
543 elif isinstance(content, bytes):
544 # content is already packed, as in a relayed message
544 # content is already packed, as in a relayed message
545 pass
545 pass
546 elif isinstance(content, unicode_type):
546 elif isinstance(content, unicode_type):
547 # should be bytes, but JSON often spits out unicode
547 # should be bytes, but JSON often spits out unicode
548 content = content.encode('utf8')
548 content = content.encode('utf8')
549 else:
549 else:
550 raise TypeError("Content incorrect type: %s"%type(content))
550 raise TypeError("Content incorrect type: %s"%type(content))
551
551
552 real_message = [self.pack(msg['header']),
552 real_message = [self.pack(msg['header']),
553 self.pack(msg['parent_header']),
553 self.pack(msg['parent_header']),
554 self.pack(msg['metadata']),
554 self.pack(msg['metadata']),
555 content,
555 content,
556 ]
556 ]
557
557
558 to_send = []
558 to_send = []
559
559
560 if isinstance(ident, list):
560 if isinstance(ident, list):
561 # accept list of idents
561 # accept list of idents
562 to_send.extend(ident)
562 to_send.extend(ident)
563 elif ident is not None:
563 elif ident is not None:
564 to_send.append(ident)
564 to_send.append(ident)
565 to_send.append(DELIM)
565 to_send.append(DELIM)
566
566
567 signature = self.sign(real_message)
567 signature = self.sign(real_message)
568 to_send.append(signature)
568 to_send.append(signature)
569
569
570 to_send.extend(real_message)
570 to_send.extend(real_message)
571
571
572 return to_send
572 return to_send
573
573
574 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
574 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
575 buffers=None, track=False, header=None, metadata=None):
575 buffers=None, track=False, header=None, metadata=None):
576 """Build and send a message via stream or socket.
576 """Build and send a message via stream or socket.
577
577
578 The message format used by this function internally is as follows:
578 The message format used by this function internally is as follows:
579
579
580 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
580 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
581 buffer1,buffer2,...]
581 buffer1,buffer2,...]
582
582
583 The serialize/unserialize methods convert the nested message dict into this
583 The serialize/unserialize methods convert the nested message dict into this
584 format.
584 format.
585
585
586 Parameters
586 Parameters
587 ----------
587 ----------
588
588
589 stream : zmq.Socket or ZMQStream
589 stream : zmq.Socket or ZMQStream
590 The socket-like object used to send the data.
590 The socket-like object used to send the data.
591 msg_or_type : str or Message/dict
591 msg_or_type : str or Message/dict
592 Normally, msg_or_type will be a msg_type unless a message is being
592 Normally, msg_or_type will be a msg_type unless a message is being
593 sent more than once. If a header is supplied, this can be set to
593 sent more than once. If a header is supplied, this can be set to
594 None and the msg_type will be pulled from the header.
594 None and the msg_type will be pulled from the header.
595
595
596 content : dict or None
596 content : dict or None
597 The content of the message (ignored if msg_or_type is a message).
597 The content of the message (ignored if msg_or_type is a message).
598 header : dict or None
598 header : dict or None
599 The header dict for the message (ignored if msg_to_type is a message).
599 The header dict for the message (ignored if msg_to_type is a message).
600 parent : Message or dict or None
600 parent : Message or dict or None
601 The parent or parent header describing the parent of this message
601 The parent or parent header describing the parent of this message
602 (ignored if msg_or_type is a message).
602 (ignored if msg_or_type is a message).
603 ident : bytes or list of bytes
603 ident : bytes or list of bytes
604 The zmq.IDENTITY routing path.
604 The zmq.IDENTITY routing path.
605 metadata : dict or None
605 metadata : dict or None
606 The metadata describing the message
606 The metadata describing the message
607 buffers : list or None
607 buffers : list or None
608 The already-serialized buffers to be appended to the message.
608 The already-serialized buffers to be appended to the message.
609 track : bool
609 track : bool
610 Whether to track. Only for use with Sockets, because ZMQStream
610 Whether to track. Only for use with Sockets, because ZMQStream
611 objects cannot track messages.
611 objects cannot track messages.
612
612
613
613
614 Returns
614 Returns
615 -------
615 -------
616 msg : dict
616 msg : dict
617 The constructed message.
617 The constructed message.
618 """
618 """
619 if not isinstance(stream, zmq.Socket):
619 if not isinstance(stream, zmq.Socket):
620 # ZMQStreams and dummy sockets do not support tracking.
620 # ZMQStreams and dummy sockets do not support tracking.
621 track = False
621 track = False
622
622
623 if isinstance(msg_or_type, (Message, dict)):
623 if isinstance(msg_or_type, (Message, dict)):
624 # We got a Message or message dict, not a msg_type so don't
624 # We got a Message or message dict, not a msg_type so don't
625 # build a new Message.
625 # build a new Message.
626 msg = msg_or_type
626 msg = msg_or_type
627 else:
627 else:
628 msg = self.msg(msg_or_type, content=content, parent=parent,
628 msg = self.msg(msg_or_type, content=content, parent=parent,
629 header=header, metadata=metadata)
629 header=header, metadata=metadata)
630 if not os.getpid() == self.pid:
630 if not os.getpid() == self.pid:
631 io.rprint("WARNING: attempted to send message from fork")
631 io.rprint("WARNING: attempted to send message from fork")
632 io.rprint(msg)
632 io.rprint(msg)
633 return
633 return
634 buffers = [] if buffers is None else buffers
634 buffers = [] if buffers is None else buffers
635 to_send = self.serialize(msg, ident)
635 to_send = self.serialize(msg, ident)
636 to_send.extend(buffers)
636 to_send.extend(buffers)
637 longest = max([ len(s) for s in to_send ])
637 longest = max([ len(s) for s in to_send ])
638 copy = (longest < self.copy_threshold)
638 copy = (longest < self.copy_threshold)
639
639
640 if buffers and track and not copy:
640 if buffers and track and not copy:
641 # only really track when we are doing zero-copy buffers
641 # only really track when we are doing zero-copy buffers
642 tracker = stream.send_multipart(to_send, copy=False, track=True)
642 tracker = stream.send_multipart(to_send, copy=False, track=True)
643 else:
643 else:
644 # use dummy tracker, which will be done immediately
644 # use dummy tracker, which will be done immediately
645 tracker = DONE
645 tracker = DONE
646 stream.send_multipart(to_send, copy=copy)
646 stream.send_multipart(to_send, copy=copy)
647
647
648 if self.debug:
648 if self.debug:
649 pprint.pprint(msg)
649 pprint.pprint(msg)
650 pprint.pprint(to_send)
650 pprint.pprint(to_send)
651 pprint.pprint(buffers)
651 pprint.pprint(buffers)
652
652
653 msg['tracker'] = tracker
653 msg['tracker'] = tracker
654
654
655 return msg
655 return msg
656
656
657 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
657 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
658 """Send a raw message via ident path.
658 """Send a raw message via ident path.
659
659
660 This method is used to send a already serialized message.
660 This method is used to send a already serialized message.
661
661
662 Parameters
662 Parameters
663 ----------
663 ----------
664 stream : ZMQStream or Socket
664 stream : ZMQStream or Socket
665 The ZMQ stream or socket to use for sending the message.
665 The ZMQ stream or socket to use for sending the message.
666 msg_list : list
666 msg_list : list
667 The serialized list of messages to send. This only includes the
667 The serialized list of messages to send. This only includes the
668 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
668 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
669 the message.
669 the message.
670 ident : ident or list
670 ident : ident or list
671 A single ident or a list of idents to use in sending.
671 A single ident or a list of idents to use in sending.
672 """
672 """
673 to_send = []
673 to_send = []
674 if isinstance(ident, bytes):
674 if isinstance(ident, bytes):
675 ident = [ident]
675 ident = [ident]
676 if ident is not None:
676 if ident is not None:
677 to_send.extend(ident)
677 to_send.extend(ident)
678
678
679 to_send.append(DELIM)
679 to_send.append(DELIM)
680 to_send.append(self.sign(msg_list))
680 to_send.append(self.sign(msg_list))
681 to_send.extend(msg_list)
681 to_send.extend(msg_list)
682 stream.send_multipart(msg_list, flags, copy=copy)
682 stream.send_multipart(to_send, flags, copy=copy)
683
683
684 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
684 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
685 """Receive and unpack a message.
685 """Receive and unpack a message.
686
686
687 Parameters
687 Parameters
688 ----------
688 ----------
689 socket : ZMQStream or Socket
689 socket : ZMQStream or Socket
690 The socket or stream to use in receiving.
690 The socket or stream to use in receiving.
691
691
692 Returns
692 Returns
693 -------
693 -------
694 [idents], msg
694 [idents], msg
695 [idents] is a list of idents and msg is a nested message dict of
695 [idents] is a list of idents and msg is a nested message dict of
696 same format as self.msg returns.
696 same format as self.msg returns.
697 """
697 """
698 if isinstance(socket, ZMQStream):
698 if isinstance(socket, ZMQStream):
699 socket = socket.socket
699 socket = socket.socket
700 try:
700 try:
701 msg_list = socket.recv_multipart(mode, copy=copy)
701 msg_list = socket.recv_multipart(mode, copy=copy)
702 except zmq.ZMQError as e:
702 except zmq.ZMQError as e:
703 if e.errno == zmq.EAGAIN:
703 if e.errno == zmq.EAGAIN:
704 # We can convert EAGAIN to None as we know in this case
704 # We can convert EAGAIN to None as we know in this case
705 # recv_multipart won't return None.
705 # recv_multipart won't return None.
706 return None,None
706 return None,None
707 else:
707 else:
708 raise
708 raise
709 # split multipart message into identity list and message dict
709 # split multipart message into identity list and message dict
710 # invalid large messages can cause very expensive string comparisons
710 # invalid large messages can cause very expensive string comparisons
711 idents, msg_list = self.feed_identities(msg_list, copy)
711 idents, msg_list = self.feed_identities(msg_list, copy)
712 try:
712 try:
713 return idents, self.unserialize(msg_list, content=content, copy=copy)
713 return idents, self.unserialize(msg_list, content=content, copy=copy)
714 except Exception as e:
714 except Exception as e:
715 # TODO: handle it
715 # TODO: handle it
716 raise e
716 raise e
717
717
718 def feed_identities(self, msg_list, copy=True):
718 def feed_identities(self, msg_list, copy=True):
719 """Split the identities from the rest of the message.
719 """Split the identities from the rest of the message.
720
720
721 Feed until DELIM is reached, then return the prefix as idents and
721 Feed until DELIM is reached, then return the prefix as idents and
722 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
722 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
723 but that would be silly.
723 but that would be silly.
724
724
725 Parameters
725 Parameters
726 ----------
726 ----------
727 msg_list : a list of Message or bytes objects
727 msg_list : a list of Message or bytes objects
728 The message to be split.
728 The message to be split.
729 copy : bool
729 copy : bool
730 flag determining whether the arguments are bytes or Messages
730 flag determining whether the arguments are bytes or Messages
731
731
732 Returns
732 Returns
733 -------
733 -------
734 (idents, msg_list) : two lists
734 (idents, msg_list) : two lists
735 idents will always be a list of bytes, each of which is a ZMQ
735 idents will always be a list of bytes, each of which is a ZMQ
736 identity. msg_list will be a list of bytes or zmq.Messages of the
736 identity. msg_list will be a list of bytes or zmq.Messages of the
737 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
737 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
738 should be unpackable/unserializable via self.unserialize at this
738 should be unpackable/unserializable via self.unserialize at this
739 point.
739 point.
740 """
740 """
741 if copy:
741 if copy:
742 idx = msg_list.index(DELIM)
742 idx = msg_list.index(DELIM)
743 return msg_list[:idx], msg_list[idx+1:]
743 return msg_list[:idx], msg_list[idx+1:]
744 else:
744 else:
745 failed = True
745 failed = True
746 for idx,m in enumerate(msg_list):
746 for idx,m in enumerate(msg_list):
747 if m.bytes == DELIM:
747 if m.bytes == DELIM:
748 failed = False
748 failed = False
749 break
749 break
750 if failed:
750 if failed:
751 raise ValueError("DELIM not in msg_list")
751 raise ValueError("DELIM not in msg_list")
752 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
752 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
753 return [m.bytes for m in idents], msg_list
753 return [m.bytes for m in idents], msg_list
754
754
755 def _add_digest(self, signature):
755 def _add_digest(self, signature):
756 """add a digest to history to protect against replay attacks"""
756 """add a digest to history to protect against replay attacks"""
757 if self.digest_history_size == 0:
757 if self.digest_history_size == 0:
758 # no history, never add digests
758 # no history, never add digests
759 return
759 return
760
760
761 self.digest_history.add(signature)
761 self.digest_history.add(signature)
762 if len(self.digest_history) > self.digest_history_size:
762 if len(self.digest_history) > self.digest_history_size:
763 # threshold reached, cull 10%
763 # threshold reached, cull 10%
764 self._cull_digest_history()
764 self._cull_digest_history()
765
765
766 def _cull_digest_history(self):
766 def _cull_digest_history(self):
767 """cull the digest history
767 """cull the digest history
768
768
769 Removes a randomly selected 10% of the digest history
769 Removes a randomly selected 10% of the digest history
770 """
770 """
771 current = len(self.digest_history)
771 current = len(self.digest_history)
772 n_to_cull = max(int(current // 10), current - self.digest_history_size)
772 n_to_cull = max(int(current // 10), current - self.digest_history_size)
773 if n_to_cull >= current:
773 if n_to_cull >= current:
774 self.digest_history = set()
774 self.digest_history = set()
775 return
775 return
776 to_cull = random.sample(self.digest_history, n_to_cull)
776 to_cull = random.sample(self.digest_history, n_to_cull)
777 self.digest_history.difference_update(to_cull)
777 self.digest_history.difference_update(to_cull)
778
778
779 def unserialize(self, msg_list, content=True, copy=True):
779 def unserialize(self, msg_list, content=True, copy=True):
780 """Unserialize a msg_list to a nested message dict.
780 """Unserialize a msg_list to a nested message dict.
781
781
782 This is roughly the inverse of serialize. The serialize/unserialize
782 This is roughly the inverse of serialize. The serialize/unserialize
783 methods work with full message lists, whereas pack/unpack work with
783 methods work with full message lists, whereas pack/unpack work with
784 the individual message parts in the message list.
784 the individual message parts in the message list.
785
785
786 Parameters
786 Parameters
787 ----------
787 ----------
788 msg_list : list of bytes or Message objects
788 msg_list : list of bytes or Message objects
789 The list of message parts of the form [HMAC,p_header,p_parent,
789 The list of message parts of the form [HMAC,p_header,p_parent,
790 p_metadata,p_content,buffer1,buffer2,...].
790 p_metadata,p_content,buffer1,buffer2,...].
791 content : bool (True)
791 content : bool (True)
792 Whether to unpack the content dict (True), or leave it packed
792 Whether to unpack the content dict (True), or leave it packed
793 (False).
793 (False).
794 copy : bool (True)
794 copy : bool (True)
795 Whether to return the bytes (True), or the non-copying Message
795 Whether to return the bytes (True), or the non-copying Message
796 object in each place (False).
796 object in each place (False).
797
797
798 Returns
798 Returns
799 -------
799 -------
800 msg : dict
800 msg : dict
801 The nested message dict with top-level keys [header, parent_header,
801 The nested message dict with top-level keys [header, parent_header,
802 content, buffers].
802 content, buffers].
803 """
803 """
804 minlen = 5
804 minlen = 5
805 message = {}
805 message = {}
806 if not copy:
806 if not copy:
807 for i in range(minlen):
807 for i in range(minlen):
808 msg_list[i] = msg_list[i].bytes
808 msg_list[i] = msg_list[i].bytes
809 if self.auth is not None:
809 if self.auth is not None:
810 signature = msg_list[0]
810 signature = msg_list[0]
811 if not signature:
811 if not signature:
812 raise ValueError("Unsigned Message")
812 raise ValueError("Unsigned Message")
813 if signature in self.digest_history:
813 if signature in self.digest_history:
814 raise ValueError("Duplicate Signature: %r" % signature)
814 raise ValueError("Duplicate Signature: %r" % signature)
815 self._add_digest(signature)
815 self._add_digest(signature)
816 check = self.sign(msg_list[1:5])
816 check = self.sign(msg_list[1:5])
817 if not signature == check:
817 if not signature == check:
818 raise ValueError("Invalid Signature: %r" % signature)
818 raise ValueError("Invalid Signature: %r" % signature)
819 if not len(msg_list) >= minlen:
819 if not len(msg_list) >= minlen:
820 raise TypeError("malformed message, must have at least %i elements"%minlen)
820 raise TypeError("malformed message, must have at least %i elements"%minlen)
821 header = self.unpack(msg_list[1])
821 header = self.unpack(msg_list[1])
822 message['header'] = extract_dates(header)
822 message['header'] = extract_dates(header)
823 message['msg_id'] = header['msg_id']
823 message['msg_id'] = header['msg_id']
824 message['msg_type'] = header['msg_type']
824 message['msg_type'] = header['msg_type']
825 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
825 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
826 message['metadata'] = self.unpack(msg_list[3])
826 message['metadata'] = self.unpack(msg_list[3])
827 if content:
827 if content:
828 message['content'] = self.unpack(msg_list[4])
828 message['content'] = self.unpack(msg_list[4])
829 else:
829 else:
830 message['content'] = msg_list[4]
830 message['content'] = msg_list[4]
831
831
832 message['buffers'] = msg_list[5:]
832 message['buffers'] = msg_list[5:]
833 return message
833 return message
834
834
835 def test_msg2obj():
835 def test_msg2obj():
836 am = dict(x=1)
836 am = dict(x=1)
837 ao = Message(am)
837 ao = Message(am)
838 assert ao.x == am['x']
838 assert ao.x == am['x']
839
839
840 am['y'] = dict(z=1)
840 am['y'] = dict(z=1)
841 ao = Message(am)
841 ao = Message(am)
842 assert ao.y.z == am['y']['z']
842 assert ao.y.z == am['y']['z']
843
843
844 k1, k2 = 'y', 'z'
844 k1, k2 = 'y', 'z'
845 assert ao[k1][k2] == am[k1][k2]
845 assert ao[k1][k2] == am[k1][k2]
846
846
847 am2 = dict(ao)
847 am2 = dict(ao)
848 assert am['x'] == am2['x']
848 assert am['x'] == am2['x']
849 assert am['y']['z'] == am2['y']['z']
849 assert am['y']['z'] == am2['y']['z']
850
850
@@ -1,289 +1,313 b''
1 """test building messages with streamsession"""
1 """test building messages with streamsession"""
2
2
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 import os
14 import os
15 import uuid
15 import uuid
16 from datetime import datetime
16 from datetime import datetime
17
17
18 import zmq
18 import zmq
19
19
20 from zmq.tests import BaseZMQTestCase
20 from zmq.tests import BaseZMQTestCase
21 from zmq.eventloop.zmqstream import ZMQStream
21 from zmq.eventloop.zmqstream import ZMQStream
22
22
23 from IPython.kernel.zmq import session as ss
23 from IPython.kernel.zmq import session as ss
24
24
25 from IPython.testing.decorators import skipif, module_not_available
25 from IPython.testing.decorators import skipif, module_not_available
26 from IPython.utils.py3compat import string_types
26 from IPython.utils.py3compat import string_types
27 from IPython.utils import jsonutil
27 from IPython.utils import jsonutil
28
28
29 def _bad_packer(obj):
29 def _bad_packer(obj):
30 raise TypeError("I don't work")
30 raise TypeError("I don't work")
31
31
32 def _bad_unpacker(bytes):
32 def _bad_unpacker(bytes):
33 raise TypeError("I don't work either")
33 raise TypeError("I don't work either")
34
34
35 class SessionTestCase(BaseZMQTestCase):
35 class SessionTestCase(BaseZMQTestCase):
36
36
37 def setUp(self):
37 def setUp(self):
38 BaseZMQTestCase.setUp(self)
38 BaseZMQTestCase.setUp(self)
39 self.session = ss.Session()
39 self.session = ss.Session()
40
40
41
41
42 class TestSession(SessionTestCase):
42 class TestSession(SessionTestCase):
43
43
44 def test_msg(self):
44 def test_msg(self):
45 """message format"""
45 """message format"""
46 msg = self.session.msg('execute')
46 msg = self.session.msg('execute')
47 thekeys = set('header parent_header metadata content msg_type msg_id'.split())
47 thekeys = set('header parent_header metadata content msg_type msg_id'.split())
48 s = set(msg.keys())
48 s = set(msg.keys())
49 self.assertEqual(s, thekeys)
49 self.assertEqual(s, thekeys)
50 self.assertTrue(isinstance(msg['content'],dict))
50 self.assertTrue(isinstance(msg['content'],dict))
51 self.assertTrue(isinstance(msg['metadata'],dict))
51 self.assertTrue(isinstance(msg['metadata'],dict))
52 self.assertTrue(isinstance(msg['header'],dict))
52 self.assertTrue(isinstance(msg['header'],dict))
53 self.assertTrue(isinstance(msg['parent_header'],dict))
53 self.assertTrue(isinstance(msg['parent_header'],dict))
54 self.assertTrue(isinstance(msg['msg_id'],str))
54 self.assertTrue(isinstance(msg['msg_id'],str))
55 self.assertTrue(isinstance(msg['msg_type'],str))
55 self.assertTrue(isinstance(msg['msg_type'],str))
56 self.assertEqual(msg['header']['msg_type'], 'execute')
56 self.assertEqual(msg['header']['msg_type'], 'execute')
57 self.assertEqual(msg['msg_type'], 'execute')
57 self.assertEqual(msg['msg_type'], 'execute')
58
58
59 def test_serialize(self):
59 def test_serialize(self):
60 msg = self.session.msg('execute', content=dict(a=10, b=1.1))
60 msg = self.session.msg('execute', content=dict(a=10, b=1.1))
61 msg_list = self.session.serialize(msg, ident=b'foo')
61 msg_list = self.session.serialize(msg, ident=b'foo')
62 ident, msg_list = self.session.feed_identities(msg_list)
62 ident, msg_list = self.session.feed_identities(msg_list)
63 new_msg = self.session.unserialize(msg_list)
63 new_msg = self.session.unserialize(msg_list)
64 self.assertEqual(ident[0], b'foo')
64 self.assertEqual(ident[0], b'foo')
65 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
65 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
66 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
66 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
67 self.assertEqual(new_msg['header'],msg['header'])
67 self.assertEqual(new_msg['header'],msg['header'])
68 self.assertEqual(new_msg['content'],msg['content'])
68 self.assertEqual(new_msg['content'],msg['content'])
69 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
69 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
70 self.assertEqual(new_msg['metadata'],msg['metadata'])
70 self.assertEqual(new_msg['metadata'],msg['metadata'])
71 # ensure floats don't come out as Decimal:
71 # ensure floats don't come out as Decimal:
72 self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
72 self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
73
73
74 def test_send(self):
74 def test_send(self):
75 ctx = zmq.Context.instance()
75 ctx = zmq.Context.instance()
76 A = ctx.socket(zmq.PAIR)
76 A = ctx.socket(zmq.PAIR)
77 B = ctx.socket(zmq.PAIR)
77 B = ctx.socket(zmq.PAIR)
78 A.bind("inproc://test")
78 A.bind("inproc://test")
79 B.connect("inproc://test")
79 B.connect("inproc://test")
80
80
81 msg = self.session.msg('execute', content=dict(a=10))
81 msg = self.session.msg('execute', content=dict(a=10))
82 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
82 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
83
83
84 ident, msg_list = self.session.feed_identities(B.recv_multipart())
84 ident, msg_list = self.session.feed_identities(B.recv_multipart())
85 new_msg = self.session.unserialize(msg_list)
85 new_msg = self.session.unserialize(msg_list)
86 self.assertEqual(ident[0], b'foo')
86 self.assertEqual(ident[0], b'foo')
87 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
87 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
88 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
88 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
89 self.assertEqual(new_msg['header'],msg['header'])
89 self.assertEqual(new_msg['header'],msg['header'])
90 self.assertEqual(new_msg['content'],msg['content'])
90 self.assertEqual(new_msg['content'],msg['content'])
91 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
91 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
92 self.assertEqual(new_msg['metadata'],msg['metadata'])
92 self.assertEqual(new_msg['metadata'],msg['metadata'])
93 self.assertEqual(new_msg['buffers'],[b'bar'])
93 self.assertEqual(new_msg['buffers'],[b'bar'])
94
94
95 content = msg['content']
95 content = msg['content']
96 header = msg['header']
96 header = msg['header']
97 parent = msg['parent_header']
97 parent = msg['parent_header']
98 metadata = msg['metadata']
98 metadata = msg['metadata']
99 msg_type = header['msg_type']
99 msg_type = header['msg_type']
100 self.session.send(A, None, content=content, parent=parent,
100 self.session.send(A, None, content=content, parent=parent,
101 header=header, metadata=metadata, ident=b'foo', buffers=[b'bar'])
101 header=header, metadata=metadata, ident=b'foo', buffers=[b'bar'])
102 ident, msg_list = self.session.feed_identities(B.recv_multipart())
102 ident, msg_list = self.session.feed_identities(B.recv_multipart())
103 new_msg = self.session.unserialize(msg_list)
103 new_msg = self.session.unserialize(msg_list)
104 self.assertEqual(ident[0], b'foo')
104 self.assertEqual(ident[0], b'foo')
105 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
105 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
106 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
106 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
107 self.assertEqual(new_msg['header'],msg['header'])
107 self.assertEqual(new_msg['header'],msg['header'])
108 self.assertEqual(new_msg['content'],msg['content'])
108 self.assertEqual(new_msg['content'],msg['content'])
109 self.assertEqual(new_msg['metadata'],msg['metadata'])
109 self.assertEqual(new_msg['metadata'],msg['metadata'])
110 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
110 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
111 self.assertEqual(new_msg['buffers'],[b'bar'])
111 self.assertEqual(new_msg['buffers'],[b'bar'])
112
112
113 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
113 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
114 ident, new_msg = self.session.recv(B)
114 ident, new_msg = self.session.recv(B)
115 self.assertEqual(ident[0], b'foo')
115 self.assertEqual(ident[0], b'foo')
116 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
116 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
117 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
117 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
118 self.assertEqual(new_msg['header'],msg['header'])
118 self.assertEqual(new_msg['header'],msg['header'])
119 self.assertEqual(new_msg['content'],msg['content'])
119 self.assertEqual(new_msg['content'],msg['content'])
120 self.assertEqual(new_msg['metadata'],msg['metadata'])
120 self.assertEqual(new_msg['metadata'],msg['metadata'])
121 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
121 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
122 self.assertEqual(new_msg['buffers'],[b'bar'])
122 self.assertEqual(new_msg['buffers'],[b'bar'])
123
123
124 A.close()
124 A.close()
125 B.close()
125 B.close()
126 ctx.term()
126 ctx.term()
127
127
128 def test_args(self):
128 def test_args(self):
129 """initialization arguments for Session"""
129 """initialization arguments for Session"""
130 s = self.session
130 s = self.session
131 self.assertTrue(s.pack is ss.default_packer)
131 self.assertTrue(s.pack is ss.default_packer)
132 self.assertTrue(s.unpack is ss.default_unpacker)
132 self.assertTrue(s.unpack is ss.default_unpacker)
133 self.assertEqual(s.username, os.environ.get('USER', u'username'))
133 self.assertEqual(s.username, os.environ.get('USER', u'username'))
134
134
135 s = ss.Session()
135 s = ss.Session()
136 self.assertEqual(s.username, os.environ.get('USER', u'username'))
136 self.assertEqual(s.username, os.environ.get('USER', u'username'))
137
137
138 self.assertRaises(TypeError, ss.Session, pack='hi')
138 self.assertRaises(TypeError, ss.Session, pack='hi')
139 self.assertRaises(TypeError, ss.Session, unpack='hi')
139 self.assertRaises(TypeError, ss.Session, unpack='hi')
140 u = str(uuid.uuid4())
140 u = str(uuid.uuid4())
141 s = ss.Session(username=u'carrot', session=u)
141 s = ss.Session(username=u'carrot', session=u)
142 self.assertEqual(s.session, u)
142 self.assertEqual(s.session, u)
143 self.assertEqual(s.username, u'carrot')
143 self.assertEqual(s.username, u'carrot')
144
144
145 def test_tracking(self):
145 def test_tracking(self):
146 """test tracking messages"""
146 """test tracking messages"""
147 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
147 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
148 s = self.session
148 s = self.session
149 s.copy_threshold = 1
149 s.copy_threshold = 1
150 stream = ZMQStream(a)
150 stream = ZMQStream(a)
151 msg = s.send(a, 'hello', track=False)
151 msg = s.send(a, 'hello', track=False)
152 self.assertTrue(msg['tracker'] is ss.DONE)
152 self.assertTrue(msg['tracker'] is ss.DONE)
153 msg = s.send(a, 'hello', track=True)
153 msg = s.send(a, 'hello', track=True)
154 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
154 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
155 M = zmq.Message(b'hi there', track=True)
155 M = zmq.Message(b'hi there', track=True)
156 msg = s.send(a, 'hello', buffers=[M], track=True)
156 msg = s.send(a, 'hello', buffers=[M], track=True)
157 t = msg['tracker']
157 t = msg['tracker']
158 self.assertTrue(isinstance(t, zmq.MessageTracker))
158 self.assertTrue(isinstance(t, zmq.MessageTracker))
159 self.assertRaises(zmq.NotDone, t.wait, .1)
159 self.assertRaises(zmq.NotDone, t.wait, .1)
160 del M
160 del M
161 t.wait(1) # this will raise
161 t.wait(1) # this will raise
162
162
163
163
164 def test_unique_msg_ids(self):
164 def test_unique_msg_ids(self):
165 """test that messages receive unique ids"""
165 """test that messages receive unique ids"""
166 ids = set()
166 ids = set()
167 for i in range(2**12):
167 for i in range(2**12):
168 h = self.session.msg_header('test')
168 h = self.session.msg_header('test')
169 msg_id = h['msg_id']
169 msg_id = h['msg_id']
170 self.assertTrue(msg_id not in ids)
170 self.assertTrue(msg_id not in ids)
171 ids.add(msg_id)
171 ids.add(msg_id)
172
172
173 def test_feed_identities(self):
173 def test_feed_identities(self):
174 """scrub the front for zmq IDENTITIES"""
174 """scrub the front for zmq IDENTITIES"""
175 theids = "engine client other".split()
175 theids = "engine client other".split()
176 content = dict(code='whoda',stuff=object())
176 content = dict(code='whoda',stuff=object())
177 themsg = self.session.msg('execute',content=content)
177 themsg = self.session.msg('execute',content=content)
178 pmsg = theids
178 pmsg = theids
179
179
180 def test_session_id(self):
180 def test_session_id(self):
181 session = ss.Session()
181 session = ss.Session()
182 # get bs before us
182 # get bs before us
183 bs = session.bsession
183 bs = session.bsession
184 us = session.session
184 us = session.session
185 self.assertEqual(us.encode('ascii'), bs)
185 self.assertEqual(us.encode('ascii'), bs)
186 session = ss.Session()
186 session = ss.Session()
187 # get us before bs
187 # get us before bs
188 us = session.session
188 us = session.session
189 bs = session.bsession
189 bs = session.bsession
190 self.assertEqual(us.encode('ascii'), bs)
190 self.assertEqual(us.encode('ascii'), bs)
191 # change propagates:
191 # change propagates:
192 session.session = 'something else'
192 session.session = 'something else'
193 bs = session.bsession
193 bs = session.bsession
194 us = session.session
194 us = session.session
195 self.assertEqual(us.encode('ascii'), bs)
195 self.assertEqual(us.encode('ascii'), bs)
196 session = ss.Session(session='stuff')
196 session = ss.Session(session='stuff')
197 # get us before bs
197 # get us before bs
198 self.assertEqual(session.bsession, session.session.encode('ascii'))
198 self.assertEqual(session.bsession, session.session.encode('ascii'))
199 self.assertEqual(b'stuff', session.bsession)
199 self.assertEqual(b'stuff', session.bsession)
200
200
201 def test_zero_digest_history(self):
201 def test_zero_digest_history(self):
202 session = ss.Session(digest_history_size=0)
202 session = ss.Session(digest_history_size=0)
203 for i in range(11):
203 for i in range(11):
204 session._add_digest(uuid.uuid4().bytes)
204 session._add_digest(uuid.uuid4().bytes)
205 self.assertEqual(len(session.digest_history), 0)
205 self.assertEqual(len(session.digest_history), 0)
206
206
207 def test_cull_digest_history(self):
207 def test_cull_digest_history(self):
208 session = ss.Session(digest_history_size=100)
208 session = ss.Session(digest_history_size=100)
209 for i in range(100):
209 for i in range(100):
210 session._add_digest(uuid.uuid4().bytes)
210 session._add_digest(uuid.uuid4().bytes)
211 self.assertTrue(len(session.digest_history) == 100)
211 self.assertTrue(len(session.digest_history) == 100)
212 session._add_digest(uuid.uuid4().bytes)
212 session._add_digest(uuid.uuid4().bytes)
213 self.assertTrue(len(session.digest_history) == 91)
213 self.assertTrue(len(session.digest_history) == 91)
214 for i in range(9):
214 for i in range(9):
215 session._add_digest(uuid.uuid4().bytes)
215 session._add_digest(uuid.uuid4().bytes)
216 self.assertTrue(len(session.digest_history) == 100)
216 self.assertTrue(len(session.digest_history) == 100)
217 session._add_digest(uuid.uuid4().bytes)
217 session._add_digest(uuid.uuid4().bytes)
218 self.assertTrue(len(session.digest_history) == 91)
218 self.assertTrue(len(session.digest_history) == 91)
219
219
220 def test_bad_pack(self):
220 def test_bad_pack(self):
221 try:
221 try:
222 session = ss.Session(pack=_bad_packer)
222 session = ss.Session(pack=_bad_packer)
223 except ValueError as e:
223 except ValueError as e:
224 self.assertIn("could not serialize", str(e))
224 self.assertIn("could not serialize", str(e))
225 self.assertIn("don't work", str(e))
225 self.assertIn("don't work", str(e))
226 else:
226 else:
227 self.fail("Should have raised ValueError")
227 self.fail("Should have raised ValueError")
228
228
229 def test_bad_unpack(self):
229 def test_bad_unpack(self):
230 try:
230 try:
231 session = ss.Session(unpack=_bad_unpacker)
231 session = ss.Session(unpack=_bad_unpacker)
232 except ValueError as e:
232 except ValueError as e:
233 self.assertIn("could not handle output", str(e))
233 self.assertIn("could not handle output", str(e))
234 self.assertIn("don't work either", str(e))
234 self.assertIn("don't work either", str(e))
235 else:
235 else:
236 self.fail("Should have raised ValueError")
236 self.fail("Should have raised ValueError")
237
237
238 def test_bad_packer(self):
238 def test_bad_packer(self):
239 try:
239 try:
240 session = ss.Session(packer=__name__ + '._bad_packer')
240 session = ss.Session(packer=__name__ + '._bad_packer')
241 except ValueError as e:
241 except ValueError as e:
242 self.assertIn("could not serialize", str(e))
242 self.assertIn("could not serialize", str(e))
243 self.assertIn("don't work", str(e))
243 self.assertIn("don't work", str(e))
244 else:
244 else:
245 self.fail("Should have raised ValueError")
245 self.fail("Should have raised ValueError")
246
246
247 def test_bad_unpacker(self):
247 def test_bad_unpacker(self):
248 try:
248 try:
249 session = ss.Session(unpacker=__name__ + '._bad_unpacker')
249 session = ss.Session(unpacker=__name__ + '._bad_unpacker')
250 except ValueError as e:
250 except ValueError as e:
251 self.assertIn("could not handle output", str(e))
251 self.assertIn("could not handle output", str(e))
252 self.assertIn("don't work either", str(e))
252 self.assertIn("don't work either", str(e))
253 else:
253 else:
254 self.fail("Should have raised ValueError")
254 self.fail("Should have raised ValueError")
255
255
256 def test_bad_roundtrip(self):
256 def test_bad_roundtrip(self):
257 with self.assertRaises(ValueError):
257 with self.assertRaises(ValueError):
258 session = ss.Session(unpack=lambda b: 5)
258 session = ss.Session(unpack=lambda b: 5)
259
259
260 def _datetime_test(self, session):
260 def _datetime_test(self, session):
261 content = dict(t=datetime.now())
261 content = dict(t=datetime.now())
262 metadata = dict(t=datetime.now())
262 metadata = dict(t=datetime.now())
263 p = session.msg('msg')
263 p = session.msg('msg')
264 msg = session.msg('msg', content=content, metadata=metadata, parent=p['header'])
264 msg = session.msg('msg', content=content, metadata=metadata, parent=p['header'])
265 smsg = session.serialize(msg)
265 smsg = session.serialize(msg)
266 msg2 = session.unserialize(session.feed_identities(smsg)[1])
266 msg2 = session.unserialize(session.feed_identities(smsg)[1])
267 assert isinstance(msg2['header']['date'], datetime)
267 assert isinstance(msg2['header']['date'], datetime)
268 self.assertEqual(msg['header'], msg2['header'])
268 self.assertEqual(msg['header'], msg2['header'])
269 self.assertEqual(msg['parent_header'], msg2['parent_header'])
269 self.assertEqual(msg['parent_header'], msg2['parent_header'])
270 self.assertEqual(msg['parent_header'], msg2['parent_header'])
270 self.assertEqual(msg['parent_header'], msg2['parent_header'])
271 assert isinstance(msg['content']['t'], datetime)
271 assert isinstance(msg['content']['t'], datetime)
272 assert isinstance(msg['metadata']['t'], datetime)
272 assert isinstance(msg['metadata']['t'], datetime)
273 assert isinstance(msg2['content']['t'], string_types)
273 assert isinstance(msg2['content']['t'], string_types)
274 assert isinstance(msg2['metadata']['t'], string_types)
274 assert isinstance(msg2['metadata']['t'], string_types)
275 self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
275 self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
276 self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
276 self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
277
277
278 def test_datetimes(self):
278 def test_datetimes(self):
279 self._datetime_test(self.session)
279 self._datetime_test(self.session)
280
280
281 def test_datetimes_pickle(self):
281 def test_datetimes_pickle(self):
282 session = ss.Session(packer='pickle')
282 session = ss.Session(packer='pickle')
283 self._datetime_test(session)
283 self._datetime_test(session)
284
284
285 @skipif(module_not_available('msgpack'))
285 @skipif(module_not_available('msgpack'))
286 def test_datetimes_msgpack(self):
286 def test_datetimes_msgpack(self):
287 session = ss.Session(packer='msgpack.packb', unpacker='msgpack.unpackb')
287 session = ss.Session(packer='msgpack.packb', unpacker='msgpack.unpackb')
288 self._datetime_test(session)
288 self._datetime_test(session)
289
289
290 def test_send_raw(self):
291 ctx = zmq.Context.instance()
292 A = ctx.socket(zmq.PAIR)
293 B = ctx.socket(zmq.PAIR)
294 A.bind("inproc://test")
295 B.connect("inproc://test")
296
297 msg = self.session.msg('execute', content=dict(a=10))
298 msg_list = [self.session.pack(msg[part]) for part in
299 ['header', 'parent_header', 'metadata', 'content']]
300 self.session.send_raw(A, msg_list, ident=b'foo')
301
302 ident, new_msg_list = self.session.feed_identities(B.recv_multipart())
303 new_msg = self.session.unserialize(new_msg_list)
304 self.assertEqual(ident[0], b'foo')
305 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
306 self.assertEqual(new_msg['header'],msg['header'])
307 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
308 self.assertEqual(new_msg['content'],msg['content'])
309 self.assertEqual(new_msg['metadata'],msg['metadata'])
310
311 A.close()
312 B.close()
313 ctx.term()
General Comments 0
You need to be logged in to leave comments. Login now