##// END OF EJS Templates
Extract session buffers as memoryviews...
Jason Grout -
Show More
@@ -1,873 +1,876 b''
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9 """
9 """
10
10
11 # Copyright (c) IPython Development Team.
11 # Copyright (c) IPython Development Team.
12 # Distributed under the terms of the Modified BSD License.
12 # Distributed under the terms of the Modified BSD License.
13
13
14 import hashlib
14 import hashlib
15 import hmac
15 import hmac
16 import logging
16 import logging
17 import os
17 import os
18 import pprint
18 import pprint
19 import random
19 import random
20 import uuid
20 import uuid
21 import warnings
21 import warnings
22 from datetime import datetime
22 from datetime import datetime
23
23
24 try:
24 try:
25 import cPickle
25 import cPickle
26 pickle = cPickle
26 pickle = cPickle
27 except:
27 except:
28 cPickle = None
28 cPickle = None
29 import pickle
29 import pickle
30
30
31 try:
31 try:
32 # We are using compare_digest to limit the surface of timing attacks
32 # We are using compare_digest to limit the surface of timing attacks
33 from hmac import compare_digest
33 from hmac import compare_digest
34 except ImportError:
34 except ImportError:
35 # Python < 2.7.7: When digests don't match no feedback is provided,
35 # Python < 2.7.7: When digests don't match no feedback is provided,
36 # limiting the surface of attack
36 # limiting the surface of attack
37 def compare_digest(a,b): return a == b
37 def compare_digest(a,b): return a == b
38
38
39 import zmq
39 import zmq
40 from zmq.utils import jsonapi
40 from zmq.utils import jsonapi
41 from zmq.eventloop.ioloop import IOLoop
41 from zmq.eventloop.ioloop import IOLoop
42 from zmq.eventloop.zmqstream import ZMQStream
42 from zmq.eventloop.zmqstream import ZMQStream
43
43
44 from IPython.core.release import kernel_protocol_version
44 from IPython.core.release import kernel_protocol_version
45 from IPython.config.configurable import Configurable, LoggingConfigurable
45 from IPython.config.configurable import Configurable, LoggingConfigurable
46 from IPython.utils import io
46 from IPython.utils import io
47 from IPython.utils.importstring import import_item
47 from IPython.utils.importstring import import_item
48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
49 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
49 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
50 iteritems)
50 iteritems)
51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
52 DottedObjectName, CUnicode, Dict, Integer,
52 DottedObjectName, CUnicode, Dict, Integer,
53 TraitError,
53 TraitError,
54 )
54 )
55 from IPython.utils.pickleutil import PICKLE_PROTOCOL
55 from IPython.utils.pickleutil import PICKLE_PROTOCOL
56 from IPython.kernel.adapter import adapt
56 from IPython.kernel.adapter import adapt
57 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
57 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
58
58
59 #-----------------------------------------------------------------------------
59 #-----------------------------------------------------------------------------
60 # utility functions
60 # utility functions
61 #-----------------------------------------------------------------------------
61 #-----------------------------------------------------------------------------
62
62
63 def squash_unicode(obj):
63 def squash_unicode(obj):
64 """coerce unicode back to bytestrings."""
64 """coerce unicode back to bytestrings."""
65 if isinstance(obj,dict):
65 if isinstance(obj,dict):
66 for key in obj.keys():
66 for key in obj.keys():
67 obj[key] = squash_unicode(obj[key])
67 obj[key] = squash_unicode(obj[key])
68 if isinstance(key, unicode_type):
68 if isinstance(key, unicode_type):
69 obj[squash_unicode(key)] = obj.pop(key)
69 obj[squash_unicode(key)] = obj.pop(key)
70 elif isinstance(obj, list):
70 elif isinstance(obj, list):
71 for i,v in enumerate(obj):
71 for i,v in enumerate(obj):
72 obj[i] = squash_unicode(v)
72 obj[i] = squash_unicode(v)
73 elif isinstance(obj, unicode_type):
73 elif isinstance(obj, unicode_type):
74 obj = obj.encode('utf8')
74 obj = obj.encode('utf8')
75 return obj
75 return obj
76
76
77 #-----------------------------------------------------------------------------
77 #-----------------------------------------------------------------------------
78 # globals and defaults
78 # globals and defaults
79 #-----------------------------------------------------------------------------
79 #-----------------------------------------------------------------------------
80
80
81 # ISO8601-ify datetime objects
81 # ISO8601-ify datetime objects
82 # allow unicode
82 # allow unicode
83 # disallow nan, because it's not actually valid JSON
83 # disallow nan, because it's not actually valid JSON
84 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
84 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
85 ensure_ascii=False, allow_nan=False,
85 ensure_ascii=False, allow_nan=False,
86 )
86 )
87 json_unpacker = lambda s: jsonapi.loads(s)
87 json_unpacker = lambda s: jsonapi.loads(s)
88
88
89 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
89 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
90 pickle_unpacker = pickle.loads
90 pickle_unpacker = pickle.loads
91
91
92 default_packer = json_packer
92 default_packer = json_packer
93 default_unpacker = json_unpacker
93 default_unpacker = json_unpacker
94
94
95 DELIM = b"<IDS|MSG>"
95 DELIM = b"<IDS|MSG>"
96 # singleton dummy tracker, which will always report as done
96 # singleton dummy tracker, which will always report as done
97 DONE = zmq.MessageTracker()
97 DONE = zmq.MessageTracker()
98
98
99 #-----------------------------------------------------------------------------
99 #-----------------------------------------------------------------------------
100 # Mixin tools for apps that use Sessions
100 # Mixin tools for apps that use Sessions
101 #-----------------------------------------------------------------------------
101 #-----------------------------------------------------------------------------
102
102
103 session_aliases = dict(
103 session_aliases = dict(
104 ident = 'Session.session',
104 ident = 'Session.session',
105 user = 'Session.username',
105 user = 'Session.username',
106 keyfile = 'Session.keyfile',
106 keyfile = 'Session.keyfile',
107 )
107 )
108
108
109 session_flags = {
109 session_flags = {
110 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
110 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
111 'keyfile' : '' }},
111 'keyfile' : '' }},
112 """Use HMAC digests for authentication of messages.
112 """Use HMAC digests for authentication of messages.
113 Setting this flag will generate a new UUID to use as the HMAC key.
113 Setting this flag will generate a new UUID to use as the HMAC key.
114 """),
114 """),
115 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
115 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
116 """Don't authenticate messages."""),
116 """Don't authenticate messages."""),
117 }
117 }
118
118
119 def default_secure(cfg):
119 def default_secure(cfg):
120 """Set the default behavior for a config environment to be secure.
120 """Set the default behavior for a config environment to be secure.
121
121
122 If Session.key/keyfile have not been set, set Session.key to
122 If Session.key/keyfile have not been set, set Session.key to
123 a new random UUID.
123 a new random UUID.
124 """
124 """
125
125
126 if 'Session' in cfg:
126 if 'Session' in cfg:
127 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
127 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
128 return
128 return
129 # key/keyfile not specified, generate new UUID:
129 # key/keyfile not specified, generate new UUID:
130 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
130 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
131
131
132
132
133 #-----------------------------------------------------------------------------
133 #-----------------------------------------------------------------------------
134 # Classes
134 # Classes
135 #-----------------------------------------------------------------------------
135 #-----------------------------------------------------------------------------
136
136
137 class SessionFactory(LoggingConfigurable):
137 class SessionFactory(LoggingConfigurable):
138 """The Base class for configurables that have a Session, Context, logger,
138 """The Base class for configurables that have a Session, Context, logger,
139 and IOLoop.
139 and IOLoop.
140 """
140 """
141
141
142 logname = Unicode('')
142 logname = Unicode('')
143 def _logname_changed(self, name, old, new):
143 def _logname_changed(self, name, old, new):
144 self.log = logging.getLogger(new)
144 self.log = logging.getLogger(new)
145
145
146 # not configurable:
146 # not configurable:
147 context = Instance('zmq.Context')
147 context = Instance('zmq.Context')
148 def _context_default(self):
148 def _context_default(self):
149 return zmq.Context.instance()
149 return zmq.Context.instance()
150
150
151 session = Instance('IPython.kernel.zmq.session.Session')
151 session = Instance('IPython.kernel.zmq.session.Session')
152
152
153 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
153 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
154 def _loop_default(self):
154 def _loop_default(self):
155 return IOLoop.instance()
155 return IOLoop.instance()
156
156
157 def __init__(self, **kwargs):
157 def __init__(self, **kwargs):
158 super(SessionFactory, self).__init__(**kwargs)
158 super(SessionFactory, self).__init__(**kwargs)
159
159
160 if self.session is None:
160 if self.session is None:
161 # construct the session
161 # construct the session
162 self.session = Session(**kwargs)
162 self.session = Session(**kwargs)
163
163
164
164
165 class Message(object):
165 class Message(object):
166 """A simple message object that maps dict keys to attributes.
166 """A simple message object that maps dict keys to attributes.
167
167
168 A Message can be created from a dict and a dict from a Message instance
168 A Message can be created from a dict and a dict from a Message instance
169 simply by calling dict(msg_obj)."""
169 simply by calling dict(msg_obj)."""
170
170
171 def __init__(self, msg_dict):
171 def __init__(self, msg_dict):
172 dct = self.__dict__
172 dct = self.__dict__
173 for k, v in iteritems(dict(msg_dict)):
173 for k, v in iteritems(dict(msg_dict)):
174 if isinstance(v, dict):
174 if isinstance(v, dict):
175 v = Message(v)
175 v = Message(v)
176 dct[k] = v
176 dct[k] = v
177
177
178 # Having this iterator lets dict(msg_obj) work out of the box.
178 # Having this iterator lets dict(msg_obj) work out of the box.
179 def __iter__(self):
179 def __iter__(self):
180 return iter(iteritems(self.__dict__))
180 return iter(iteritems(self.__dict__))
181
181
182 def __repr__(self):
182 def __repr__(self):
183 return repr(self.__dict__)
183 return repr(self.__dict__)
184
184
185 def __str__(self):
185 def __str__(self):
186 return pprint.pformat(self.__dict__)
186 return pprint.pformat(self.__dict__)
187
187
188 def __contains__(self, k):
188 def __contains__(self, k):
189 return k in self.__dict__
189 return k in self.__dict__
190
190
191 def __getitem__(self, k):
191 def __getitem__(self, k):
192 return self.__dict__[k]
192 return self.__dict__[k]
193
193
194
194
195 def msg_header(msg_id, msg_type, username, session):
195 def msg_header(msg_id, msg_type, username, session):
196 date = datetime.now()
196 date = datetime.now()
197 version = kernel_protocol_version
197 version = kernel_protocol_version
198 return locals()
198 return locals()
199
199
200 def extract_header(msg_or_header):
200 def extract_header(msg_or_header):
201 """Given a message or header, return the header."""
201 """Given a message or header, return the header."""
202 if not msg_or_header:
202 if not msg_or_header:
203 return {}
203 return {}
204 try:
204 try:
205 # See if msg_or_header is the entire message.
205 # See if msg_or_header is the entire message.
206 h = msg_or_header['header']
206 h = msg_or_header['header']
207 except KeyError:
207 except KeyError:
208 try:
208 try:
209 # See if msg_or_header is just the header
209 # See if msg_or_header is just the header
210 h = msg_or_header['msg_id']
210 h = msg_or_header['msg_id']
211 except KeyError:
211 except KeyError:
212 raise
212 raise
213 else:
213 else:
214 h = msg_or_header
214 h = msg_or_header
215 if not isinstance(h, dict):
215 if not isinstance(h, dict):
216 h = dict(h)
216 h = dict(h)
217 return h
217 return h
218
218
219 class Session(Configurable):
219 class Session(Configurable):
220 """Object for handling serialization and sending of messages.
220 """Object for handling serialization and sending of messages.
221
221
222 The Session object handles building messages and sending them
222 The Session object handles building messages and sending them
223 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
223 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
224 other over the network via Session objects, and only need to work with the
224 other over the network via Session objects, and only need to work with the
225 dict-based IPython message spec. The Session will handle
225 dict-based IPython message spec. The Session will handle
226 serialization/deserialization, security, and metadata.
226 serialization/deserialization, security, and metadata.
227
227
228 Sessions support configurable serialization via packer/unpacker traits,
228 Sessions support configurable serialization via packer/unpacker traits,
229 and signing with HMAC digests via the key/keyfile traits.
229 and signing with HMAC digests via the key/keyfile traits.
230
230
231 Parameters
231 Parameters
232 ----------
232 ----------
233
233
234 debug : bool
234 debug : bool
235 whether to trigger extra debugging statements
235 whether to trigger extra debugging statements
236 packer/unpacker : str : 'json', 'pickle' or import_string
236 packer/unpacker : str : 'json', 'pickle' or import_string
237 importstrings for methods to serialize message parts. If just
237 importstrings for methods to serialize message parts. If just
238 'json' or 'pickle', predefined JSON and pickle packers will be used.
238 'json' or 'pickle', predefined JSON and pickle packers will be used.
239 Otherwise, the entire importstring must be used.
239 Otherwise, the entire importstring must be used.
240
240
241 The functions must accept at least valid JSON input, and output *bytes*.
241 The functions must accept at least valid JSON input, and output *bytes*.
242
242
243 For example, to use msgpack:
243 For example, to use msgpack:
244 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
244 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
245 pack/unpack : callables
245 pack/unpack : callables
246 You can also set the pack/unpack callables for serialization directly.
246 You can also set the pack/unpack callables for serialization directly.
247 session : bytes
247 session : bytes
248 the ID of this Session object. The default is to generate a new UUID.
248 the ID of this Session object. The default is to generate a new UUID.
249 username : unicode
249 username : unicode
250 username added to message headers. The default is to ask the OS.
250 username added to message headers. The default is to ask the OS.
251 key : bytes
251 key : bytes
252 The key used to initialize an HMAC signature. If unset, messages
252 The key used to initialize an HMAC signature. If unset, messages
253 will not be signed or checked.
253 will not be signed or checked.
254 keyfile : filepath
254 keyfile : filepath
255 The file containing a key. If this is set, `key` will be initialized
255 The file containing a key. If this is set, `key` will be initialized
256 to the contents of the file.
256 to the contents of the file.
257
257
258 """
258 """
259
259
260 debug=Bool(False, config=True, help="""Debug output in the Session""")
260 debug=Bool(False, config=True, help="""Debug output in the Session""")
261
261
262 packer = DottedObjectName('json',config=True,
262 packer = DottedObjectName('json',config=True,
263 help="""The name of the packer for serializing messages.
263 help="""The name of the packer for serializing messages.
264 Should be one of 'json', 'pickle', or an import name
264 Should be one of 'json', 'pickle', or an import name
265 for a custom callable serializer.""")
265 for a custom callable serializer.""")
266 def _packer_changed(self, name, old, new):
266 def _packer_changed(self, name, old, new):
267 if new.lower() == 'json':
267 if new.lower() == 'json':
268 self.pack = json_packer
268 self.pack = json_packer
269 self.unpack = json_unpacker
269 self.unpack = json_unpacker
270 self.unpacker = new
270 self.unpacker = new
271 elif new.lower() == 'pickle':
271 elif new.lower() == 'pickle':
272 self.pack = pickle_packer
272 self.pack = pickle_packer
273 self.unpack = pickle_unpacker
273 self.unpack = pickle_unpacker
274 self.unpacker = new
274 self.unpacker = new
275 else:
275 else:
276 self.pack = import_item(str(new))
276 self.pack = import_item(str(new))
277
277
278 unpacker = DottedObjectName('json', config=True,
278 unpacker = DottedObjectName('json', config=True,
279 help="""The name of the unpacker for unserializing messages.
279 help="""The name of the unpacker for unserializing messages.
280 Only used with custom functions for `packer`.""")
280 Only used with custom functions for `packer`.""")
281 def _unpacker_changed(self, name, old, new):
281 def _unpacker_changed(self, name, old, new):
282 if new.lower() == 'json':
282 if new.lower() == 'json':
283 self.pack = json_packer
283 self.pack = json_packer
284 self.unpack = json_unpacker
284 self.unpack = json_unpacker
285 self.packer = new
285 self.packer = new
286 elif new.lower() == 'pickle':
286 elif new.lower() == 'pickle':
287 self.pack = pickle_packer
287 self.pack = pickle_packer
288 self.unpack = pickle_unpacker
288 self.unpack = pickle_unpacker
289 self.packer = new
289 self.packer = new
290 else:
290 else:
291 self.unpack = import_item(str(new))
291 self.unpack = import_item(str(new))
292
292
293 session = CUnicode(u'', config=True,
293 session = CUnicode(u'', config=True,
294 help="""The UUID identifying this session.""")
294 help="""The UUID identifying this session.""")
295 def _session_default(self):
295 def _session_default(self):
296 u = unicode_type(uuid.uuid4())
296 u = unicode_type(uuid.uuid4())
297 self.bsession = u.encode('ascii')
297 self.bsession = u.encode('ascii')
298 return u
298 return u
299
299
300 def _session_changed(self, name, old, new):
300 def _session_changed(self, name, old, new):
301 self.bsession = self.session.encode('ascii')
301 self.bsession = self.session.encode('ascii')
302
302
303 # bsession is the session as bytes
303 # bsession is the session as bytes
304 bsession = CBytes(b'')
304 bsession = CBytes(b'')
305
305
306 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
306 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
307 help="""Username for the Session. Default is your system username.""",
307 help="""Username for the Session. Default is your system username.""",
308 config=True)
308 config=True)
309
309
310 metadata = Dict({}, config=True,
310 metadata = Dict({}, config=True,
311 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
311 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
312
312
313 # if 0, no adapting to do.
313 # if 0, no adapting to do.
314 adapt_version = Integer(0)
314 adapt_version = Integer(0)
315
315
316 # message signature related traits:
316 # message signature related traits:
317
317
318 key = CBytes(b'', config=True,
318 key = CBytes(b'', config=True,
319 help="""execution key, for extra authentication.""")
319 help="""execution key, for extra authentication.""")
320 def _key_changed(self):
320 def _key_changed(self):
321 self._new_auth()
321 self._new_auth()
322
322
323 signature_scheme = Unicode('hmac-sha256', config=True,
323 signature_scheme = Unicode('hmac-sha256', config=True,
324 help="""The digest scheme used to construct the message signatures.
324 help="""The digest scheme used to construct the message signatures.
325 Must have the form 'hmac-HASH'.""")
325 Must have the form 'hmac-HASH'.""")
326 def _signature_scheme_changed(self, name, old, new):
326 def _signature_scheme_changed(self, name, old, new):
327 if not new.startswith('hmac-'):
327 if not new.startswith('hmac-'):
328 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
328 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
329 hash_name = new.split('-', 1)[1]
329 hash_name = new.split('-', 1)[1]
330 try:
330 try:
331 self.digest_mod = getattr(hashlib, hash_name)
331 self.digest_mod = getattr(hashlib, hash_name)
332 except AttributeError:
332 except AttributeError:
333 raise TraitError("hashlib has no such attribute: %s" % hash_name)
333 raise TraitError("hashlib has no such attribute: %s" % hash_name)
334 self._new_auth()
334 self._new_auth()
335
335
336 digest_mod = Any()
336 digest_mod = Any()
337 def _digest_mod_default(self):
337 def _digest_mod_default(self):
338 return hashlib.sha256
338 return hashlib.sha256
339
339
340 auth = Instance(hmac.HMAC)
340 auth = Instance(hmac.HMAC)
341
341
342 def _new_auth(self):
342 def _new_auth(self):
343 if self.key:
343 if self.key:
344 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
344 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
345 else:
345 else:
346 self.auth = None
346 self.auth = None
347
347
348 digest_history = Set()
348 digest_history = Set()
349 digest_history_size = Integer(2**16, config=True,
349 digest_history_size = Integer(2**16, config=True,
350 help="""The maximum number of digests to remember.
350 help="""The maximum number of digests to remember.
351
351
352 The digest history will be culled when it exceeds this value.
352 The digest history will be culled when it exceeds this value.
353 """
353 """
354 )
354 )
355
355
356 keyfile = Unicode('', config=True,
356 keyfile = Unicode('', config=True,
357 help="""path to file containing execution key.""")
357 help="""path to file containing execution key.""")
358 def _keyfile_changed(self, name, old, new):
358 def _keyfile_changed(self, name, old, new):
359 with open(new, 'rb') as f:
359 with open(new, 'rb') as f:
360 self.key = f.read().strip()
360 self.key = f.read().strip()
361
361
362 # for protecting against sends from forks
362 # for protecting against sends from forks
363 pid = Integer()
363 pid = Integer()
364
364
365 # serialization traits:
365 # serialization traits:
366
366
367 pack = Any(default_packer) # the actual packer function
367 pack = Any(default_packer) # the actual packer function
368 def _pack_changed(self, name, old, new):
368 def _pack_changed(self, name, old, new):
369 if not callable(new):
369 if not callable(new):
370 raise TypeError("packer must be callable, not %s"%type(new))
370 raise TypeError("packer must be callable, not %s"%type(new))
371
371
372 unpack = Any(default_unpacker) # the actual packer function
372 unpack = Any(default_unpacker) # the actual packer function
373 def _unpack_changed(self, name, old, new):
373 def _unpack_changed(self, name, old, new):
374 # unpacker is not checked - it is assumed to be
374 # unpacker is not checked - it is assumed to be
375 if not callable(new):
375 if not callable(new):
376 raise TypeError("unpacker must be callable, not %s"%type(new))
376 raise TypeError("unpacker must be callable, not %s"%type(new))
377
377
378 # thresholds:
378 # thresholds:
379 copy_threshold = Integer(2**16, config=True,
379 copy_threshold = Integer(2**16, config=True,
380 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
380 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
381 buffer_threshold = Integer(MAX_BYTES, config=True,
381 buffer_threshold = Integer(MAX_BYTES, config=True,
382 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
382 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
383 item_threshold = Integer(MAX_ITEMS, config=True,
383 item_threshold = Integer(MAX_ITEMS, config=True,
384 help="""The maximum number of items for a container to be introspected for custom serialization.
384 help="""The maximum number of items for a container to be introspected for custom serialization.
385 Containers larger than this are pickled outright.
385 Containers larger than this are pickled outright.
386 """
386 """
387 )
387 )
388
388
389
389
390 def __init__(self, **kwargs):
390 def __init__(self, **kwargs):
391 """create a Session object
391 """create a Session object
392
392
393 Parameters
393 Parameters
394 ----------
394 ----------
395
395
396 debug : bool
396 debug : bool
397 whether to trigger extra debugging statements
397 whether to trigger extra debugging statements
398 packer/unpacker : str : 'json', 'pickle' or import_string
398 packer/unpacker : str : 'json', 'pickle' or import_string
399 importstrings for methods to serialize message parts. If just
399 importstrings for methods to serialize message parts. If just
400 'json' or 'pickle', predefined JSON and pickle packers will be used.
400 'json' or 'pickle', predefined JSON and pickle packers will be used.
401 Otherwise, the entire importstring must be used.
401 Otherwise, the entire importstring must be used.
402
402
403 The functions must accept at least valid JSON input, and output
403 The functions must accept at least valid JSON input, and output
404 *bytes*.
404 *bytes*.
405
405
406 For example, to use msgpack:
406 For example, to use msgpack:
407 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
407 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
408 pack/unpack : callables
408 pack/unpack : callables
409 You can also set the pack/unpack callables for serialization
409 You can also set the pack/unpack callables for serialization
410 directly.
410 directly.
411 session : unicode (must be ascii)
411 session : unicode (must be ascii)
412 the ID of this Session object. The default is to generate a new
412 the ID of this Session object. The default is to generate a new
413 UUID.
413 UUID.
414 bsession : bytes
414 bsession : bytes
415 The session as bytes
415 The session as bytes
416 username : unicode
416 username : unicode
417 username added to message headers. The default is to ask the OS.
417 username added to message headers. The default is to ask the OS.
418 key : bytes
418 key : bytes
419 The key used to initialize an HMAC signature. If unset, messages
419 The key used to initialize an HMAC signature. If unset, messages
420 will not be signed or checked.
420 will not be signed or checked.
421 signature_scheme : str
421 signature_scheme : str
422 The message digest scheme. Currently must be of the form 'hmac-HASH',
422 The message digest scheme. Currently must be of the form 'hmac-HASH',
423 where 'HASH' is a hashing function available in Python's hashlib.
423 where 'HASH' is a hashing function available in Python's hashlib.
424 The default is 'hmac-sha256'.
424 The default is 'hmac-sha256'.
425 This is ignored if 'key' is empty.
425 This is ignored if 'key' is empty.
426 keyfile : filepath
426 keyfile : filepath
427 The file containing a key. If this is set, `key` will be
427 The file containing a key. If this is set, `key` will be
428 initialized to the contents of the file.
428 initialized to the contents of the file.
429 """
429 """
430 super(Session, self).__init__(**kwargs)
430 super(Session, self).__init__(**kwargs)
431 self._check_packers()
431 self._check_packers()
432 self.none = self.pack({})
432 self.none = self.pack({})
433 # ensure self._session_default() if necessary, so bsession is defined:
433 # ensure self._session_default() if necessary, so bsession is defined:
434 self.session
434 self.session
435 self.pid = os.getpid()
435 self.pid = os.getpid()
436
436
437 @property
437 @property
438 def msg_id(self):
438 def msg_id(self):
439 """always return new uuid"""
439 """always return new uuid"""
440 return str(uuid.uuid4())
440 return str(uuid.uuid4())
441
441
442 def _check_packers(self):
442 def _check_packers(self):
443 """check packers for datetime support."""
443 """check packers for datetime support."""
444 pack = self.pack
444 pack = self.pack
445 unpack = self.unpack
445 unpack = self.unpack
446
446
447 # check simple serialization
447 # check simple serialization
448 msg = dict(a=[1,'hi'])
448 msg = dict(a=[1,'hi'])
449 try:
449 try:
450 packed = pack(msg)
450 packed = pack(msg)
451 except Exception as e:
451 except Exception as e:
452 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
452 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
453 if self.packer == 'json':
453 if self.packer == 'json':
454 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
454 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
455 else:
455 else:
456 jsonmsg = ""
456 jsonmsg = ""
457 raise ValueError(
457 raise ValueError(
458 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
458 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
459 )
459 )
460
460
461 # ensure packed message is bytes
461 # ensure packed message is bytes
462 if not isinstance(packed, bytes):
462 if not isinstance(packed, bytes):
463 raise ValueError("message packed to %r, but bytes are required"%type(packed))
463 raise ValueError("message packed to %r, but bytes are required"%type(packed))
464
464
465 # check that unpack is pack's inverse
465 # check that unpack is pack's inverse
466 try:
466 try:
467 unpacked = unpack(packed)
467 unpacked = unpack(packed)
468 assert unpacked == msg
468 assert unpacked == msg
469 except Exception as e:
469 except Exception as e:
470 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
470 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
471 if self.packer == 'json':
471 if self.packer == 'json':
472 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
472 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
473 else:
473 else:
474 jsonmsg = ""
474 jsonmsg = ""
475 raise ValueError(
475 raise ValueError(
476 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
476 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
477 )
477 )
478
478
479 # check datetime support
479 # check datetime support
480 msg = dict(t=datetime.now())
480 msg = dict(t=datetime.now())
481 try:
481 try:
482 unpacked = unpack(pack(msg))
482 unpacked = unpack(pack(msg))
483 if isinstance(unpacked['t'], datetime):
483 if isinstance(unpacked['t'], datetime):
484 raise ValueError("Shouldn't deserialize to datetime")
484 raise ValueError("Shouldn't deserialize to datetime")
485 except Exception:
485 except Exception:
486 self.pack = lambda o: pack(squash_dates(o))
486 self.pack = lambda o: pack(squash_dates(o))
487 self.unpack = lambda s: unpack(s)
487 self.unpack = lambda s: unpack(s)
488
488
489 def msg_header(self, msg_type):
489 def msg_header(self, msg_type):
490 return msg_header(self.msg_id, msg_type, self.username, self.session)
490 return msg_header(self.msg_id, msg_type, self.username, self.session)
491
491
492 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
492 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
493 """Return the nested message dict.
493 """Return the nested message dict.
494
494
495 This format is different from what is sent over the wire. The
495 This format is different from what is sent over the wire. The
496 serialize/deserialize methods converts this nested message dict to the wire
496 serialize/deserialize methods converts this nested message dict to the wire
497 format, which is a list of message parts.
497 format, which is a list of message parts.
498 """
498 """
499 msg = {}
499 msg = {}
500 header = self.msg_header(msg_type) if header is None else header
500 header = self.msg_header(msg_type) if header is None else header
501 msg['header'] = header
501 msg['header'] = header
502 msg['msg_id'] = header['msg_id']
502 msg['msg_id'] = header['msg_id']
503 msg['msg_type'] = header['msg_type']
503 msg['msg_type'] = header['msg_type']
504 msg['parent_header'] = {} if parent is None else extract_header(parent)
504 msg['parent_header'] = {} if parent is None else extract_header(parent)
505 msg['content'] = {} if content is None else content
505 msg['content'] = {} if content is None else content
506 msg['metadata'] = self.metadata.copy()
506 msg['metadata'] = self.metadata.copy()
507 if metadata is not None:
507 if metadata is not None:
508 msg['metadata'].update(metadata)
508 msg['metadata'].update(metadata)
509 return msg
509 return msg
510
510
511 def sign(self, msg_list):
511 def sign(self, msg_list):
512 """Sign a message with HMAC digest. If no auth, return b''.
512 """Sign a message with HMAC digest. If no auth, return b''.
513
513
514 Parameters
514 Parameters
515 ----------
515 ----------
516 msg_list : list
516 msg_list : list
517 The [p_header,p_parent,p_content] part of the message list.
517 The [p_header,p_parent,p_content] part of the message list.
518 """
518 """
519 if self.auth is None:
519 if self.auth is None:
520 return b''
520 return b''
521 h = self.auth.copy()
521 h = self.auth.copy()
522 for m in msg_list:
522 for m in msg_list:
523 h.update(m)
523 h.update(m)
524 return str_to_bytes(h.hexdigest())
524 return str_to_bytes(h.hexdigest())
525
525
526 def serialize(self, msg, ident=None):
526 def serialize(self, msg, ident=None):
527 """Serialize the message components to bytes.
527 """Serialize the message components to bytes.
528
528
529 This is roughly the inverse of deserialize. The serialize/deserialize
529 This is roughly the inverse of deserialize. The serialize/deserialize
530 methods work with full message lists, whereas pack/unpack work with
530 methods work with full message lists, whereas pack/unpack work with
531 the individual message parts in the message list.
531 the individual message parts in the message list.
532
532
533 Parameters
533 Parameters
534 ----------
534 ----------
535 msg : dict or Message
535 msg : dict or Message
536 The next message dict as returned by the self.msg method.
536 The next message dict as returned by the self.msg method.
537
537
538 Returns
538 Returns
539 -------
539 -------
540 msg_list : list
540 msg_list : list
541 The list of bytes objects to be sent with the format::
541 The list of bytes objects to be sent with the format::
542
542
543 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
543 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
544 p_metadata, p_content, buffer1, buffer2, ...]
544 p_metadata, p_content, buffer1, buffer2, ...]
545
545
546 In this list, the ``p_*`` entities are the packed or serialized
546 In this list, the ``p_*`` entities are the packed or serialized
547 versions, so if JSON is used, these are utf8 encoded JSON strings.
547 versions, so if JSON is used, these are utf8 encoded JSON strings.
548 """
548 """
549 content = msg.get('content', {})
549 content = msg.get('content', {})
550 if content is None:
550 if content is None:
551 content = self.none
551 content = self.none
552 elif isinstance(content, dict):
552 elif isinstance(content, dict):
553 content = self.pack(content)
553 content = self.pack(content)
554 elif isinstance(content, bytes):
554 elif isinstance(content, bytes):
555 # content is already packed, as in a relayed message
555 # content is already packed, as in a relayed message
556 pass
556 pass
557 elif isinstance(content, unicode_type):
557 elif isinstance(content, unicode_type):
558 # should be bytes, but JSON often spits out unicode
558 # should be bytes, but JSON often spits out unicode
559 content = content.encode('utf8')
559 content = content.encode('utf8')
560 else:
560 else:
561 raise TypeError("Content incorrect type: %s"%type(content))
561 raise TypeError("Content incorrect type: %s"%type(content))
562
562
563 real_message = [self.pack(msg['header']),
563 real_message = [self.pack(msg['header']),
564 self.pack(msg['parent_header']),
564 self.pack(msg['parent_header']),
565 self.pack(msg['metadata']),
565 self.pack(msg['metadata']),
566 content,
566 content,
567 ]
567 ]
568
568
569 to_send = []
569 to_send = []
570
570
571 if isinstance(ident, list):
571 if isinstance(ident, list):
572 # accept list of idents
572 # accept list of idents
573 to_send.extend(ident)
573 to_send.extend(ident)
574 elif ident is not None:
574 elif ident is not None:
575 to_send.append(ident)
575 to_send.append(ident)
576 to_send.append(DELIM)
576 to_send.append(DELIM)
577
577
578 signature = self.sign(real_message)
578 signature = self.sign(real_message)
579 to_send.append(signature)
579 to_send.append(signature)
580
580
581 to_send.extend(real_message)
581 to_send.extend(real_message)
582
582
583 return to_send
583 return to_send
584
584
585 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
585 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
586 buffers=None, track=False, header=None, metadata=None):
586 buffers=None, track=False, header=None, metadata=None):
587 """Build and send a message via stream or socket.
587 """Build and send a message via stream or socket.
588
588
589 The message format used by this function internally is as follows:
589 The message format used by this function internally is as follows:
590
590
591 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
591 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
592 buffer1,buffer2,...]
592 buffer1,buffer2,...]
593
593
594 The serialize/deserialize methods convert the nested message dict into this
594 The serialize/deserialize methods convert the nested message dict into this
595 format.
595 format.
596
596
597 Parameters
597 Parameters
598 ----------
598 ----------
599
599
600 stream : zmq.Socket or ZMQStream
600 stream : zmq.Socket or ZMQStream
601 The socket-like object used to send the data.
601 The socket-like object used to send the data.
602 msg_or_type : str or Message/dict
602 msg_or_type : str or Message/dict
603 Normally, msg_or_type will be a msg_type unless a message is being
603 Normally, msg_or_type will be a msg_type unless a message is being
604 sent more than once. If a header is supplied, this can be set to
604 sent more than once. If a header is supplied, this can be set to
605 None and the msg_type will be pulled from the header.
605 None and the msg_type will be pulled from the header.
606
606
607 content : dict or None
607 content : dict or None
608 The content of the message (ignored if msg_or_type is a message).
608 The content of the message (ignored if msg_or_type is a message).
609 header : dict or None
609 header : dict or None
610 The header dict for the message (ignored if msg_to_type is a message).
610 The header dict for the message (ignored if msg_to_type is a message).
611 parent : Message or dict or None
611 parent : Message or dict or None
612 The parent or parent header describing the parent of this message
612 The parent or parent header describing the parent of this message
613 (ignored if msg_or_type is a message).
613 (ignored if msg_or_type is a message).
614 ident : bytes or list of bytes
614 ident : bytes or list of bytes
615 The zmq.IDENTITY routing path.
615 The zmq.IDENTITY routing path.
616 metadata : dict or None
616 metadata : dict or None
617 The metadata describing the message
617 The metadata describing the message
618 buffers : list or None
618 buffers : list or None
619 The already-serialized buffers to be appended to the message.
619 The already-serialized buffers to be appended to the message.
620 track : bool
620 track : bool
621 Whether to track. Only for use with Sockets, because ZMQStream
621 Whether to track. Only for use with Sockets, because ZMQStream
622 objects cannot track messages.
622 objects cannot track messages.
623
623
624
624
625 Returns
625 Returns
626 -------
626 -------
627 msg : dict
627 msg : dict
628 The constructed message.
628 The constructed message.
629 """
629 """
630 if not isinstance(stream, zmq.Socket):
630 if not isinstance(stream, zmq.Socket):
631 # ZMQStreams and dummy sockets do not support tracking.
631 # ZMQStreams and dummy sockets do not support tracking.
632 track = False
632 track = False
633
633
634 if isinstance(msg_or_type, (Message, dict)):
634 if isinstance(msg_or_type, (Message, dict)):
635 # We got a Message or message dict, not a msg_type so don't
635 # We got a Message or message dict, not a msg_type so don't
636 # build a new Message.
636 # build a new Message.
637 msg = msg_or_type
637 msg = msg_or_type
638 buffers = buffers or msg.get('buffers', [])
638 buffers = buffers or msg.get('buffers', [])
639 else:
639 else:
640 msg = self.msg(msg_or_type, content=content, parent=parent,
640 msg = self.msg(msg_or_type, content=content, parent=parent,
641 header=header, metadata=metadata)
641 header=header, metadata=metadata)
642 if not os.getpid() == self.pid:
642 if not os.getpid() == self.pid:
643 io.rprint("WARNING: attempted to send message from fork")
643 io.rprint("WARNING: attempted to send message from fork")
644 io.rprint(msg)
644 io.rprint(msg)
645 return
645 return
646 buffers = [] if buffers is None else buffers
646 buffers = [] if buffers is None else buffers
647 if self.adapt_version:
647 if self.adapt_version:
648 msg = adapt(msg, self.adapt_version)
648 msg = adapt(msg, self.adapt_version)
649 to_send = self.serialize(msg, ident)
649 to_send = self.serialize(msg, ident)
650 to_send.extend(buffers)
650 to_send.extend(buffers)
651 longest = max([ len(s) for s in to_send ])
651 longest = max([ len(s) for s in to_send ])
652 copy = (longest < self.copy_threshold)
652 copy = (longest < self.copy_threshold)
653
653
654 if buffers and track and not copy:
654 if buffers and track and not copy:
655 # only really track when we are doing zero-copy buffers
655 # only really track when we are doing zero-copy buffers
656 tracker = stream.send_multipart(to_send, copy=False, track=True)
656 tracker = stream.send_multipart(to_send, copy=False, track=True)
657 else:
657 else:
658 # use dummy tracker, which will be done immediately
658 # use dummy tracker, which will be done immediately
659 tracker = DONE
659 tracker = DONE
660 stream.send_multipart(to_send, copy=copy)
660 stream.send_multipart(to_send, copy=copy)
661
661
662 if self.debug:
662 if self.debug:
663 pprint.pprint(msg)
663 pprint.pprint(msg)
664 pprint.pprint(to_send)
664 pprint.pprint(to_send)
665 pprint.pprint(buffers)
665 pprint.pprint(buffers)
666
666
667 msg['tracker'] = tracker
667 msg['tracker'] = tracker
668
668
669 return msg
669 return msg
670
670
671 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
671 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
672 """Send a raw message via ident path.
672 """Send a raw message via ident path.
673
673
674 This method is used to send a already serialized message.
674 This method is used to send a already serialized message.
675
675
676 Parameters
676 Parameters
677 ----------
677 ----------
678 stream : ZMQStream or Socket
678 stream : ZMQStream or Socket
679 The ZMQ stream or socket to use for sending the message.
679 The ZMQ stream or socket to use for sending the message.
680 msg_list : list
680 msg_list : list
681 The serialized list of messages to send. This only includes the
681 The serialized list of messages to send. This only includes the
682 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
682 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
683 the message.
683 the message.
684 ident : ident or list
684 ident : ident or list
685 A single ident or a list of idents to use in sending.
685 A single ident or a list of idents to use in sending.
686 """
686 """
687 to_send = []
687 to_send = []
688 if isinstance(ident, bytes):
688 if isinstance(ident, bytes):
689 ident = [ident]
689 ident = [ident]
690 if ident is not None:
690 if ident is not None:
691 to_send.extend(ident)
691 to_send.extend(ident)
692
692
693 to_send.append(DELIM)
693 to_send.append(DELIM)
694 to_send.append(self.sign(msg_list))
694 to_send.append(self.sign(msg_list))
695 to_send.extend(msg_list)
695 to_send.extend(msg_list)
696 stream.send_multipart(to_send, flags, copy=copy)
696 stream.send_multipart(to_send, flags, copy=copy)
697
697
698 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
698 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
699 """Receive and unpack a message.
699 """Receive and unpack a message.
700
700
701 Parameters
701 Parameters
702 ----------
702 ----------
703 socket : ZMQStream or Socket
703 socket : ZMQStream or Socket
704 The socket or stream to use in receiving.
704 The socket or stream to use in receiving.
705
705
706 Returns
706 Returns
707 -------
707 -------
708 [idents], msg
708 [idents], msg
709 [idents] is a list of idents and msg is a nested message dict of
709 [idents] is a list of idents and msg is a nested message dict of
710 same format as self.msg returns.
710 same format as self.msg returns.
711 """
711 """
712 if isinstance(socket, ZMQStream):
712 if isinstance(socket, ZMQStream):
713 socket = socket.socket
713 socket = socket.socket
714 try:
714 try:
715 msg_list = socket.recv_multipart(mode, copy=copy)
715 msg_list = socket.recv_multipart(mode, copy=copy)
716 except zmq.ZMQError as e:
716 except zmq.ZMQError as e:
717 if e.errno == zmq.EAGAIN:
717 if e.errno == zmq.EAGAIN:
718 # We can convert EAGAIN to None as we know in this case
718 # We can convert EAGAIN to None as we know in this case
719 # recv_multipart won't return None.
719 # recv_multipart won't return None.
720 return None,None
720 return None,None
721 else:
721 else:
722 raise
722 raise
723 # split multipart message into identity list and message dict
723 # split multipart message into identity list and message dict
724 # invalid large messages can cause very expensive string comparisons
724 # invalid large messages can cause very expensive string comparisons
725 idents, msg_list = self.feed_identities(msg_list, copy)
725 idents, msg_list = self.feed_identities(msg_list, copy)
726 try:
726 try:
727 return idents, self.deserialize(msg_list, content=content, copy=copy)
727 return idents, self.deserialize(msg_list, content=content, copy=copy)
728 except Exception as e:
728 except Exception as e:
729 # TODO: handle it
729 # TODO: handle it
730 raise e
730 raise e
731
731
732 def feed_identities(self, msg_list, copy=True):
732 def feed_identities(self, msg_list, copy=True):
733 """Split the identities from the rest of the message.
733 """Split the identities from the rest of the message.
734
734
735 Feed until DELIM is reached, then return the prefix as idents and
735 Feed until DELIM is reached, then return the prefix as idents and
736 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
736 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
737 but that would be silly.
737 but that would be silly.
738
738
739 Parameters
739 Parameters
740 ----------
740 ----------
741 msg_list : a list of Message or bytes objects
741 msg_list : a list of Message or bytes objects
742 The message to be split.
742 The message to be split.
743 copy : bool
743 copy : bool
744 flag determining whether the arguments are bytes or Messages
744 flag determining whether the arguments are bytes or Messages
745
745
746 Returns
746 Returns
747 -------
747 -------
748 (idents, msg_list) : two lists
748 (idents, msg_list) : two lists
749 idents will always be a list of bytes, each of which is a ZMQ
749 idents will always be a list of bytes, each of which is a ZMQ
750 identity. msg_list will be a list of bytes or zmq.Messages of the
750 identity. msg_list will be a list of bytes or zmq.Messages of the
751 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
751 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
752 should be unpackable/unserializable via self.deserialize at this
752 should be unpackable/unserializable via self.deserialize at this
753 point.
753 point.
754 """
754 """
755 if copy:
755 if copy:
756 idx = msg_list.index(DELIM)
756 idx = msg_list.index(DELIM)
757 return msg_list[:idx], msg_list[idx+1:]
757 return msg_list[:idx], msg_list[idx+1:]
758 else:
758 else:
759 failed = True
759 failed = True
760 for idx,m in enumerate(msg_list):
760 for idx,m in enumerate(msg_list):
761 if m.bytes == DELIM:
761 if m.bytes == DELIM:
762 failed = False
762 failed = False
763 break
763 break
764 if failed:
764 if failed:
765 raise ValueError("DELIM not in msg_list")
765 raise ValueError("DELIM not in msg_list")
766 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
766 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
767 return [m.bytes for m in idents], msg_list
767 return [m.bytes for m in idents], msg_list
768
768
769 def _add_digest(self, signature):
769 def _add_digest(self, signature):
770 """add a digest to history to protect against replay attacks"""
770 """add a digest to history to protect against replay attacks"""
771 if self.digest_history_size == 0:
771 if self.digest_history_size == 0:
772 # no history, never add digests
772 # no history, never add digests
773 return
773 return
774
774
775 self.digest_history.add(signature)
775 self.digest_history.add(signature)
776 if len(self.digest_history) > self.digest_history_size:
776 if len(self.digest_history) > self.digest_history_size:
777 # threshold reached, cull 10%
777 # threshold reached, cull 10%
778 self._cull_digest_history()
778 self._cull_digest_history()
779
779
780 def _cull_digest_history(self):
780 def _cull_digest_history(self):
781 """cull the digest history
781 """cull the digest history
782
782
783 Removes a randomly selected 10% of the digest history
783 Removes a randomly selected 10% of the digest history
784 """
784 """
785 current = len(self.digest_history)
785 current = len(self.digest_history)
786 n_to_cull = max(int(current // 10), current - self.digest_history_size)
786 n_to_cull = max(int(current // 10), current - self.digest_history_size)
787 if n_to_cull >= current:
787 if n_to_cull >= current:
788 self.digest_history = set()
788 self.digest_history = set()
789 return
789 return
790 to_cull = random.sample(self.digest_history, n_to_cull)
790 to_cull = random.sample(self.digest_history, n_to_cull)
791 self.digest_history.difference_update(to_cull)
791 self.digest_history.difference_update(to_cull)
792
792
793 def deserialize(self, msg_list, content=True, copy=True):
793 def deserialize(self, msg_list, content=True, copy=True):
794 """Unserialize a msg_list to a nested message dict.
794 """Unserialize a msg_list to a nested message dict.
795
795
796 This is roughly the inverse of serialize. The serialize/deserialize
796 This is roughly the inverse of serialize. The serialize/deserialize
797 methods work with full message lists, whereas pack/unpack work with
797 methods work with full message lists, whereas pack/unpack work with
798 the individual message parts in the message list.
798 the individual message parts in the message list.
799
799
800 Parameters
800 Parameters
801 ----------
801 ----------
802 msg_list : list of bytes or Message objects
802 msg_list : list of bytes or Message objects
803 The list of message parts of the form [HMAC,p_header,p_parent,
803 The list of message parts of the form [HMAC,p_header,p_parent,
804 p_metadata,p_content,buffer1,buffer2,...].
804 p_metadata,p_content,buffer1,buffer2,...].
805 content : bool (True)
805 content : bool (True)
806 Whether to unpack the content dict (True), or leave it packed
806 Whether to unpack the content dict (True), or leave it packed
807 (False).
807 (False).
808 copy : bool (True)
808 copy : bool (True)
809 Whether to return the bytes (True), or the non-copying Message
809 Whether to return the bytes (True), or the non-copying Message
810 object in each place (False).
810 object in each place (False).
811
811
812 Returns
812 Returns
813 -------
813 -------
814 msg : dict
814 msg : dict
815 The nested message dict with top-level keys [header, parent_header,
815 The nested message dict with top-level keys [header, parent_header,
816 content, buffers].
816 content, buffers].
817 """
817 """
818 minlen = 5
818 minlen = 5
819 message = {}
819 message = {}
820 if not copy:
820 if not copy:
821 for i in range(minlen):
821 for i in range(minlen):
822 msg_list[i] = msg_list[i].bytes
822 msg_list[i] = msg_list[i].bytes
823 if self.auth is not None:
823 if self.auth is not None:
824 signature = msg_list[0]
824 signature = msg_list[0]
825 if not signature:
825 if not signature:
826 raise ValueError("Unsigned Message")
826 raise ValueError("Unsigned Message")
827 if signature in self.digest_history:
827 if signature in self.digest_history:
828 raise ValueError("Duplicate Signature: %r" % signature)
828 raise ValueError("Duplicate Signature: %r" % signature)
829 self._add_digest(signature)
829 self._add_digest(signature)
830 check = self.sign(msg_list[1:5])
830 check = self.sign(msg_list[1:5])
831 if not compare_digest(signature, check):
831 if not compare_digest(signature, check):
832 raise ValueError("Invalid Signature: %r" % signature)
832 raise ValueError("Invalid Signature: %r" % signature)
833 if not len(msg_list) >= minlen:
833 if not len(msg_list) >= minlen:
834 raise TypeError("malformed message, must have at least %i elements"%minlen)
834 raise TypeError("malformed message, must have at least %i elements"%minlen)
835 header = self.unpack(msg_list[1])
835 header = self.unpack(msg_list[1])
836 message['header'] = extract_dates(header)
836 message['header'] = extract_dates(header)
837 message['msg_id'] = header['msg_id']
837 message['msg_id'] = header['msg_id']
838 message['msg_type'] = header['msg_type']
838 message['msg_type'] = header['msg_type']
839 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
839 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
840 message['metadata'] = self.unpack(msg_list[3])
840 message['metadata'] = self.unpack(msg_list[3])
841 if content:
841 if content:
842 message['content'] = self.unpack(msg_list[4])
842 message['content'] = self.unpack(msg_list[4])
843 else:
843 else:
844 message['content'] = msg_list[4]
844 message['content'] = msg_list[4]
845
845 buffers = [memoryview(b) for b in msg_list[5:]]
846 message['buffers'] = msg_list[5:]
846 if buffers and buffers[0].shape is None:
847 # force copy to workaround pyzmq #646
848 buffers = [memoryview(b.bytes) for b in msg_list[5:]]
849 message['buffers'] = buffers
847 # adapt to the current version
850 # adapt to the current version
848 return adapt(message)
851 return adapt(message)
849
852
850 def unserialize(self, *args, **kwargs):
853 def unserialize(self, *args, **kwargs):
851 warnings.warn(
854 warnings.warn(
852 "Session.unserialize is deprecated. Use Session.deserialize.",
855 "Session.unserialize is deprecated. Use Session.deserialize.",
853 DeprecationWarning,
856 DeprecationWarning,
854 )
857 )
855 return self.deserialize(*args, **kwargs)
858 return self.deserialize(*args, **kwargs)
856
859
857
860
858 def test_msg2obj():
861 def test_msg2obj():
859 am = dict(x=1)
862 am = dict(x=1)
860 ao = Message(am)
863 ao = Message(am)
861 assert ao.x == am['x']
864 assert ao.x == am['x']
862
865
863 am['y'] = dict(z=1)
866 am['y'] = dict(z=1)
864 ao = Message(am)
867 ao = Message(am)
865 assert ao.y.z == am['y']['z']
868 assert ao.y.z == am['y']['z']
866
869
867 k1, k2 = 'y', 'z'
870 k1, k2 = 'y', 'z'
868 assert ao[k1][k2] == am[k1][k2]
871 assert ao[k1][k2] == am[k1][k2]
869
872
870 am2 = dict(ao)
873 am2 = dict(ao)
871 assert am['x'] == am2['x']
874 assert am['x'] == am2['x']
872 assert am['y']['z'] == am2['y']['z']
875 assert am['y']['z'] == am2['y']['z']
873
876
General Comments 0
You need to be logged in to leave comments. Login now