##// END OF EJS Templates
Don't use io.rprint in jupyter_client.session...
Min RK -
Show More
@@ -1,889 +1,891 b''
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9 """
9 """
10
10
11 # Copyright (c) IPython Development Team.
11 # Copyright (c) IPython Development Team.
12 # Distributed under the terms of the Modified BSD License.
12 # Distributed under the terms of the Modified BSD License.
13
13
14 import hashlib
14 import hashlib
15 import hmac
15 import hmac
16 import logging
16 import logging
17 import os
17 import os
18 import pprint
18 import pprint
19 import random
19 import random
20 import uuid
20 import uuid
21 import warnings
21 import warnings
22 from datetime import datetime
22 from datetime import datetime
23
23
24 try:
24 try:
25 import cPickle
25 import cPickle
26 pickle = cPickle
26 pickle = cPickle
27 except:
27 except:
28 cPickle = None
28 cPickle = None
29 import pickle
29 import pickle
30
30
31 try:
31 try:
32 # py3
32 # py3
33 PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
33 PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
34 except AttributeError:
34 except AttributeError:
35 PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
35 PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
36
36
37 try:
37 try:
38 # We are using compare_digest to limit the surface of timing attacks
38 # We are using compare_digest to limit the surface of timing attacks
39 from hmac import compare_digest
39 from hmac import compare_digest
40 except ImportError:
40 except ImportError:
41 # Python < 2.7.7: When digests don't match no feedback is provided,
41 # Python < 2.7.7: When digests don't match no feedback is provided,
42 # limiting the surface of attack
42 # limiting the surface of attack
43 def compare_digest(a,b): return a == b
43 def compare_digest(a,b): return a == b
44
44
45 import zmq
45 import zmq
46 from zmq.utils import jsonapi
46 from zmq.utils import jsonapi
47 from zmq.eventloop.ioloop import IOLoop
47 from zmq.eventloop.ioloop import IOLoop
48 from zmq.eventloop.zmqstream import ZMQStream
48 from zmq.eventloop.zmqstream import ZMQStream
49
49
50 from IPython.core.release import kernel_protocol_version
50 from IPython.core.release import kernel_protocol_version
51 from IPython.config.configurable import Configurable, LoggingConfigurable
51 from IPython.config.configurable import Configurable, LoggingConfigurable
52 from IPython.utils import io
53 from IPython.utils.importstring import import_item
52 from IPython.utils.importstring import import_item
54 from jupyter_client.jsonutil import extract_dates, squash_dates, date_default
53 from jupyter_client.jsonutil import extract_dates, squash_dates, date_default
55 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
54 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
56 iteritems)
55 iteritems)
57 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
56 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
58 DottedObjectName, CUnicode, Dict, Integer,
57 DottedObjectName, CUnicode, Dict, Integer,
59 TraitError,
58 TraitError,
60 )
59 )
61 from jupyter_client.adapter import adapt
60 from jupyter_client.adapter import adapt
61 from traitlets.log import get_logger
62
62
63
63 #-----------------------------------------------------------------------------
64 #-----------------------------------------------------------------------------
64 # utility functions
65 # utility functions
65 #-----------------------------------------------------------------------------
66 #-----------------------------------------------------------------------------
66
67
67 def squash_unicode(obj):
68 def squash_unicode(obj):
68 """coerce unicode back to bytestrings."""
69 """coerce unicode back to bytestrings."""
69 if isinstance(obj,dict):
70 if isinstance(obj,dict):
70 for key in obj.keys():
71 for key in obj.keys():
71 obj[key] = squash_unicode(obj[key])
72 obj[key] = squash_unicode(obj[key])
72 if isinstance(key, unicode_type):
73 if isinstance(key, unicode_type):
73 obj[squash_unicode(key)] = obj.pop(key)
74 obj[squash_unicode(key)] = obj.pop(key)
74 elif isinstance(obj, list):
75 elif isinstance(obj, list):
75 for i,v in enumerate(obj):
76 for i,v in enumerate(obj):
76 obj[i] = squash_unicode(v)
77 obj[i] = squash_unicode(v)
77 elif isinstance(obj, unicode_type):
78 elif isinstance(obj, unicode_type):
78 obj = obj.encode('utf8')
79 obj = obj.encode('utf8')
79 return obj
80 return obj
80
81
81 #-----------------------------------------------------------------------------
82 #-----------------------------------------------------------------------------
82 # globals and defaults
83 # globals and defaults
83 #-----------------------------------------------------------------------------
84 #-----------------------------------------------------------------------------
84
85
85 # default values for the thresholds:
86 # default values for the thresholds:
86 MAX_ITEMS = 64
87 MAX_ITEMS = 64
87 MAX_BYTES = 1024
88 MAX_BYTES = 1024
88
89
89 # ISO8601-ify datetime objects
90 # ISO8601-ify datetime objects
90 # allow unicode
91 # allow unicode
91 # disallow nan, because it's not actually valid JSON
92 # disallow nan, because it's not actually valid JSON
92 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
93 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
93 ensure_ascii=False, allow_nan=False,
94 ensure_ascii=False, allow_nan=False,
94 )
95 )
95 json_unpacker = lambda s: jsonapi.loads(s)
96 json_unpacker = lambda s: jsonapi.loads(s)
96
97
97 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
98 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
98 pickle_unpacker = pickle.loads
99 pickle_unpacker = pickle.loads
99
100
100 default_packer = json_packer
101 default_packer = json_packer
101 default_unpacker = json_unpacker
102 default_unpacker = json_unpacker
102
103
103 DELIM = b"<IDS|MSG>"
104 DELIM = b"<IDS|MSG>"
104 # singleton dummy tracker, which will always report as done
105 # singleton dummy tracker, which will always report as done
105 DONE = zmq.MessageTracker()
106 DONE = zmq.MessageTracker()
106
107
107 #-----------------------------------------------------------------------------
108 #-----------------------------------------------------------------------------
108 # Mixin tools for apps that use Sessions
109 # Mixin tools for apps that use Sessions
109 #-----------------------------------------------------------------------------
110 #-----------------------------------------------------------------------------
110
111
111 session_aliases = dict(
112 session_aliases = dict(
112 ident = 'Session.session',
113 ident = 'Session.session',
113 user = 'Session.username',
114 user = 'Session.username',
114 keyfile = 'Session.keyfile',
115 keyfile = 'Session.keyfile',
115 )
116 )
116
117
117 session_flags = {
118 session_flags = {
118 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
119 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
119 'keyfile' : '' }},
120 'keyfile' : '' }},
120 """Use HMAC digests for authentication of messages.
121 """Use HMAC digests for authentication of messages.
121 Setting this flag will generate a new UUID to use as the HMAC key.
122 Setting this flag will generate a new UUID to use as the HMAC key.
122 """),
123 """),
123 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
124 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
124 """Don't authenticate messages."""),
125 """Don't authenticate messages."""),
125 }
126 }
126
127
127 def default_secure(cfg):
128 def default_secure(cfg):
128 """Set the default behavior for a config environment to be secure.
129 """Set the default behavior for a config environment to be secure.
129
130
130 If Session.key/keyfile have not been set, set Session.key to
131 If Session.key/keyfile have not been set, set Session.key to
131 a new random UUID.
132 a new random UUID.
132 """
133 """
133 warnings.warn("default_secure is deprecated", DeprecationWarning)
134 warnings.warn("default_secure is deprecated", DeprecationWarning)
134 if 'Session' in cfg:
135 if 'Session' in cfg:
135 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
136 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
136 return
137 return
137 # key/keyfile not specified, generate new UUID:
138 # key/keyfile not specified, generate new UUID:
138 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
139 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
139
140
140
141
141 #-----------------------------------------------------------------------------
142 #-----------------------------------------------------------------------------
142 # Classes
143 # Classes
143 #-----------------------------------------------------------------------------
144 #-----------------------------------------------------------------------------
144
145
145 class SessionFactory(LoggingConfigurable):
146 class SessionFactory(LoggingConfigurable):
146 """The Base class for configurables that have a Session, Context, logger,
147 """The Base class for configurables that have a Session, Context, logger,
147 and IOLoop.
148 and IOLoop.
148 """
149 """
149
150
150 logname = Unicode('')
151 logname = Unicode('')
151 def _logname_changed(self, name, old, new):
152 def _logname_changed(self, name, old, new):
152 self.log = logging.getLogger(new)
153 self.log = logging.getLogger(new)
153
154
154 # not configurable:
155 # not configurable:
155 context = Instance('zmq.Context')
156 context = Instance('zmq.Context')
156 def _context_default(self):
157 def _context_default(self):
157 return zmq.Context.instance()
158 return zmq.Context.instance()
158
159
159 session = Instance('jupyter_client.session.Session',
160 session = Instance('jupyter_client.session.Session',
160 allow_none=True)
161 allow_none=True)
161
162
162 loop = Instance('zmq.eventloop.ioloop.IOLoop')
163 loop = Instance('zmq.eventloop.ioloop.IOLoop')
163 def _loop_default(self):
164 def _loop_default(self):
164 return IOLoop.instance()
165 return IOLoop.instance()
165
166
166 def __init__(self, **kwargs):
167 def __init__(self, **kwargs):
167 super(SessionFactory, self).__init__(**kwargs)
168 super(SessionFactory, self).__init__(**kwargs)
168
169
169 if self.session is None:
170 if self.session is None:
170 # construct the session
171 # construct the session
171 self.session = Session(**kwargs)
172 self.session = Session(**kwargs)
172
173
173
174
174 class Message(object):
175 class Message(object):
175 """A simple message object that maps dict keys to attributes.
176 """A simple message object that maps dict keys to attributes.
176
177
177 A Message can be created from a dict and a dict from a Message instance
178 A Message can be created from a dict and a dict from a Message instance
178 simply by calling dict(msg_obj)."""
179 simply by calling dict(msg_obj)."""
179
180
180 def __init__(self, msg_dict):
181 def __init__(self, msg_dict):
181 dct = self.__dict__
182 dct = self.__dict__
182 for k, v in iteritems(dict(msg_dict)):
183 for k, v in iteritems(dict(msg_dict)):
183 if isinstance(v, dict):
184 if isinstance(v, dict):
184 v = Message(v)
185 v = Message(v)
185 dct[k] = v
186 dct[k] = v
186
187
187 # Having this iterator lets dict(msg_obj) work out of the box.
188 # Having this iterator lets dict(msg_obj) work out of the box.
188 def __iter__(self):
189 def __iter__(self):
189 return iter(iteritems(self.__dict__))
190 return iter(iteritems(self.__dict__))
190
191
191 def __repr__(self):
192 def __repr__(self):
192 return repr(self.__dict__)
193 return repr(self.__dict__)
193
194
194 def __str__(self):
195 def __str__(self):
195 return pprint.pformat(self.__dict__)
196 return pprint.pformat(self.__dict__)
196
197
197 def __contains__(self, k):
198 def __contains__(self, k):
198 return k in self.__dict__
199 return k in self.__dict__
199
200
200 def __getitem__(self, k):
201 def __getitem__(self, k):
201 return self.__dict__[k]
202 return self.__dict__[k]
202
203
203
204
204 def msg_header(msg_id, msg_type, username, session):
205 def msg_header(msg_id, msg_type, username, session):
205 date = datetime.now()
206 date = datetime.now()
206 version = kernel_protocol_version
207 version = kernel_protocol_version
207 return locals()
208 return locals()
208
209
209 def extract_header(msg_or_header):
210 def extract_header(msg_or_header):
210 """Given a message or header, return the header."""
211 """Given a message or header, return the header."""
211 if not msg_or_header:
212 if not msg_or_header:
212 return {}
213 return {}
213 try:
214 try:
214 # See if msg_or_header is the entire message.
215 # See if msg_or_header is the entire message.
215 h = msg_or_header['header']
216 h = msg_or_header['header']
216 except KeyError:
217 except KeyError:
217 try:
218 try:
218 # See if msg_or_header is just the header
219 # See if msg_or_header is just the header
219 h = msg_or_header['msg_id']
220 h = msg_or_header['msg_id']
220 except KeyError:
221 except KeyError:
221 raise
222 raise
222 else:
223 else:
223 h = msg_or_header
224 h = msg_or_header
224 if not isinstance(h, dict):
225 if not isinstance(h, dict):
225 h = dict(h)
226 h = dict(h)
226 return h
227 return h
227
228
228 class Session(Configurable):
229 class Session(Configurable):
229 """Object for handling serialization and sending of messages.
230 """Object for handling serialization and sending of messages.
230
231
231 The Session object handles building messages and sending them
232 The Session object handles building messages and sending them
232 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
233 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
233 other over the network via Session objects, and only need to work with the
234 other over the network via Session objects, and only need to work with the
234 dict-based IPython message spec. The Session will handle
235 dict-based IPython message spec. The Session will handle
235 serialization/deserialization, security, and metadata.
236 serialization/deserialization, security, and metadata.
236
237
237 Sessions support configurable serialization via packer/unpacker traits,
238 Sessions support configurable serialization via packer/unpacker traits,
238 and signing with HMAC digests via the key/keyfile traits.
239 and signing with HMAC digests via the key/keyfile traits.
239
240
240 Parameters
241 Parameters
241 ----------
242 ----------
242
243
243 debug : bool
244 debug : bool
244 whether to trigger extra debugging statements
245 whether to trigger extra debugging statements
245 packer/unpacker : str : 'json', 'pickle' or import_string
246 packer/unpacker : str : 'json', 'pickle' or import_string
246 importstrings for methods to serialize message parts. If just
247 importstrings for methods to serialize message parts. If just
247 'json' or 'pickle', predefined JSON and pickle packers will be used.
248 'json' or 'pickle', predefined JSON and pickle packers will be used.
248 Otherwise, the entire importstring must be used.
249 Otherwise, the entire importstring must be used.
249
250
250 The functions must accept at least valid JSON input, and output *bytes*.
251 The functions must accept at least valid JSON input, and output *bytes*.
251
252
252 For example, to use msgpack:
253 For example, to use msgpack:
253 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
254 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
254 pack/unpack : callables
255 pack/unpack : callables
255 You can also set the pack/unpack callables for serialization directly.
256 You can also set the pack/unpack callables for serialization directly.
256 session : bytes
257 session : bytes
257 the ID of this Session object. The default is to generate a new UUID.
258 the ID of this Session object. The default is to generate a new UUID.
258 username : unicode
259 username : unicode
259 username added to message headers. The default is to ask the OS.
260 username added to message headers. The default is to ask the OS.
260 key : bytes
261 key : bytes
261 The key used to initialize an HMAC signature. If unset, messages
262 The key used to initialize an HMAC signature. If unset, messages
262 will not be signed or checked.
263 will not be signed or checked.
263 keyfile : filepath
264 keyfile : filepath
264 The file containing a key. If this is set, `key` will be initialized
265 The file containing a key. If this is set, `key` will be initialized
265 to the contents of the file.
266 to the contents of the file.
266
267
267 """
268 """
268
269
269 debug=Bool(False, config=True, help="""Debug output in the Session""")
270 debug=Bool(False, config=True, help="""Debug output in the Session""")
270
271
271 packer = DottedObjectName('json',config=True,
272 packer = DottedObjectName('json',config=True,
272 help="""The name of the packer for serializing messages.
273 help="""The name of the packer for serializing messages.
273 Should be one of 'json', 'pickle', or an import name
274 Should be one of 'json', 'pickle', or an import name
274 for a custom callable serializer.""")
275 for a custom callable serializer.""")
275 def _packer_changed(self, name, old, new):
276 def _packer_changed(self, name, old, new):
276 if new.lower() == 'json':
277 if new.lower() == 'json':
277 self.pack = json_packer
278 self.pack = json_packer
278 self.unpack = json_unpacker
279 self.unpack = json_unpacker
279 self.unpacker = new
280 self.unpacker = new
280 elif new.lower() == 'pickle':
281 elif new.lower() == 'pickle':
281 self.pack = pickle_packer
282 self.pack = pickle_packer
282 self.unpack = pickle_unpacker
283 self.unpack = pickle_unpacker
283 self.unpacker = new
284 self.unpacker = new
284 else:
285 else:
285 self.pack = import_item(str(new))
286 self.pack = import_item(str(new))
286
287
287 unpacker = DottedObjectName('json', config=True,
288 unpacker = DottedObjectName('json', config=True,
288 help="""The name of the unpacker for unserializing messages.
289 help="""The name of the unpacker for unserializing messages.
289 Only used with custom functions for `packer`.""")
290 Only used with custom functions for `packer`.""")
290 def _unpacker_changed(self, name, old, new):
291 def _unpacker_changed(self, name, old, new):
291 if new.lower() == 'json':
292 if new.lower() == 'json':
292 self.pack = json_packer
293 self.pack = json_packer
293 self.unpack = json_unpacker
294 self.unpack = json_unpacker
294 self.packer = new
295 self.packer = new
295 elif new.lower() == 'pickle':
296 elif new.lower() == 'pickle':
296 self.pack = pickle_packer
297 self.pack = pickle_packer
297 self.unpack = pickle_unpacker
298 self.unpack = pickle_unpacker
298 self.packer = new
299 self.packer = new
299 else:
300 else:
300 self.unpack = import_item(str(new))
301 self.unpack = import_item(str(new))
301
302
302 session = CUnicode(u'', config=True,
303 session = CUnicode(u'', config=True,
303 help="""The UUID identifying this session.""")
304 help="""The UUID identifying this session.""")
304 def _session_default(self):
305 def _session_default(self):
305 u = unicode_type(uuid.uuid4())
306 u = unicode_type(uuid.uuid4())
306 self.bsession = u.encode('ascii')
307 self.bsession = u.encode('ascii')
307 return u
308 return u
308
309
309 def _session_changed(self, name, old, new):
310 def _session_changed(self, name, old, new):
310 self.bsession = self.session.encode('ascii')
311 self.bsession = self.session.encode('ascii')
311
312
312 # bsession is the session as bytes
313 # bsession is the session as bytes
313 bsession = CBytes(b'')
314 bsession = CBytes(b'')
314
315
315 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
316 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
316 help="""Username for the Session. Default is your system username.""",
317 help="""Username for the Session. Default is your system username.""",
317 config=True)
318 config=True)
318
319
319 metadata = Dict({}, config=True,
320 metadata = Dict({}, config=True,
320 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
321 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
321
322
322 # if 0, no adapting to do.
323 # if 0, no adapting to do.
323 adapt_version = Integer(0)
324 adapt_version = Integer(0)
324
325
325 # message signature related traits:
326 # message signature related traits:
326
327
327 key = CBytes(config=True,
328 key = CBytes(config=True,
328 help="""execution key, for signing messages.""")
329 help="""execution key, for signing messages.""")
329 def _key_default(self):
330 def _key_default(self):
330 return str_to_bytes(str(uuid.uuid4()))
331 return str_to_bytes(str(uuid.uuid4()))
331
332
332 def _key_changed(self):
333 def _key_changed(self):
333 self._new_auth()
334 self._new_auth()
334
335
335 signature_scheme = Unicode('hmac-sha256', config=True,
336 signature_scheme = Unicode('hmac-sha256', config=True,
336 help="""The digest scheme used to construct the message signatures.
337 help="""The digest scheme used to construct the message signatures.
337 Must have the form 'hmac-HASH'.""")
338 Must have the form 'hmac-HASH'.""")
338 def _signature_scheme_changed(self, name, old, new):
339 def _signature_scheme_changed(self, name, old, new):
339 if not new.startswith('hmac-'):
340 if not new.startswith('hmac-'):
340 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
341 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
341 hash_name = new.split('-', 1)[1]
342 hash_name = new.split('-', 1)[1]
342 try:
343 try:
343 self.digest_mod = getattr(hashlib, hash_name)
344 self.digest_mod = getattr(hashlib, hash_name)
344 except AttributeError:
345 except AttributeError:
345 raise TraitError("hashlib has no such attribute: %s" % hash_name)
346 raise TraitError("hashlib has no such attribute: %s" % hash_name)
346 self._new_auth()
347 self._new_auth()
347
348
348 digest_mod = Any()
349 digest_mod = Any()
349 def _digest_mod_default(self):
350 def _digest_mod_default(self):
350 return hashlib.sha256
351 return hashlib.sha256
351
352
352 auth = Instance(hmac.HMAC, allow_none=True)
353 auth = Instance(hmac.HMAC, allow_none=True)
353
354
354 def _new_auth(self):
355 def _new_auth(self):
355 if self.key:
356 if self.key:
356 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
357 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
357 else:
358 else:
358 self.auth = None
359 self.auth = None
359
360
360 digest_history = Set()
361 digest_history = Set()
361 digest_history_size = Integer(2**16, config=True,
362 digest_history_size = Integer(2**16, config=True,
362 help="""The maximum number of digests to remember.
363 help="""The maximum number of digests to remember.
363
364
364 The digest history will be culled when it exceeds this value.
365 The digest history will be culled when it exceeds this value.
365 """
366 """
366 )
367 )
367
368
368 keyfile = Unicode('', config=True,
369 keyfile = Unicode('', config=True,
369 help="""path to file containing execution key.""")
370 help="""path to file containing execution key.""")
370 def _keyfile_changed(self, name, old, new):
371 def _keyfile_changed(self, name, old, new):
371 with open(new, 'rb') as f:
372 with open(new, 'rb') as f:
372 self.key = f.read().strip()
373 self.key = f.read().strip()
373
374
374 # for protecting against sends from forks
375 # for protecting against sends from forks
375 pid = Integer()
376 pid = Integer()
376
377
377 # serialization traits:
378 # serialization traits:
378
379
379 pack = Any(default_packer) # the actual packer function
380 pack = Any(default_packer) # the actual packer function
380 def _pack_changed(self, name, old, new):
381 def _pack_changed(self, name, old, new):
381 if not callable(new):
382 if not callable(new):
382 raise TypeError("packer must be callable, not %s"%type(new))
383 raise TypeError("packer must be callable, not %s"%type(new))
383
384
384 unpack = Any(default_unpacker) # the actual packer function
385 unpack = Any(default_unpacker) # the actual packer function
385 def _unpack_changed(self, name, old, new):
386 def _unpack_changed(self, name, old, new):
386 # unpacker is not checked - it is assumed to be
387 # unpacker is not checked - it is assumed to be
387 if not callable(new):
388 if not callable(new):
388 raise TypeError("unpacker must be callable, not %s"%type(new))
389 raise TypeError("unpacker must be callable, not %s"%type(new))
389
390
390 # thresholds:
391 # thresholds:
391 copy_threshold = Integer(2**16, config=True,
392 copy_threshold = Integer(2**16, config=True,
392 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
393 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
393 buffer_threshold = Integer(MAX_BYTES, config=True,
394 buffer_threshold = Integer(MAX_BYTES, config=True,
394 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
395 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
395 item_threshold = Integer(MAX_ITEMS, config=True,
396 item_threshold = Integer(MAX_ITEMS, config=True,
396 help="""The maximum number of items for a container to be introspected for custom serialization.
397 help="""The maximum number of items for a container to be introspected for custom serialization.
397 Containers larger than this are pickled outright.
398 Containers larger than this are pickled outright.
398 """
399 """
399 )
400 )
400
401
401
402
402 def __init__(self, **kwargs):
403 def __init__(self, **kwargs):
403 """create a Session object
404 """create a Session object
404
405
405 Parameters
406 Parameters
406 ----------
407 ----------
407
408
408 debug : bool
409 debug : bool
409 whether to trigger extra debugging statements
410 whether to trigger extra debugging statements
410 packer/unpacker : str : 'json', 'pickle' or import_string
411 packer/unpacker : str : 'json', 'pickle' or import_string
411 importstrings for methods to serialize message parts. If just
412 importstrings for methods to serialize message parts. If just
412 'json' or 'pickle', predefined JSON and pickle packers will be used.
413 'json' or 'pickle', predefined JSON and pickle packers will be used.
413 Otherwise, the entire importstring must be used.
414 Otherwise, the entire importstring must be used.
414
415
415 The functions must accept at least valid JSON input, and output
416 The functions must accept at least valid JSON input, and output
416 *bytes*.
417 *bytes*.
417
418
418 For example, to use msgpack:
419 For example, to use msgpack:
419 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
420 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
420 pack/unpack : callables
421 pack/unpack : callables
421 You can also set the pack/unpack callables for serialization
422 You can also set the pack/unpack callables for serialization
422 directly.
423 directly.
423 session : unicode (must be ascii)
424 session : unicode (must be ascii)
424 the ID of this Session object. The default is to generate a new
425 the ID of this Session object. The default is to generate a new
425 UUID.
426 UUID.
426 bsession : bytes
427 bsession : bytes
427 The session as bytes
428 The session as bytes
428 username : unicode
429 username : unicode
429 username added to message headers. The default is to ask the OS.
430 username added to message headers. The default is to ask the OS.
430 key : bytes
431 key : bytes
431 The key used to initialize an HMAC signature. If unset, messages
432 The key used to initialize an HMAC signature. If unset, messages
432 will not be signed or checked.
433 will not be signed or checked.
433 signature_scheme : str
434 signature_scheme : str
434 The message digest scheme. Currently must be of the form 'hmac-HASH',
435 The message digest scheme. Currently must be of the form 'hmac-HASH',
435 where 'HASH' is a hashing function available in Python's hashlib.
436 where 'HASH' is a hashing function available in Python's hashlib.
436 The default is 'hmac-sha256'.
437 The default is 'hmac-sha256'.
437 This is ignored if 'key' is empty.
438 This is ignored if 'key' is empty.
438 keyfile : filepath
439 keyfile : filepath
439 The file containing a key. If this is set, `key` will be
440 The file containing a key. If this is set, `key` will be
440 initialized to the contents of the file.
441 initialized to the contents of the file.
441 """
442 """
442 super(Session, self).__init__(**kwargs)
443 super(Session, self).__init__(**kwargs)
443 self._check_packers()
444 self._check_packers()
444 self.none = self.pack({})
445 self.none = self.pack({})
445 # ensure self._session_default() if necessary, so bsession is defined:
446 # ensure self._session_default() if necessary, so bsession is defined:
446 self.session
447 self.session
447 self.pid = os.getpid()
448 self.pid = os.getpid()
448 self._new_auth()
449 self._new_auth()
449
450
450 @property
451 @property
451 def msg_id(self):
452 def msg_id(self):
452 """always return new uuid"""
453 """always return new uuid"""
453 return str(uuid.uuid4())
454 return str(uuid.uuid4())
454
455
455 def _check_packers(self):
456 def _check_packers(self):
456 """check packers for datetime support."""
457 """check packers for datetime support."""
457 pack = self.pack
458 pack = self.pack
458 unpack = self.unpack
459 unpack = self.unpack
459
460
460 # check simple serialization
461 # check simple serialization
461 msg = dict(a=[1,'hi'])
462 msg = dict(a=[1,'hi'])
462 try:
463 try:
463 packed = pack(msg)
464 packed = pack(msg)
464 except Exception as e:
465 except Exception as e:
465 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
466 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
466 if self.packer == 'json':
467 if self.packer == 'json':
467 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
468 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
468 else:
469 else:
469 jsonmsg = ""
470 jsonmsg = ""
470 raise ValueError(
471 raise ValueError(
471 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
472 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
472 )
473 )
473
474
474 # ensure packed message is bytes
475 # ensure packed message is bytes
475 if not isinstance(packed, bytes):
476 if not isinstance(packed, bytes):
476 raise ValueError("message packed to %r, but bytes are required"%type(packed))
477 raise ValueError("message packed to %r, but bytes are required"%type(packed))
477
478
478 # check that unpack is pack's inverse
479 # check that unpack is pack's inverse
479 try:
480 try:
480 unpacked = unpack(packed)
481 unpacked = unpack(packed)
481 assert unpacked == msg
482 assert unpacked == msg
482 except Exception as e:
483 except Exception as e:
483 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
484 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
484 if self.packer == 'json':
485 if self.packer == 'json':
485 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
486 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
486 else:
487 else:
487 jsonmsg = ""
488 jsonmsg = ""
488 raise ValueError(
489 raise ValueError(
489 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
490 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
490 )
491 )
491
492
492 # check datetime support
493 # check datetime support
493 msg = dict(t=datetime.now())
494 msg = dict(t=datetime.now())
494 try:
495 try:
495 unpacked = unpack(pack(msg))
496 unpacked = unpack(pack(msg))
496 if isinstance(unpacked['t'], datetime):
497 if isinstance(unpacked['t'], datetime):
497 raise ValueError("Shouldn't deserialize to datetime")
498 raise ValueError("Shouldn't deserialize to datetime")
498 except Exception:
499 except Exception:
499 self.pack = lambda o: pack(squash_dates(o))
500 self.pack = lambda o: pack(squash_dates(o))
500 self.unpack = lambda s: unpack(s)
501 self.unpack = lambda s: unpack(s)
501
502
502 def msg_header(self, msg_type):
503 def msg_header(self, msg_type):
503 return msg_header(self.msg_id, msg_type, self.username, self.session)
504 return msg_header(self.msg_id, msg_type, self.username, self.session)
504
505
505 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
506 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
506 """Return the nested message dict.
507 """Return the nested message dict.
507
508
508 This format is different from what is sent over the wire. The
509 This format is different from what is sent over the wire. The
509 serialize/deserialize methods converts this nested message dict to the wire
510 serialize/deserialize methods converts this nested message dict to the wire
510 format, which is a list of message parts.
511 format, which is a list of message parts.
511 """
512 """
512 msg = {}
513 msg = {}
513 header = self.msg_header(msg_type) if header is None else header
514 header = self.msg_header(msg_type) if header is None else header
514 msg['header'] = header
515 msg['header'] = header
515 msg['msg_id'] = header['msg_id']
516 msg['msg_id'] = header['msg_id']
516 msg['msg_type'] = header['msg_type']
517 msg['msg_type'] = header['msg_type']
517 msg['parent_header'] = {} if parent is None else extract_header(parent)
518 msg['parent_header'] = {} if parent is None else extract_header(parent)
518 msg['content'] = {} if content is None else content
519 msg['content'] = {} if content is None else content
519 msg['metadata'] = self.metadata.copy()
520 msg['metadata'] = self.metadata.copy()
520 if metadata is not None:
521 if metadata is not None:
521 msg['metadata'].update(metadata)
522 msg['metadata'].update(metadata)
522 return msg
523 return msg
523
524
524 def sign(self, msg_list):
525 def sign(self, msg_list):
525 """Sign a message with HMAC digest. If no auth, return b''.
526 """Sign a message with HMAC digest. If no auth, return b''.
526
527
527 Parameters
528 Parameters
528 ----------
529 ----------
529 msg_list : list
530 msg_list : list
530 The [p_header,p_parent,p_content] part of the message list.
531 The [p_header,p_parent,p_content] part of the message list.
531 """
532 """
532 if self.auth is None:
533 if self.auth is None:
533 return b''
534 return b''
534 h = self.auth.copy()
535 h = self.auth.copy()
535 for m in msg_list:
536 for m in msg_list:
536 h.update(m)
537 h.update(m)
537 return str_to_bytes(h.hexdigest())
538 return str_to_bytes(h.hexdigest())
538
539
539 def serialize(self, msg, ident=None):
540 def serialize(self, msg, ident=None):
540 """Serialize the message components to bytes.
541 """Serialize the message components to bytes.
541
542
542 This is roughly the inverse of deserialize. The serialize/deserialize
543 This is roughly the inverse of deserialize. The serialize/deserialize
543 methods work with full message lists, whereas pack/unpack work with
544 methods work with full message lists, whereas pack/unpack work with
544 the individual message parts in the message list.
545 the individual message parts in the message list.
545
546
546 Parameters
547 Parameters
547 ----------
548 ----------
548 msg : dict or Message
549 msg : dict or Message
549 The next message dict as returned by the self.msg method.
550 The next message dict as returned by the self.msg method.
550
551
551 Returns
552 Returns
552 -------
553 -------
553 msg_list : list
554 msg_list : list
554 The list of bytes objects to be sent with the format::
555 The list of bytes objects to be sent with the format::
555
556
556 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
557 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
557 p_metadata, p_content, buffer1, buffer2, ...]
558 p_metadata, p_content, buffer1, buffer2, ...]
558
559
559 In this list, the ``p_*`` entities are the packed or serialized
560 In this list, the ``p_*`` entities are the packed or serialized
560 versions, so if JSON is used, these are utf8 encoded JSON strings.
561 versions, so if JSON is used, these are utf8 encoded JSON strings.
561 """
562 """
562 content = msg.get('content', {})
563 content = msg.get('content', {})
563 if content is None:
564 if content is None:
564 content = self.none
565 content = self.none
565 elif isinstance(content, dict):
566 elif isinstance(content, dict):
566 content = self.pack(content)
567 content = self.pack(content)
567 elif isinstance(content, bytes):
568 elif isinstance(content, bytes):
568 # content is already packed, as in a relayed message
569 # content is already packed, as in a relayed message
569 pass
570 pass
570 elif isinstance(content, unicode_type):
571 elif isinstance(content, unicode_type):
571 # should be bytes, but JSON often spits out unicode
572 # should be bytes, but JSON often spits out unicode
572 content = content.encode('utf8')
573 content = content.encode('utf8')
573 else:
574 else:
574 raise TypeError("Content incorrect type: %s"%type(content))
575 raise TypeError("Content incorrect type: %s"%type(content))
575
576
576 real_message = [self.pack(msg['header']),
577 real_message = [self.pack(msg['header']),
577 self.pack(msg['parent_header']),
578 self.pack(msg['parent_header']),
578 self.pack(msg['metadata']),
579 self.pack(msg['metadata']),
579 content,
580 content,
580 ]
581 ]
581
582
582 to_send = []
583 to_send = []
583
584
584 if isinstance(ident, list):
585 if isinstance(ident, list):
585 # accept list of idents
586 # accept list of idents
586 to_send.extend(ident)
587 to_send.extend(ident)
587 elif ident is not None:
588 elif ident is not None:
588 to_send.append(ident)
589 to_send.append(ident)
589 to_send.append(DELIM)
590 to_send.append(DELIM)
590
591
591 signature = self.sign(real_message)
592 signature = self.sign(real_message)
592 to_send.append(signature)
593 to_send.append(signature)
593
594
594 to_send.extend(real_message)
595 to_send.extend(real_message)
595
596
596 return to_send
597 return to_send
597
598
598 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
599 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
599 buffers=None, track=False, header=None, metadata=None):
600 buffers=None, track=False, header=None, metadata=None):
600 """Build and send a message via stream or socket.
601 """Build and send a message via stream or socket.
601
602
602 The message format used by this function internally is as follows:
603 The message format used by this function internally is as follows:
603
604
604 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
605 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
605 buffer1,buffer2,...]
606 buffer1,buffer2,...]
606
607
607 The serialize/deserialize methods convert the nested message dict into this
608 The serialize/deserialize methods convert the nested message dict into this
608 format.
609 format.
609
610
610 Parameters
611 Parameters
611 ----------
612 ----------
612
613
613 stream : zmq.Socket or ZMQStream
614 stream : zmq.Socket or ZMQStream
614 The socket-like object used to send the data.
615 The socket-like object used to send the data.
615 msg_or_type : str or Message/dict
616 msg_or_type : str or Message/dict
616 Normally, msg_or_type will be a msg_type unless a message is being
617 Normally, msg_or_type will be a msg_type unless a message is being
617 sent more than once. If a header is supplied, this can be set to
618 sent more than once. If a header is supplied, this can be set to
618 None and the msg_type will be pulled from the header.
619 None and the msg_type will be pulled from the header.
619
620
620 content : dict or None
621 content : dict or None
621 The content of the message (ignored if msg_or_type is a message).
622 The content of the message (ignored if msg_or_type is a message).
622 header : dict or None
623 header : dict or None
623 The header dict for the message (ignored if msg_to_type is a message).
624 The header dict for the message (ignored if msg_to_type is a message).
624 parent : Message or dict or None
625 parent : Message or dict or None
625 The parent or parent header describing the parent of this message
626 The parent or parent header describing the parent of this message
626 (ignored if msg_or_type is a message).
627 (ignored if msg_or_type is a message).
627 ident : bytes or list of bytes
628 ident : bytes or list of bytes
628 The zmq.IDENTITY routing path.
629 The zmq.IDENTITY routing path.
629 metadata : dict or None
630 metadata : dict or None
630 The metadata describing the message
631 The metadata describing the message
631 buffers : list or None
632 buffers : list or None
632 The already-serialized buffers to be appended to the message.
633 The already-serialized buffers to be appended to the message.
633 track : bool
634 track : bool
634 Whether to track. Only for use with Sockets, because ZMQStream
635 Whether to track. Only for use with Sockets, because ZMQStream
635 objects cannot track messages.
636 objects cannot track messages.
636
637
637
638
638 Returns
639 Returns
639 -------
640 -------
640 msg : dict
641 msg : dict
641 The constructed message.
642 The constructed message.
642 """
643 """
643 if not isinstance(stream, zmq.Socket):
644 if not isinstance(stream, zmq.Socket):
644 # ZMQStreams and dummy sockets do not support tracking.
645 # ZMQStreams and dummy sockets do not support tracking.
645 track = False
646 track = False
646
647
647 if isinstance(msg_or_type, (Message, dict)):
648 if isinstance(msg_or_type, (Message, dict)):
648 # We got a Message or message dict, not a msg_type so don't
649 # We got a Message or message dict, not a msg_type so don't
649 # build a new Message.
650 # build a new Message.
650 msg = msg_or_type
651 msg = msg_or_type
651 buffers = buffers or msg.get('buffers', [])
652 buffers = buffers or msg.get('buffers', [])
652 else:
653 else:
653 msg = self.msg(msg_or_type, content=content, parent=parent,
654 msg = self.msg(msg_or_type, content=content, parent=parent,
654 header=header, metadata=metadata)
655 header=header, metadata=metadata)
655 if not os.getpid() == self.pid:
656 if not os.getpid() == self.pid:
656 io.rprint("WARNING: attempted to send message from fork")
657 get_logger().warn("WARNING: attempted to send message from fork\n%s",
657 io.rprint(msg)
658 msg
659 )
658 return
660 return
659 buffers = [] if buffers is None else buffers
661 buffers = [] if buffers is None else buffers
660 if self.adapt_version:
662 if self.adapt_version:
661 msg = adapt(msg, self.adapt_version)
663 msg = adapt(msg, self.adapt_version)
662 to_send = self.serialize(msg, ident)
664 to_send = self.serialize(msg, ident)
663 to_send.extend(buffers)
665 to_send.extend(buffers)
664 longest = max([ len(s) for s in to_send ])
666 longest = max([ len(s) for s in to_send ])
665 copy = (longest < self.copy_threshold)
667 copy = (longest < self.copy_threshold)
666
668
667 if buffers and track and not copy:
669 if buffers and track and not copy:
668 # only really track when we are doing zero-copy buffers
670 # only really track when we are doing zero-copy buffers
669 tracker = stream.send_multipart(to_send, copy=False, track=True)
671 tracker = stream.send_multipart(to_send, copy=False, track=True)
670 else:
672 else:
671 # use dummy tracker, which will be done immediately
673 # use dummy tracker, which will be done immediately
672 tracker = DONE
674 tracker = DONE
673 stream.send_multipart(to_send, copy=copy)
675 stream.send_multipart(to_send, copy=copy)
674
676
675 if self.debug:
677 if self.debug:
676 pprint.pprint(msg)
678 pprint.pprint(msg)
677 pprint.pprint(to_send)
679 pprint.pprint(to_send)
678 pprint.pprint(buffers)
680 pprint.pprint(buffers)
679
681
680 msg['tracker'] = tracker
682 msg['tracker'] = tracker
681
683
682 return msg
684 return msg
683
685
684 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
686 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
685 """Send a raw message via ident path.
687 """Send a raw message via ident path.
686
688
687 This method is used to send a already serialized message.
689 This method is used to send a already serialized message.
688
690
689 Parameters
691 Parameters
690 ----------
692 ----------
691 stream : ZMQStream or Socket
693 stream : ZMQStream or Socket
692 The ZMQ stream or socket to use for sending the message.
694 The ZMQ stream or socket to use for sending the message.
693 msg_list : list
695 msg_list : list
694 The serialized list of messages to send. This only includes the
696 The serialized list of messages to send. This only includes the
695 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
697 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
696 the message.
698 the message.
697 ident : ident or list
699 ident : ident or list
698 A single ident or a list of idents to use in sending.
700 A single ident or a list of idents to use in sending.
699 """
701 """
700 to_send = []
702 to_send = []
701 if isinstance(ident, bytes):
703 if isinstance(ident, bytes):
702 ident = [ident]
704 ident = [ident]
703 if ident is not None:
705 if ident is not None:
704 to_send.extend(ident)
706 to_send.extend(ident)
705
707
706 to_send.append(DELIM)
708 to_send.append(DELIM)
707 to_send.append(self.sign(msg_list))
709 to_send.append(self.sign(msg_list))
708 to_send.extend(msg_list)
710 to_send.extend(msg_list)
709 stream.send_multipart(to_send, flags, copy=copy)
711 stream.send_multipart(to_send, flags, copy=copy)
710
712
711 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
713 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
712 """Receive and unpack a message.
714 """Receive and unpack a message.
713
715
714 Parameters
716 Parameters
715 ----------
717 ----------
716 socket : ZMQStream or Socket
718 socket : ZMQStream or Socket
717 The socket or stream to use in receiving.
719 The socket or stream to use in receiving.
718
720
719 Returns
721 Returns
720 -------
722 -------
721 [idents], msg
723 [idents], msg
722 [idents] is a list of idents and msg is a nested message dict of
724 [idents] is a list of idents and msg is a nested message dict of
723 same format as self.msg returns.
725 same format as self.msg returns.
724 """
726 """
725 if isinstance(socket, ZMQStream):
727 if isinstance(socket, ZMQStream):
726 socket = socket.socket
728 socket = socket.socket
727 try:
729 try:
728 msg_list = socket.recv_multipart(mode, copy=copy)
730 msg_list = socket.recv_multipart(mode, copy=copy)
729 except zmq.ZMQError as e:
731 except zmq.ZMQError as e:
730 if e.errno == zmq.EAGAIN:
732 if e.errno == zmq.EAGAIN:
731 # We can convert EAGAIN to None as we know in this case
733 # We can convert EAGAIN to None as we know in this case
732 # recv_multipart won't return None.
734 # recv_multipart won't return None.
733 return None,None
735 return None,None
734 else:
736 else:
735 raise
737 raise
736 # split multipart message into identity list and message dict
738 # split multipart message into identity list and message dict
737 # invalid large messages can cause very expensive string comparisons
739 # invalid large messages can cause very expensive string comparisons
738 idents, msg_list = self.feed_identities(msg_list, copy)
740 idents, msg_list = self.feed_identities(msg_list, copy)
739 try:
741 try:
740 return idents, self.deserialize(msg_list, content=content, copy=copy)
742 return idents, self.deserialize(msg_list, content=content, copy=copy)
741 except Exception as e:
743 except Exception as e:
742 # TODO: handle it
744 # TODO: handle it
743 raise e
745 raise e
744
746
745 def feed_identities(self, msg_list, copy=True):
747 def feed_identities(self, msg_list, copy=True):
746 """Split the identities from the rest of the message.
748 """Split the identities from the rest of the message.
747
749
748 Feed until DELIM is reached, then return the prefix as idents and
750 Feed until DELIM is reached, then return the prefix as idents and
749 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
751 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
750 but that would be silly.
752 but that would be silly.
751
753
752 Parameters
754 Parameters
753 ----------
755 ----------
754 msg_list : a list of Message or bytes objects
756 msg_list : a list of Message or bytes objects
755 The message to be split.
757 The message to be split.
756 copy : bool
758 copy : bool
757 flag determining whether the arguments are bytes or Messages
759 flag determining whether the arguments are bytes or Messages
758
760
759 Returns
761 Returns
760 -------
762 -------
761 (idents, msg_list) : two lists
763 (idents, msg_list) : two lists
762 idents will always be a list of bytes, each of which is a ZMQ
764 idents will always be a list of bytes, each of which is a ZMQ
763 identity. msg_list will be a list of bytes or zmq.Messages of the
765 identity. msg_list will be a list of bytes or zmq.Messages of the
764 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
766 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
765 should be unpackable/unserializable via self.deserialize at this
767 should be unpackable/unserializable via self.deserialize at this
766 point.
768 point.
767 """
769 """
768 if copy:
770 if copy:
769 idx = msg_list.index(DELIM)
771 idx = msg_list.index(DELIM)
770 return msg_list[:idx], msg_list[idx+1:]
772 return msg_list[:idx], msg_list[idx+1:]
771 else:
773 else:
772 failed = True
774 failed = True
773 for idx,m in enumerate(msg_list):
775 for idx,m in enumerate(msg_list):
774 if m.bytes == DELIM:
776 if m.bytes == DELIM:
775 failed = False
777 failed = False
776 break
778 break
777 if failed:
779 if failed:
778 raise ValueError("DELIM not in msg_list")
780 raise ValueError("DELIM not in msg_list")
779 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
781 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
780 return [m.bytes for m in idents], msg_list
782 return [m.bytes for m in idents], msg_list
781
783
782 def _add_digest(self, signature):
784 def _add_digest(self, signature):
783 """add a digest to history to protect against replay attacks"""
785 """add a digest to history to protect against replay attacks"""
784 if self.digest_history_size == 0:
786 if self.digest_history_size == 0:
785 # no history, never add digests
787 # no history, never add digests
786 return
788 return
787
789
788 self.digest_history.add(signature)
790 self.digest_history.add(signature)
789 if len(self.digest_history) > self.digest_history_size:
791 if len(self.digest_history) > self.digest_history_size:
790 # threshold reached, cull 10%
792 # threshold reached, cull 10%
791 self._cull_digest_history()
793 self._cull_digest_history()
792
794
793 def _cull_digest_history(self):
795 def _cull_digest_history(self):
794 """cull the digest history
796 """cull the digest history
795
797
796 Removes a randomly selected 10% of the digest history
798 Removes a randomly selected 10% of the digest history
797 """
799 """
798 current = len(self.digest_history)
800 current = len(self.digest_history)
799 n_to_cull = max(int(current // 10), current - self.digest_history_size)
801 n_to_cull = max(int(current // 10), current - self.digest_history_size)
800 if n_to_cull >= current:
802 if n_to_cull >= current:
801 self.digest_history = set()
803 self.digest_history = set()
802 return
804 return
803 to_cull = random.sample(self.digest_history, n_to_cull)
805 to_cull = random.sample(self.digest_history, n_to_cull)
804 self.digest_history.difference_update(to_cull)
806 self.digest_history.difference_update(to_cull)
805
807
806 def deserialize(self, msg_list, content=True, copy=True):
808 def deserialize(self, msg_list, content=True, copy=True):
807 """Unserialize a msg_list to a nested message dict.
809 """Unserialize a msg_list to a nested message dict.
808
810
809 This is roughly the inverse of serialize. The serialize/deserialize
811 This is roughly the inverse of serialize. The serialize/deserialize
810 methods work with full message lists, whereas pack/unpack work with
812 methods work with full message lists, whereas pack/unpack work with
811 the individual message parts in the message list.
813 the individual message parts in the message list.
812
814
813 Parameters
815 Parameters
814 ----------
816 ----------
815 msg_list : list of bytes or Message objects
817 msg_list : list of bytes or Message objects
816 The list of message parts of the form [HMAC,p_header,p_parent,
818 The list of message parts of the form [HMAC,p_header,p_parent,
817 p_metadata,p_content,buffer1,buffer2,...].
819 p_metadata,p_content,buffer1,buffer2,...].
818 content : bool (True)
820 content : bool (True)
819 Whether to unpack the content dict (True), or leave it packed
821 Whether to unpack the content dict (True), or leave it packed
820 (False).
822 (False).
821 copy : bool (True)
823 copy : bool (True)
822 Whether msg_list contains bytes (True) or the non-copying Message
824 Whether msg_list contains bytes (True) or the non-copying Message
823 objects in each place (False).
825 objects in each place (False).
824
826
825 Returns
827 Returns
826 -------
828 -------
827 msg : dict
829 msg : dict
828 The nested message dict with top-level keys [header, parent_header,
830 The nested message dict with top-level keys [header, parent_header,
829 content, buffers]. The buffers are returned as memoryviews.
831 content, buffers]. The buffers are returned as memoryviews.
830 """
832 """
831 minlen = 5
833 minlen = 5
832 message = {}
834 message = {}
833 if not copy:
835 if not copy:
834 # pyzmq didn't copy the first parts of the message, so we'll do it
836 # pyzmq didn't copy the first parts of the message, so we'll do it
835 for i in range(minlen):
837 for i in range(minlen):
836 msg_list[i] = msg_list[i].bytes
838 msg_list[i] = msg_list[i].bytes
837 if self.auth is not None:
839 if self.auth is not None:
838 signature = msg_list[0]
840 signature = msg_list[0]
839 if not signature:
841 if not signature:
840 raise ValueError("Unsigned Message")
842 raise ValueError("Unsigned Message")
841 if signature in self.digest_history:
843 if signature in self.digest_history:
842 raise ValueError("Duplicate Signature: %r" % signature)
844 raise ValueError("Duplicate Signature: %r" % signature)
843 self._add_digest(signature)
845 self._add_digest(signature)
844 check = self.sign(msg_list[1:5])
846 check = self.sign(msg_list[1:5])
845 if not compare_digest(signature, check):
847 if not compare_digest(signature, check):
846 raise ValueError("Invalid Signature: %r" % signature)
848 raise ValueError("Invalid Signature: %r" % signature)
847 if not len(msg_list) >= minlen:
849 if not len(msg_list) >= minlen:
848 raise TypeError("malformed message, must have at least %i elements"%minlen)
850 raise TypeError("malformed message, must have at least %i elements"%minlen)
849 header = self.unpack(msg_list[1])
851 header = self.unpack(msg_list[1])
850 message['header'] = extract_dates(header)
852 message['header'] = extract_dates(header)
851 message['msg_id'] = header['msg_id']
853 message['msg_id'] = header['msg_id']
852 message['msg_type'] = header['msg_type']
854 message['msg_type'] = header['msg_type']
853 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
855 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
854 message['metadata'] = self.unpack(msg_list[3])
856 message['metadata'] = self.unpack(msg_list[3])
855 if content:
857 if content:
856 message['content'] = self.unpack(msg_list[4])
858 message['content'] = self.unpack(msg_list[4])
857 else:
859 else:
858 message['content'] = msg_list[4]
860 message['content'] = msg_list[4]
859 buffers = [memoryview(b) for b in msg_list[5:]]
861 buffers = [memoryview(b) for b in msg_list[5:]]
860 if buffers and buffers[0].shape is None:
862 if buffers and buffers[0].shape is None:
861 # force copy to workaround pyzmq #646
863 # force copy to workaround pyzmq #646
862 buffers = [memoryview(b.bytes) for b in msg_list[5:]]
864 buffers = [memoryview(b.bytes) for b in msg_list[5:]]
863 message['buffers'] = buffers
865 message['buffers'] = buffers
864 # adapt to the current version
866 # adapt to the current version
865 return adapt(message)
867 return adapt(message)
866
868
867 def unserialize(self, *args, **kwargs):
869 def unserialize(self, *args, **kwargs):
868 warnings.warn(
870 warnings.warn(
869 "Session.unserialize is deprecated. Use Session.deserialize.",
871 "Session.unserialize is deprecated. Use Session.deserialize.",
870 DeprecationWarning,
872 DeprecationWarning,
871 )
873 )
872 return self.deserialize(*args, **kwargs)
874 return self.deserialize(*args, **kwargs)
873
875
874
876
875 def test_msg2obj():
877 def test_msg2obj():
876 am = dict(x=1)
878 am = dict(x=1)
877 ao = Message(am)
879 ao = Message(am)
878 assert ao.x == am['x']
880 assert ao.x == am['x']
879
881
880 am['y'] = dict(z=1)
882 am['y'] = dict(z=1)
881 ao = Message(am)
883 ao = Message(am)
882 assert ao.y.z == am['y']['z']
884 assert ao.y.z == am['y']['z']
883
885
884 k1, k2 = 'y', 'z'
886 k1, k2 = 'y', 'z'
885 assert ao[k1][k2] == am[k1][k2]
887 assert ao[k1][k2] == am[k1][k2]
886
888
887 am2 = dict(ao)
889 am2 = dict(ao)
888 assert am['x'] == am2['x']
890 assert am['x'] == am2['x']
889 assert am['y']['z'] == am2['y']['z']
891 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now