##// END OF EJS Templates
missing allow_none=True
Sylvain Corlay -
Show More
@@ -1,882 +1,882 b''
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9 """
9 """
10
10
11 # Copyright (c) IPython Development Team.
11 # Copyright (c) IPython Development Team.
12 # Distributed under the terms of the Modified BSD License.
12 # Distributed under the terms of the Modified BSD License.
13
13
14 import hashlib
14 import hashlib
15 import hmac
15 import hmac
16 import logging
16 import logging
17 import os
17 import os
18 import pprint
18 import pprint
19 import random
19 import random
20 import uuid
20 import uuid
21 import warnings
21 import warnings
22 from datetime import datetime
22 from datetime import datetime
23
23
24 try:
24 try:
25 import cPickle
25 import cPickle
26 pickle = cPickle
26 pickle = cPickle
27 except:
27 except:
28 cPickle = None
28 cPickle = None
29 import pickle
29 import pickle
30
30
31 try:
31 try:
32 # We are using compare_digest to limit the surface of timing attacks
32 # We are using compare_digest to limit the surface of timing attacks
33 from hmac import compare_digest
33 from hmac import compare_digest
34 except ImportError:
34 except ImportError:
35 # Python < 2.7.7: When digests don't match no feedback is provided,
35 # Python < 2.7.7: When digests don't match no feedback is provided,
36 # limiting the surface of attack
36 # limiting the surface of attack
37 def compare_digest(a,b): return a == b
37 def compare_digest(a,b): return a == b
38
38
39 import zmq
39 import zmq
40 from zmq.utils import jsonapi
40 from zmq.utils import jsonapi
41 from zmq.eventloop.ioloop import IOLoop
41 from zmq.eventloop.ioloop import IOLoop
42 from zmq.eventloop.zmqstream import ZMQStream
42 from zmq.eventloop.zmqstream import ZMQStream
43
43
44 from IPython.core.release import kernel_protocol_version
44 from IPython.core.release import kernel_protocol_version
45 from IPython.config.configurable import Configurable, LoggingConfigurable
45 from IPython.config.configurable import Configurable, LoggingConfigurable
46 from IPython.utils import io
46 from IPython.utils import io
47 from IPython.utils.importstring import import_item
47 from IPython.utils.importstring import import_item
48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
49 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
49 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
50 iteritems)
50 iteritems)
51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
52 DottedObjectName, CUnicode, Dict, Integer,
52 DottedObjectName, CUnicode, Dict, Integer,
53 TraitError,
53 TraitError,
54 )
54 )
55 from IPython.utils.pickleutil import PICKLE_PROTOCOL
55 from IPython.utils.pickleutil import PICKLE_PROTOCOL
56 from IPython.kernel.adapter import adapt
56 from IPython.kernel.adapter import adapt
57 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
57 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
58
58
59 #-----------------------------------------------------------------------------
59 #-----------------------------------------------------------------------------
60 # utility functions
60 # utility functions
61 #-----------------------------------------------------------------------------
61 #-----------------------------------------------------------------------------
62
62
63 def squash_unicode(obj):
63 def squash_unicode(obj):
64 """coerce unicode back to bytestrings."""
64 """coerce unicode back to bytestrings."""
65 if isinstance(obj,dict):
65 if isinstance(obj,dict):
66 for key in obj.keys():
66 for key in obj.keys():
67 obj[key] = squash_unicode(obj[key])
67 obj[key] = squash_unicode(obj[key])
68 if isinstance(key, unicode_type):
68 if isinstance(key, unicode_type):
69 obj[squash_unicode(key)] = obj.pop(key)
69 obj[squash_unicode(key)] = obj.pop(key)
70 elif isinstance(obj, list):
70 elif isinstance(obj, list):
71 for i,v in enumerate(obj):
71 for i,v in enumerate(obj):
72 obj[i] = squash_unicode(v)
72 obj[i] = squash_unicode(v)
73 elif isinstance(obj, unicode_type):
73 elif isinstance(obj, unicode_type):
74 obj = obj.encode('utf8')
74 obj = obj.encode('utf8')
75 return obj
75 return obj
76
76
77 #-----------------------------------------------------------------------------
77 #-----------------------------------------------------------------------------
78 # globals and defaults
78 # globals and defaults
79 #-----------------------------------------------------------------------------
79 #-----------------------------------------------------------------------------
80
80
81 # ISO8601-ify datetime objects
81 # ISO8601-ify datetime objects
82 # allow unicode
82 # allow unicode
83 # disallow nan, because it's not actually valid JSON
83 # disallow nan, because it's not actually valid JSON
84 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
84 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
85 ensure_ascii=False, allow_nan=False,
85 ensure_ascii=False, allow_nan=False,
86 )
86 )
87 json_unpacker = lambda s: jsonapi.loads(s)
87 json_unpacker = lambda s: jsonapi.loads(s)
88
88
89 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
89 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
90 pickle_unpacker = pickle.loads
90 pickle_unpacker = pickle.loads
91
91
92 default_packer = json_packer
92 default_packer = json_packer
93 default_unpacker = json_unpacker
93 default_unpacker = json_unpacker
94
94
95 DELIM = b"<IDS|MSG>"
95 DELIM = b"<IDS|MSG>"
96 # singleton dummy tracker, which will always report as done
96 # singleton dummy tracker, which will always report as done
97 DONE = zmq.MessageTracker()
97 DONE = zmq.MessageTracker()
98
98
99 #-----------------------------------------------------------------------------
99 #-----------------------------------------------------------------------------
100 # Mixin tools for apps that use Sessions
100 # Mixin tools for apps that use Sessions
101 #-----------------------------------------------------------------------------
101 #-----------------------------------------------------------------------------
102
102
103 session_aliases = dict(
103 session_aliases = dict(
104 ident = 'Session.session',
104 ident = 'Session.session',
105 user = 'Session.username',
105 user = 'Session.username',
106 keyfile = 'Session.keyfile',
106 keyfile = 'Session.keyfile',
107 )
107 )
108
108
109 session_flags = {
109 session_flags = {
110 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
110 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
111 'keyfile' : '' }},
111 'keyfile' : '' }},
112 """Use HMAC digests for authentication of messages.
112 """Use HMAC digests for authentication of messages.
113 Setting this flag will generate a new UUID to use as the HMAC key.
113 Setting this flag will generate a new UUID to use as the HMAC key.
114 """),
114 """),
115 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
115 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
116 """Don't authenticate messages."""),
116 """Don't authenticate messages."""),
117 }
117 }
118
118
119 def default_secure(cfg):
119 def default_secure(cfg):
120 """Set the default behavior for a config environment to be secure.
120 """Set the default behavior for a config environment to be secure.
121
121
122 If Session.key/keyfile have not been set, set Session.key to
122 If Session.key/keyfile have not been set, set Session.key to
123 a new random UUID.
123 a new random UUID.
124 """
124 """
125 warnings.warn("default_secure is deprecated", DeprecationWarning)
125 warnings.warn("default_secure is deprecated", DeprecationWarning)
126 if 'Session' in cfg:
126 if 'Session' in cfg:
127 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
127 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
128 return
128 return
129 # key/keyfile not specified, generate new UUID:
129 # key/keyfile not specified, generate new UUID:
130 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
130 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
131
131
132
132
133 #-----------------------------------------------------------------------------
133 #-----------------------------------------------------------------------------
134 # Classes
134 # Classes
135 #-----------------------------------------------------------------------------
135 #-----------------------------------------------------------------------------
136
136
137 class SessionFactory(LoggingConfigurable):
137 class SessionFactory(LoggingConfigurable):
138 """The Base class for configurables that have a Session, Context, logger,
138 """The Base class for configurables that have a Session, Context, logger,
139 and IOLoop.
139 and IOLoop.
140 """
140 """
141
141
142 logname = Unicode('')
142 logname = Unicode('')
143 def _logname_changed(self, name, old, new):
143 def _logname_changed(self, name, old, new):
144 self.log = logging.getLogger(new)
144 self.log = logging.getLogger(new)
145
145
146 # not configurable:
146 # not configurable:
147 context = Instance('zmq.Context')
147 context = Instance('zmq.Context')
148 def _context_default(self):
148 def _context_default(self):
149 return zmq.Context.instance()
149 return zmq.Context.instance()
150
150
151 session = Instance('IPython.kernel.zmq.session.Session',
151 session = Instance('IPython.kernel.zmq.session.Session',
152 allow_none=True)
152 allow_none=True)
153
153
154 loop = Instance('zmq.eventloop.ioloop.IOLoop')
154 loop = Instance('zmq.eventloop.ioloop.IOLoop')
155 def _loop_default(self):
155 def _loop_default(self):
156 return IOLoop.instance()
156 return IOLoop.instance()
157
157
158 def __init__(self, **kwargs):
158 def __init__(self, **kwargs):
159 super(SessionFactory, self).__init__(**kwargs)
159 super(SessionFactory, self).__init__(**kwargs)
160
160
161 if self.session is None:
161 if self.session is None:
162 # construct the session
162 # construct the session
163 self.session = Session(**kwargs)
163 self.session = Session(**kwargs)
164
164
165
165
166 class Message(object):
166 class Message(object):
167 """A simple message object that maps dict keys to attributes.
167 """A simple message object that maps dict keys to attributes.
168
168
169 A Message can be created from a dict and a dict from a Message instance
169 A Message can be created from a dict and a dict from a Message instance
170 simply by calling dict(msg_obj)."""
170 simply by calling dict(msg_obj)."""
171
171
172 def __init__(self, msg_dict):
172 def __init__(self, msg_dict):
173 dct = self.__dict__
173 dct = self.__dict__
174 for k, v in iteritems(dict(msg_dict)):
174 for k, v in iteritems(dict(msg_dict)):
175 if isinstance(v, dict):
175 if isinstance(v, dict):
176 v = Message(v)
176 v = Message(v)
177 dct[k] = v
177 dct[k] = v
178
178
179 # Having this iterator lets dict(msg_obj) work out of the box.
179 # Having this iterator lets dict(msg_obj) work out of the box.
180 def __iter__(self):
180 def __iter__(self):
181 return iter(iteritems(self.__dict__))
181 return iter(iteritems(self.__dict__))
182
182
183 def __repr__(self):
183 def __repr__(self):
184 return repr(self.__dict__)
184 return repr(self.__dict__)
185
185
186 def __str__(self):
186 def __str__(self):
187 return pprint.pformat(self.__dict__)
187 return pprint.pformat(self.__dict__)
188
188
189 def __contains__(self, k):
189 def __contains__(self, k):
190 return k in self.__dict__
190 return k in self.__dict__
191
191
192 def __getitem__(self, k):
192 def __getitem__(self, k):
193 return self.__dict__[k]
193 return self.__dict__[k]
194
194
195
195
196 def msg_header(msg_id, msg_type, username, session):
196 def msg_header(msg_id, msg_type, username, session):
197 date = datetime.now()
197 date = datetime.now()
198 version = kernel_protocol_version
198 version = kernel_protocol_version
199 return locals()
199 return locals()
200
200
201 def extract_header(msg_or_header):
201 def extract_header(msg_or_header):
202 """Given a message or header, return the header."""
202 """Given a message or header, return the header."""
203 if not msg_or_header:
203 if not msg_or_header:
204 return {}
204 return {}
205 try:
205 try:
206 # See if msg_or_header is the entire message.
206 # See if msg_or_header is the entire message.
207 h = msg_or_header['header']
207 h = msg_or_header['header']
208 except KeyError:
208 except KeyError:
209 try:
209 try:
210 # See if msg_or_header is just the header
210 # See if msg_or_header is just the header
211 h = msg_or_header['msg_id']
211 h = msg_or_header['msg_id']
212 except KeyError:
212 except KeyError:
213 raise
213 raise
214 else:
214 else:
215 h = msg_or_header
215 h = msg_or_header
216 if not isinstance(h, dict):
216 if not isinstance(h, dict):
217 h = dict(h)
217 h = dict(h)
218 return h
218 return h
219
219
220 class Session(Configurable):
220 class Session(Configurable):
221 """Object for handling serialization and sending of messages.
221 """Object for handling serialization and sending of messages.
222
222
223 The Session object handles building messages and sending them
223 The Session object handles building messages and sending them
224 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
224 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
225 other over the network via Session objects, and only need to work with the
225 other over the network via Session objects, and only need to work with the
226 dict-based IPython message spec. The Session will handle
226 dict-based IPython message spec. The Session will handle
227 serialization/deserialization, security, and metadata.
227 serialization/deserialization, security, and metadata.
228
228
229 Sessions support configurable serialization via packer/unpacker traits,
229 Sessions support configurable serialization via packer/unpacker traits,
230 and signing with HMAC digests via the key/keyfile traits.
230 and signing with HMAC digests via the key/keyfile traits.
231
231
232 Parameters
232 Parameters
233 ----------
233 ----------
234
234
235 debug : bool
235 debug : bool
236 whether to trigger extra debugging statements
236 whether to trigger extra debugging statements
237 packer/unpacker : str : 'json', 'pickle' or import_string
237 packer/unpacker : str : 'json', 'pickle' or import_string
238 importstrings for methods to serialize message parts. If just
238 importstrings for methods to serialize message parts. If just
239 'json' or 'pickle', predefined JSON and pickle packers will be used.
239 'json' or 'pickle', predefined JSON and pickle packers will be used.
240 Otherwise, the entire importstring must be used.
240 Otherwise, the entire importstring must be used.
241
241
242 The functions must accept at least valid JSON input, and output *bytes*.
242 The functions must accept at least valid JSON input, and output *bytes*.
243
243
244 For example, to use msgpack:
244 For example, to use msgpack:
245 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
245 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
246 pack/unpack : callables
246 pack/unpack : callables
247 You can also set the pack/unpack callables for serialization directly.
247 You can also set the pack/unpack callables for serialization directly.
248 session : bytes
248 session : bytes
249 the ID of this Session object. The default is to generate a new UUID.
249 the ID of this Session object. The default is to generate a new UUID.
250 username : unicode
250 username : unicode
251 username added to message headers. The default is to ask the OS.
251 username added to message headers. The default is to ask the OS.
252 key : bytes
252 key : bytes
253 The key used to initialize an HMAC signature. If unset, messages
253 The key used to initialize an HMAC signature. If unset, messages
254 will not be signed or checked.
254 will not be signed or checked.
255 keyfile : filepath
255 keyfile : filepath
256 The file containing a key. If this is set, `key` will be initialized
256 The file containing a key. If this is set, `key` will be initialized
257 to the contents of the file.
257 to the contents of the file.
258
258
259 """
259 """
260
260
261 debug=Bool(False, config=True, help="""Debug output in the Session""")
261 debug=Bool(False, config=True, help="""Debug output in the Session""")
262
262
263 packer = DottedObjectName('json',config=True,
263 packer = DottedObjectName('json',config=True,
264 help="""The name of the packer for serializing messages.
264 help="""The name of the packer for serializing messages.
265 Should be one of 'json', 'pickle', or an import name
265 Should be one of 'json', 'pickle', or an import name
266 for a custom callable serializer.""")
266 for a custom callable serializer.""")
267 def _packer_changed(self, name, old, new):
267 def _packer_changed(self, name, old, new):
268 if new.lower() == 'json':
268 if new.lower() == 'json':
269 self.pack = json_packer
269 self.pack = json_packer
270 self.unpack = json_unpacker
270 self.unpack = json_unpacker
271 self.unpacker = new
271 self.unpacker = new
272 elif new.lower() == 'pickle':
272 elif new.lower() == 'pickle':
273 self.pack = pickle_packer
273 self.pack = pickle_packer
274 self.unpack = pickle_unpacker
274 self.unpack = pickle_unpacker
275 self.unpacker = new
275 self.unpacker = new
276 else:
276 else:
277 self.pack = import_item(str(new))
277 self.pack = import_item(str(new))
278
278
279 unpacker = DottedObjectName('json', config=True,
279 unpacker = DottedObjectName('json', config=True,
280 help="""The name of the unpacker for unserializing messages.
280 help="""The name of the unpacker for unserializing messages.
281 Only used with custom functions for `packer`.""")
281 Only used with custom functions for `packer`.""")
282 def _unpacker_changed(self, name, old, new):
282 def _unpacker_changed(self, name, old, new):
283 if new.lower() == 'json':
283 if new.lower() == 'json':
284 self.pack = json_packer
284 self.pack = json_packer
285 self.unpack = json_unpacker
285 self.unpack = json_unpacker
286 self.packer = new
286 self.packer = new
287 elif new.lower() == 'pickle':
287 elif new.lower() == 'pickle':
288 self.pack = pickle_packer
288 self.pack = pickle_packer
289 self.unpack = pickle_unpacker
289 self.unpack = pickle_unpacker
290 self.packer = new
290 self.packer = new
291 else:
291 else:
292 self.unpack = import_item(str(new))
292 self.unpack = import_item(str(new))
293
293
294 session = CUnicode(u'', config=True,
294 session = CUnicode(u'', config=True,
295 help="""The UUID identifying this session.""")
295 help="""The UUID identifying this session.""")
296 def _session_default(self):
296 def _session_default(self):
297 u = unicode_type(uuid.uuid4())
297 u = unicode_type(uuid.uuid4())
298 self.bsession = u.encode('ascii')
298 self.bsession = u.encode('ascii')
299 return u
299 return u
300
300
301 def _session_changed(self, name, old, new):
301 def _session_changed(self, name, old, new):
302 self.bsession = self.session.encode('ascii')
302 self.bsession = self.session.encode('ascii')
303
303
304 # bsession is the session as bytes
304 # bsession is the session as bytes
305 bsession = CBytes(b'')
305 bsession = CBytes(b'')
306
306
307 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
307 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
308 help="""Username for the Session. Default is your system username.""",
308 help="""Username for the Session. Default is your system username.""",
309 config=True)
309 config=True)
310
310
311 metadata = Dict({}, config=True,
311 metadata = Dict({}, config=True,
312 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
312 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
313
313
314 # if 0, no adapting to do.
314 # if 0, no adapting to do.
315 adapt_version = Integer(0)
315 adapt_version = Integer(0)
316
316
317 # message signature related traits:
317 # message signature related traits:
318
318
319 key = CBytes(config=True,
319 key = CBytes(config=True,
320 help="""execution key, for signing messages.""")
320 help="""execution key, for signing messages.""")
321 def _key_default(self):
321 def _key_default(self):
322 return str_to_bytes(str(uuid.uuid4()))
322 return str_to_bytes(str(uuid.uuid4()))
323
323
324 def _key_changed(self):
324 def _key_changed(self):
325 self._new_auth()
325 self._new_auth()
326
326
327 signature_scheme = Unicode('hmac-sha256', config=True,
327 signature_scheme = Unicode('hmac-sha256', config=True,
328 help="""The digest scheme used to construct the message signatures.
328 help="""The digest scheme used to construct the message signatures.
329 Must have the form 'hmac-HASH'.""")
329 Must have the form 'hmac-HASH'.""")
330 def _signature_scheme_changed(self, name, old, new):
330 def _signature_scheme_changed(self, name, old, new):
331 if not new.startswith('hmac-'):
331 if not new.startswith('hmac-'):
332 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
332 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
333 hash_name = new.split('-', 1)[1]
333 hash_name = new.split('-', 1)[1]
334 try:
334 try:
335 self.digest_mod = getattr(hashlib, hash_name)
335 self.digest_mod = getattr(hashlib, hash_name)
336 except AttributeError:
336 except AttributeError:
337 raise TraitError("hashlib has no such attribute: %s" % hash_name)
337 raise TraitError("hashlib has no such attribute: %s" % hash_name)
338 self._new_auth()
338 self._new_auth()
339
339
340 digest_mod = Any()
340 digest_mod = Any()
341 def _digest_mod_default(self):
341 def _digest_mod_default(self):
342 return hashlib.sha256
342 return hashlib.sha256
343
343
344 auth = Instance(hmac.HMAC)
344 auth = Instance(hmac.HMAC, allow_none=True)
345
345
346 def _new_auth(self):
346 def _new_auth(self):
347 if self.key:
347 if self.key:
348 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
348 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
349 else:
349 else:
350 self.auth = None
350 self.auth = None
351
351
352 digest_history = Set()
352 digest_history = Set()
353 digest_history_size = Integer(2**16, config=True,
353 digest_history_size = Integer(2**16, config=True,
354 help="""The maximum number of digests to remember.
354 help="""The maximum number of digests to remember.
355
355
356 The digest history will be culled when it exceeds this value.
356 The digest history will be culled when it exceeds this value.
357 """
357 """
358 )
358 )
359
359
360 keyfile = Unicode('', config=True,
360 keyfile = Unicode('', config=True,
361 help="""path to file containing execution key.""")
361 help="""path to file containing execution key.""")
362 def _keyfile_changed(self, name, old, new):
362 def _keyfile_changed(self, name, old, new):
363 with open(new, 'rb') as f:
363 with open(new, 'rb') as f:
364 self.key = f.read().strip()
364 self.key = f.read().strip()
365
365
366 # for protecting against sends from forks
366 # for protecting against sends from forks
367 pid = Integer()
367 pid = Integer()
368
368
369 # serialization traits:
369 # serialization traits:
370
370
371 pack = Any(default_packer) # the actual packer function
371 pack = Any(default_packer) # the actual packer function
372 def _pack_changed(self, name, old, new):
372 def _pack_changed(self, name, old, new):
373 if not callable(new):
373 if not callable(new):
374 raise TypeError("packer must be callable, not %s"%type(new))
374 raise TypeError("packer must be callable, not %s"%type(new))
375
375
376 unpack = Any(default_unpacker) # the actual packer function
376 unpack = Any(default_unpacker) # the actual packer function
377 def _unpack_changed(self, name, old, new):
377 def _unpack_changed(self, name, old, new):
378 # unpacker is not checked - it is assumed to be
378 # unpacker is not checked - it is assumed to be
379 if not callable(new):
379 if not callable(new):
380 raise TypeError("unpacker must be callable, not %s"%type(new))
380 raise TypeError("unpacker must be callable, not %s"%type(new))
381
381
382 # thresholds:
382 # thresholds:
383 copy_threshold = Integer(2**16, config=True,
383 copy_threshold = Integer(2**16, config=True,
384 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
384 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
385 buffer_threshold = Integer(MAX_BYTES, config=True,
385 buffer_threshold = Integer(MAX_BYTES, config=True,
386 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
386 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
387 item_threshold = Integer(MAX_ITEMS, config=True,
387 item_threshold = Integer(MAX_ITEMS, config=True,
388 help="""The maximum number of items for a container to be introspected for custom serialization.
388 help="""The maximum number of items for a container to be introspected for custom serialization.
389 Containers larger than this are pickled outright.
389 Containers larger than this are pickled outright.
390 """
390 """
391 )
391 )
392
392
393
393
394 def __init__(self, **kwargs):
394 def __init__(self, **kwargs):
395 """create a Session object
395 """create a Session object
396
396
397 Parameters
397 Parameters
398 ----------
398 ----------
399
399
400 debug : bool
400 debug : bool
401 whether to trigger extra debugging statements
401 whether to trigger extra debugging statements
402 packer/unpacker : str : 'json', 'pickle' or import_string
402 packer/unpacker : str : 'json', 'pickle' or import_string
403 importstrings for methods to serialize message parts. If just
403 importstrings for methods to serialize message parts. If just
404 'json' or 'pickle', predefined JSON and pickle packers will be used.
404 'json' or 'pickle', predefined JSON and pickle packers will be used.
405 Otherwise, the entire importstring must be used.
405 Otherwise, the entire importstring must be used.
406
406
407 The functions must accept at least valid JSON input, and output
407 The functions must accept at least valid JSON input, and output
408 *bytes*.
408 *bytes*.
409
409
410 For example, to use msgpack:
410 For example, to use msgpack:
411 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
411 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
412 pack/unpack : callables
412 pack/unpack : callables
413 You can also set the pack/unpack callables for serialization
413 You can also set the pack/unpack callables for serialization
414 directly.
414 directly.
415 session : unicode (must be ascii)
415 session : unicode (must be ascii)
416 the ID of this Session object. The default is to generate a new
416 the ID of this Session object. The default is to generate a new
417 UUID.
417 UUID.
418 bsession : bytes
418 bsession : bytes
419 The session as bytes
419 The session as bytes
420 username : unicode
420 username : unicode
421 username added to message headers. The default is to ask the OS.
421 username added to message headers. The default is to ask the OS.
422 key : bytes
422 key : bytes
423 The key used to initialize an HMAC signature. If unset, messages
423 The key used to initialize an HMAC signature. If unset, messages
424 will not be signed or checked.
424 will not be signed or checked.
425 signature_scheme : str
425 signature_scheme : str
426 The message digest scheme. Currently must be of the form 'hmac-HASH',
426 The message digest scheme. Currently must be of the form 'hmac-HASH',
427 where 'HASH' is a hashing function available in Python's hashlib.
427 where 'HASH' is a hashing function available in Python's hashlib.
428 The default is 'hmac-sha256'.
428 The default is 'hmac-sha256'.
429 This is ignored if 'key' is empty.
429 This is ignored if 'key' is empty.
430 keyfile : filepath
430 keyfile : filepath
431 The file containing a key. If this is set, `key` will be
431 The file containing a key. If this is set, `key` will be
432 initialized to the contents of the file.
432 initialized to the contents of the file.
433 """
433 """
434 super(Session, self).__init__(**kwargs)
434 super(Session, self).__init__(**kwargs)
435 self._check_packers()
435 self._check_packers()
436 self.none = self.pack({})
436 self.none = self.pack({})
437 # ensure self._session_default() if necessary, so bsession is defined:
437 # ensure self._session_default() if necessary, so bsession is defined:
438 self.session
438 self.session
439 self.pid = os.getpid()
439 self.pid = os.getpid()
440 self._new_auth()
440 self._new_auth()
441
441
442 @property
442 @property
443 def msg_id(self):
443 def msg_id(self):
444 """always return new uuid"""
444 """always return new uuid"""
445 return str(uuid.uuid4())
445 return str(uuid.uuid4())
446
446
447 def _check_packers(self):
447 def _check_packers(self):
448 """check packers for datetime support."""
448 """check packers for datetime support."""
449 pack = self.pack
449 pack = self.pack
450 unpack = self.unpack
450 unpack = self.unpack
451
451
452 # check simple serialization
452 # check simple serialization
453 msg = dict(a=[1,'hi'])
453 msg = dict(a=[1,'hi'])
454 try:
454 try:
455 packed = pack(msg)
455 packed = pack(msg)
456 except Exception as e:
456 except Exception as e:
457 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
457 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
458 if self.packer == 'json':
458 if self.packer == 'json':
459 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
459 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
460 else:
460 else:
461 jsonmsg = ""
461 jsonmsg = ""
462 raise ValueError(
462 raise ValueError(
463 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
463 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
464 )
464 )
465
465
466 # ensure packed message is bytes
466 # ensure packed message is bytes
467 if not isinstance(packed, bytes):
467 if not isinstance(packed, bytes):
468 raise ValueError("message packed to %r, but bytes are required"%type(packed))
468 raise ValueError("message packed to %r, but bytes are required"%type(packed))
469
469
470 # check that unpack is pack's inverse
470 # check that unpack is pack's inverse
471 try:
471 try:
472 unpacked = unpack(packed)
472 unpacked = unpack(packed)
473 assert unpacked == msg
473 assert unpacked == msg
474 except Exception as e:
474 except Exception as e:
475 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
475 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
476 if self.packer == 'json':
476 if self.packer == 'json':
477 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
477 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
478 else:
478 else:
479 jsonmsg = ""
479 jsonmsg = ""
480 raise ValueError(
480 raise ValueError(
481 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
481 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
482 )
482 )
483
483
484 # check datetime support
484 # check datetime support
485 msg = dict(t=datetime.now())
485 msg = dict(t=datetime.now())
486 try:
486 try:
487 unpacked = unpack(pack(msg))
487 unpacked = unpack(pack(msg))
488 if isinstance(unpacked['t'], datetime):
488 if isinstance(unpacked['t'], datetime):
489 raise ValueError("Shouldn't deserialize to datetime")
489 raise ValueError("Shouldn't deserialize to datetime")
490 except Exception:
490 except Exception:
491 self.pack = lambda o: pack(squash_dates(o))
491 self.pack = lambda o: pack(squash_dates(o))
492 self.unpack = lambda s: unpack(s)
492 self.unpack = lambda s: unpack(s)
493
493
494 def msg_header(self, msg_type):
494 def msg_header(self, msg_type):
495 return msg_header(self.msg_id, msg_type, self.username, self.session)
495 return msg_header(self.msg_id, msg_type, self.username, self.session)
496
496
497 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
497 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
498 """Return the nested message dict.
498 """Return the nested message dict.
499
499
500 This format is different from what is sent over the wire. The
500 This format is different from what is sent over the wire. The
501 serialize/deserialize methods converts this nested message dict to the wire
501 serialize/deserialize methods converts this nested message dict to the wire
502 format, which is a list of message parts.
502 format, which is a list of message parts.
503 """
503 """
504 msg = {}
504 msg = {}
505 header = self.msg_header(msg_type) if header is None else header
505 header = self.msg_header(msg_type) if header is None else header
506 msg['header'] = header
506 msg['header'] = header
507 msg['msg_id'] = header['msg_id']
507 msg['msg_id'] = header['msg_id']
508 msg['msg_type'] = header['msg_type']
508 msg['msg_type'] = header['msg_type']
509 msg['parent_header'] = {} if parent is None else extract_header(parent)
509 msg['parent_header'] = {} if parent is None else extract_header(parent)
510 msg['content'] = {} if content is None else content
510 msg['content'] = {} if content is None else content
511 msg['metadata'] = self.metadata.copy()
511 msg['metadata'] = self.metadata.copy()
512 if metadata is not None:
512 if metadata is not None:
513 msg['metadata'].update(metadata)
513 msg['metadata'].update(metadata)
514 return msg
514 return msg
515
515
516 def sign(self, msg_list):
516 def sign(self, msg_list):
517 """Sign a message with HMAC digest. If no auth, return b''.
517 """Sign a message with HMAC digest. If no auth, return b''.
518
518
519 Parameters
519 Parameters
520 ----------
520 ----------
521 msg_list : list
521 msg_list : list
522 The [p_header,p_parent,p_content] part of the message list.
522 The [p_header,p_parent,p_content] part of the message list.
523 """
523 """
524 if self.auth is None:
524 if self.auth is None:
525 return b''
525 return b''
526 h = self.auth.copy()
526 h = self.auth.copy()
527 for m in msg_list:
527 for m in msg_list:
528 h.update(m)
528 h.update(m)
529 return str_to_bytes(h.hexdigest())
529 return str_to_bytes(h.hexdigest())
530
530
531 def serialize(self, msg, ident=None):
531 def serialize(self, msg, ident=None):
532 """Serialize the message components to bytes.
532 """Serialize the message components to bytes.
533
533
534 This is roughly the inverse of deserialize. The serialize/deserialize
534 This is roughly the inverse of deserialize. The serialize/deserialize
535 methods work with full message lists, whereas pack/unpack work with
535 methods work with full message lists, whereas pack/unpack work with
536 the individual message parts in the message list.
536 the individual message parts in the message list.
537
537
538 Parameters
538 Parameters
539 ----------
539 ----------
540 msg : dict or Message
540 msg : dict or Message
541 The next message dict as returned by the self.msg method.
541 The next message dict as returned by the self.msg method.
542
542
543 Returns
543 Returns
544 -------
544 -------
545 msg_list : list
545 msg_list : list
546 The list of bytes objects to be sent with the format::
546 The list of bytes objects to be sent with the format::
547
547
548 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
548 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
549 p_metadata, p_content, buffer1, buffer2, ...]
549 p_metadata, p_content, buffer1, buffer2, ...]
550
550
551 In this list, the ``p_*`` entities are the packed or serialized
551 In this list, the ``p_*`` entities are the packed or serialized
552 versions, so if JSON is used, these are utf8 encoded JSON strings.
552 versions, so if JSON is used, these are utf8 encoded JSON strings.
553 """
553 """
554 content = msg.get('content', {})
554 content = msg.get('content', {})
555 if content is None:
555 if content is None:
556 content = self.none
556 content = self.none
557 elif isinstance(content, dict):
557 elif isinstance(content, dict):
558 content = self.pack(content)
558 content = self.pack(content)
559 elif isinstance(content, bytes):
559 elif isinstance(content, bytes):
560 # content is already packed, as in a relayed message
560 # content is already packed, as in a relayed message
561 pass
561 pass
562 elif isinstance(content, unicode_type):
562 elif isinstance(content, unicode_type):
563 # should be bytes, but JSON often spits out unicode
563 # should be bytes, but JSON often spits out unicode
564 content = content.encode('utf8')
564 content = content.encode('utf8')
565 else:
565 else:
566 raise TypeError("Content incorrect type: %s"%type(content))
566 raise TypeError("Content incorrect type: %s"%type(content))
567
567
568 real_message = [self.pack(msg['header']),
568 real_message = [self.pack(msg['header']),
569 self.pack(msg['parent_header']),
569 self.pack(msg['parent_header']),
570 self.pack(msg['metadata']),
570 self.pack(msg['metadata']),
571 content,
571 content,
572 ]
572 ]
573
573
574 to_send = []
574 to_send = []
575
575
576 if isinstance(ident, list):
576 if isinstance(ident, list):
577 # accept list of idents
577 # accept list of idents
578 to_send.extend(ident)
578 to_send.extend(ident)
579 elif ident is not None:
579 elif ident is not None:
580 to_send.append(ident)
580 to_send.append(ident)
581 to_send.append(DELIM)
581 to_send.append(DELIM)
582
582
583 signature = self.sign(real_message)
583 signature = self.sign(real_message)
584 to_send.append(signature)
584 to_send.append(signature)
585
585
586 to_send.extend(real_message)
586 to_send.extend(real_message)
587
587
588 return to_send
588 return to_send
589
589
590 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
590 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
591 buffers=None, track=False, header=None, metadata=None):
591 buffers=None, track=False, header=None, metadata=None):
592 """Build and send a message via stream or socket.
592 """Build and send a message via stream or socket.
593
593
594 The message format used by this function internally is as follows:
594 The message format used by this function internally is as follows:
595
595
596 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
596 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
597 buffer1,buffer2,...]
597 buffer1,buffer2,...]
598
598
599 The serialize/deserialize methods convert the nested message dict into this
599 The serialize/deserialize methods convert the nested message dict into this
600 format.
600 format.
601
601
602 Parameters
602 Parameters
603 ----------
603 ----------
604
604
605 stream : zmq.Socket or ZMQStream
605 stream : zmq.Socket or ZMQStream
606 The socket-like object used to send the data.
606 The socket-like object used to send the data.
607 msg_or_type : str or Message/dict
607 msg_or_type : str or Message/dict
608 Normally, msg_or_type will be a msg_type unless a message is being
608 Normally, msg_or_type will be a msg_type unless a message is being
609 sent more than once. If a header is supplied, this can be set to
609 sent more than once. If a header is supplied, this can be set to
610 None and the msg_type will be pulled from the header.
610 None and the msg_type will be pulled from the header.
611
611
612 content : dict or None
612 content : dict or None
613 The content of the message (ignored if msg_or_type is a message).
613 The content of the message (ignored if msg_or_type is a message).
614 header : dict or None
614 header : dict or None
615 The header dict for the message (ignored if msg_to_type is a message).
615 The header dict for the message (ignored if msg_to_type is a message).
616 parent : Message or dict or None
616 parent : Message or dict or None
617 The parent or parent header describing the parent of this message
617 The parent or parent header describing the parent of this message
618 (ignored if msg_or_type is a message).
618 (ignored if msg_or_type is a message).
619 ident : bytes or list of bytes
619 ident : bytes or list of bytes
620 The zmq.IDENTITY routing path.
620 The zmq.IDENTITY routing path.
621 metadata : dict or None
621 metadata : dict or None
622 The metadata describing the message
622 The metadata describing the message
623 buffers : list or None
623 buffers : list or None
624 The already-serialized buffers to be appended to the message.
624 The already-serialized buffers to be appended to the message.
625 track : bool
625 track : bool
626 Whether to track. Only for use with Sockets, because ZMQStream
626 Whether to track. Only for use with Sockets, because ZMQStream
627 objects cannot track messages.
627 objects cannot track messages.
628
628
629
629
630 Returns
630 Returns
631 -------
631 -------
632 msg : dict
632 msg : dict
633 The constructed message.
633 The constructed message.
634 """
634 """
635 if not isinstance(stream, zmq.Socket):
635 if not isinstance(stream, zmq.Socket):
636 # ZMQStreams and dummy sockets do not support tracking.
636 # ZMQStreams and dummy sockets do not support tracking.
637 track = False
637 track = False
638
638
639 if isinstance(msg_or_type, (Message, dict)):
639 if isinstance(msg_or_type, (Message, dict)):
640 # We got a Message or message dict, not a msg_type so don't
640 # We got a Message or message dict, not a msg_type so don't
641 # build a new Message.
641 # build a new Message.
642 msg = msg_or_type
642 msg = msg_or_type
643 buffers = buffers or msg.get('buffers', [])
643 buffers = buffers or msg.get('buffers', [])
644 else:
644 else:
645 msg = self.msg(msg_or_type, content=content, parent=parent,
645 msg = self.msg(msg_or_type, content=content, parent=parent,
646 header=header, metadata=metadata)
646 header=header, metadata=metadata)
647 if not os.getpid() == self.pid:
647 if not os.getpid() == self.pid:
648 io.rprint("WARNING: attempted to send message from fork")
648 io.rprint("WARNING: attempted to send message from fork")
649 io.rprint(msg)
649 io.rprint(msg)
650 return
650 return
651 buffers = [] if buffers is None else buffers
651 buffers = [] if buffers is None else buffers
652 if self.adapt_version:
652 if self.adapt_version:
653 msg = adapt(msg, self.adapt_version)
653 msg = adapt(msg, self.adapt_version)
654 to_send = self.serialize(msg, ident)
654 to_send = self.serialize(msg, ident)
655 to_send.extend(buffers)
655 to_send.extend(buffers)
656 longest = max([ len(s) for s in to_send ])
656 longest = max([ len(s) for s in to_send ])
657 copy = (longest < self.copy_threshold)
657 copy = (longest < self.copy_threshold)
658
658
659 if buffers and track and not copy:
659 if buffers and track and not copy:
660 # only really track when we are doing zero-copy buffers
660 # only really track when we are doing zero-copy buffers
661 tracker = stream.send_multipart(to_send, copy=False, track=True)
661 tracker = stream.send_multipart(to_send, copy=False, track=True)
662 else:
662 else:
663 # use dummy tracker, which will be done immediately
663 # use dummy tracker, which will be done immediately
664 tracker = DONE
664 tracker = DONE
665 stream.send_multipart(to_send, copy=copy)
665 stream.send_multipart(to_send, copy=copy)
666
666
667 if self.debug:
667 if self.debug:
668 pprint.pprint(msg)
668 pprint.pprint(msg)
669 pprint.pprint(to_send)
669 pprint.pprint(to_send)
670 pprint.pprint(buffers)
670 pprint.pprint(buffers)
671
671
672 msg['tracker'] = tracker
672 msg['tracker'] = tracker
673
673
674 return msg
674 return msg
675
675
676 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
676 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
677 """Send a raw message via ident path.
677 """Send a raw message via ident path.
678
678
679 This method is used to send a already serialized message.
679 This method is used to send a already serialized message.
680
680
681 Parameters
681 Parameters
682 ----------
682 ----------
683 stream : ZMQStream or Socket
683 stream : ZMQStream or Socket
684 The ZMQ stream or socket to use for sending the message.
684 The ZMQ stream or socket to use for sending the message.
685 msg_list : list
685 msg_list : list
686 The serialized list of messages to send. This only includes the
686 The serialized list of messages to send. This only includes the
687 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
687 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
688 the message.
688 the message.
689 ident : ident or list
689 ident : ident or list
690 A single ident or a list of idents to use in sending.
690 A single ident or a list of idents to use in sending.
691 """
691 """
692 to_send = []
692 to_send = []
693 if isinstance(ident, bytes):
693 if isinstance(ident, bytes):
694 ident = [ident]
694 ident = [ident]
695 if ident is not None:
695 if ident is not None:
696 to_send.extend(ident)
696 to_send.extend(ident)
697
697
698 to_send.append(DELIM)
698 to_send.append(DELIM)
699 to_send.append(self.sign(msg_list))
699 to_send.append(self.sign(msg_list))
700 to_send.extend(msg_list)
700 to_send.extend(msg_list)
701 stream.send_multipart(to_send, flags, copy=copy)
701 stream.send_multipart(to_send, flags, copy=copy)
702
702
703 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
703 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
704 """Receive and unpack a message.
704 """Receive and unpack a message.
705
705
706 Parameters
706 Parameters
707 ----------
707 ----------
708 socket : ZMQStream or Socket
708 socket : ZMQStream or Socket
709 The socket or stream to use in receiving.
709 The socket or stream to use in receiving.
710
710
711 Returns
711 Returns
712 -------
712 -------
713 [idents], msg
713 [idents], msg
714 [idents] is a list of idents and msg is a nested message dict of
714 [idents] is a list of idents and msg is a nested message dict of
715 same format as self.msg returns.
715 same format as self.msg returns.
716 """
716 """
717 if isinstance(socket, ZMQStream):
717 if isinstance(socket, ZMQStream):
718 socket = socket.socket
718 socket = socket.socket
719 try:
719 try:
720 msg_list = socket.recv_multipart(mode, copy=copy)
720 msg_list = socket.recv_multipart(mode, copy=copy)
721 except zmq.ZMQError as e:
721 except zmq.ZMQError as e:
722 if e.errno == zmq.EAGAIN:
722 if e.errno == zmq.EAGAIN:
723 # We can convert EAGAIN to None as we know in this case
723 # We can convert EAGAIN to None as we know in this case
724 # recv_multipart won't return None.
724 # recv_multipart won't return None.
725 return None,None
725 return None,None
726 else:
726 else:
727 raise
727 raise
728 # split multipart message into identity list and message dict
728 # split multipart message into identity list and message dict
729 # invalid large messages can cause very expensive string comparisons
729 # invalid large messages can cause very expensive string comparisons
730 idents, msg_list = self.feed_identities(msg_list, copy)
730 idents, msg_list = self.feed_identities(msg_list, copy)
731 try:
731 try:
732 return idents, self.deserialize(msg_list, content=content, copy=copy)
732 return idents, self.deserialize(msg_list, content=content, copy=copy)
733 except Exception as e:
733 except Exception as e:
734 # TODO: handle it
734 # TODO: handle it
735 raise e
735 raise e
736
736
737 def feed_identities(self, msg_list, copy=True):
737 def feed_identities(self, msg_list, copy=True):
738 """Split the identities from the rest of the message.
738 """Split the identities from the rest of the message.
739
739
740 Feed until DELIM is reached, then return the prefix as idents and
740 Feed until DELIM is reached, then return the prefix as idents and
741 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
741 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
742 but that would be silly.
742 but that would be silly.
743
743
744 Parameters
744 Parameters
745 ----------
745 ----------
746 msg_list : a list of Message or bytes objects
746 msg_list : a list of Message or bytes objects
747 The message to be split.
747 The message to be split.
748 copy : bool
748 copy : bool
749 flag determining whether the arguments are bytes or Messages
749 flag determining whether the arguments are bytes or Messages
750
750
751 Returns
751 Returns
752 -------
752 -------
753 (idents, msg_list) : two lists
753 (idents, msg_list) : two lists
754 idents will always be a list of bytes, each of which is a ZMQ
754 idents will always be a list of bytes, each of which is a ZMQ
755 identity. msg_list will be a list of bytes or zmq.Messages of the
755 identity. msg_list will be a list of bytes or zmq.Messages of the
756 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
756 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
757 should be unpackable/unserializable via self.deserialize at this
757 should be unpackable/unserializable via self.deserialize at this
758 point.
758 point.
759 """
759 """
760 if copy:
760 if copy:
761 idx = msg_list.index(DELIM)
761 idx = msg_list.index(DELIM)
762 return msg_list[:idx], msg_list[idx+1:]
762 return msg_list[:idx], msg_list[idx+1:]
763 else:
763 else:
764 failed = True
764 failed = True
765 for idx,m in enumerate(msg_list):
765 for idx,m in enumerate(msg_list):
766 if m.bytes == DELIM:
766 if m.bytes == DELIM:
767 failed = False
767 failed = False
768 break
768 break
769 if failed:
769 if failed:
770 raise ValueError("DELIM not in msg_list")
770 raise ValueError("DELIM not in msg_list")
771 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
771 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
772 return [m.bytes for m in idents], msg_list
772 return [m.bytes for m in idents], msg_list
773
773
774 def _add_digest(self, signature):
774 def _add_digest(self, signature):
775 """add a digest to history to protect against replay attacks"""
775 """add a digest to history to protect against replay attacks"""
776 if self.digest_history_size == 0:
776 if self.digest_history_size == 0:
777 # no history, never add digests
777 # no history, never add digests
778 return
778 return
779
779
780 self.digest_history.add(signature)
780 self.digest_history.add(signature)
781 if len(self.digest_history) > self.digest_history_size:
781 if len(self.digest_history) > self.digest_history_size:
782 # threshold reached, cull 10%
782 # threshold reached, cull 10%
783 self._cull_digest_history()
783 self._cull_digest_history()
784
784
785 def _cull_digest_history(self):
785 def _cull_digest_history(self):
786 """cull the digest history
786 """cull the digest history
787
787
788 Removes a randomly selected 10% of the digest history
788 Removes a randomly selected 10% of the digest history
789 """
789 """
790 current = len(self.digest_history)
790 current = len(self.digest_history)
791 n_to_cull = max(int(current // 10), current - self.digest_history_size)
791 n_to_cull = max(int(current // 10), current - self.digest_history_size)
792 if n_to_cull >= current:
792 if n_to_cull >= current:
793 self.digest_history = set()
793 self.digest_history = set()
794 return
794 return
795 to_cull = random.sample(self.digest_history, n_to_cull)
795 to_cull = random.sample(self.digest_history, n_to_cull)
796 self.digest_history.difference_update(to_cull)
796 self.digest_history.difference_update(to_cull)
797
797
798 def deserialize(self, msg_list, content=True, copy=True):
798 def deserialize(self, msg_list, content=True, copy=True):
799 """Unserialize a msg_list to a nested message dict.
799 """Unserialize a msg_list to a nested message dict.
800
800
801 This is roughly the inverse of serialize. The serialize/deserialize
801 This is roughly the inverse of serialize. The serialize/deserialize
802 methods work with full message lists, whereas pack/unpack work with
802 methods work with full message lists, whereas pack/unpack work with
803 the individual message parts in the message list.
803 the individual message parts in the message list.
804
804
805 Parameters
805 Parameters
806 ----------
806 ----------
807 msg_list : list of bytes or Message objects
807 msg_list : list of bytes or Message objects
808 The list of message parts of the form [HMAC,p_header,p_parent,
808 The list of message parts of the form [HMAC,p_header,p_parent,
809 p_metadata,p_content,buffer1,buffer2,...].
809 p_metadata,p_content,buffer1,buffer2,...].
810 content : bool (True)
810 content : bool (True)
811 Whether to unpack the content dict (True), or leave it packed
811 Whether to unpack the content dict (True), or leave it packed
812 (False).
812 (False).
813 copy : bool (True)
813 copy : bool (True)
814 Whether msg_list contains bytes (True) or the non-copying Message
814 Whether msg_list contains bytes (True) or the non-copying Message
815 objects in each place (False).
815 objects in each place (False).
816
816
817 Returns
817 Returns
818 -------
818 -------
819 msg : dict
819 msg : dict
820 The nested message dict with top-level keys [header, parent_header,
820 The nested message dict with top-level keys [header, parent_header,
821 content, buffers]. The buffers are returned as memoryviews.
821 content, buffers]. The buffers are returned as memoryviews.
822 """
822 """
823 minlen = 5
823 minlen = 5
824 message = {}
824 message = {}
825 if not copy:
825 if not copy:
826 # pyzmq didn't copy the first parts of the message, so we'll do it
826 # pyzmq didn't copy the first parts of the message, so we'll do it
827 for i in range(minlen):
827 for i in range(minlen):
828 msg_list[i] = msg_list[i].bytes
828 msg_list[i] = msg_list[i].bytes
829 if self.auth is not None:
829 if self.auth is not None:
830 signature = msg_list[0]
830 signature = msg_list[0]
831 if not signature:
831 if not signature:
832 raise ValueError("Unsigned Message")
832 raise ValueError("Unsigned Message")
833 if signature in self.digest_history:
833 if signature in self.digest_history:
834 raise ValueError("Duplicate Signature: %r" % signature)
834 raise ValueError("Duplicate Signature: %r" % signature)
835 self._add_digest(signature)
835 self._add_digest(signature)
836 check = self.sign(msg_list[1:5])
836 check = self.sign(msg_list[1:5])
837 if not compare_digest(signature, check):
837 if not compare_digest(signature, check):
838 raise ValueError("Invalid Signature: %r" % signature)
838 raise ValueError("Invalid Signature: %r" % signature)
839 if not len(msg_list) >= minlen:
839 if not len(msg_list) >= minlen:
840 raise TypeError("malformed message, must have at least %i elements"%minlen)
840 raise TypeError("malformed message, must have at least %i elements"%minlen)
841 header = self.unpack(msg_list[1])
841 header = self.unpack(msg_list[1])
842 message['header'] = extract_dates(header)
842 message['header'] = extract_dates(header)
843 message['msg_id'] = header['msg_id']
843 message['msg_id'] = header['msg_id']
844 message['msg_type'] = header['msg_type']
844 message['msg_type'] = header['msg_type']
845 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
845 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
846 message['metadata'] = self.unpack(msg_list[3])
846 message['metadata'] = self.unpack(msg_list[3])
847 if content:
847 if content:
848 message['content'] = self.unpack(msg_list[4])
848 message['content'] = self.unpack(msg_list[4])
849 else:
849 else:
850 message['content'] = msg_list[4]
850 message['content'] = msg_list[4]
851 buffers = [memoryview(b) for b in msg_list[5:]]
851 buffers = [memoryview(b) for b in msg_list[5:]]
852 if buffers and buffers[0].shape is None:
852 if buffers and buffers[0].shape is None:
853 # force copy to workaround pyzmq #646
853 # force copy to workaround pyzmq #646
854 buffers = [memoryview(b.bytes) for b in msg_list[5:]]
854 buffers = [memoryview(b.bytes) for b in msg_list[5:]]
855 message['buffers'] = buffers
855 message['buffers'] = buffers
856 # adapt to the current version
856 # adapt to the current version
857 return adapt(message)
857 return adapt(message)
858
858
859 def unserialize(self, *args, **kwargs):
859 def unserialize(self, *args, **kwargs):
860 warnings.warn(
860 warnings.warn(
861 "Session.unserialize is deprecated. Use Session.deserialize.",
861 "Session.unserialize is deprecated. Use Session.deserialize.",
862 DeprecationWarning,
862 DeprecationWarning,
863 )
863 )
864 return self.deserialize(*args, **kwargs)
864 return self.deserialize(*args, **kwargs)
865
865
866
866
867 def test_msg2obj():
867 def test_msg2obj():
868 am = dict(x=1)
868 am = dict(x=1)
869 ao = Message(am)
869 ao = Message(am)
870 assert ao.x == am['x']
870 assert ao.x == am['x']
871
871
872 am['y'] = dict(z=1)
872 am['y'] = dict(z=1)
873 ao = Message(am)
873 ao = Message(am)
874 assert ao.y.z == am['y']['z']
874 assert ao.y.z == am['y']['z']
875
875
876 k1, k2 = 'y', 'z'
876 k1, k2 = 'y', 'z'
877 assert ao[k1][k2] == am[k1][k2]
877 assert ao[k1][k2] == am[k1][k2]
878
878
879 am2 = dict(ao)
879 am2 = dict(ao)
880 assert am['x'] == am2['x']
880 assert am['x'] == am2['x']
881 assert am['y']['z'] == am2['y']['z']
881 assert am['y']['z'] == am2['y']['z']
882
882
General Comments 0
You need to be logged in to leave comments. Login now