##// END OF EJS Templates
build new Session auth when either key or scheme changes
MinRK -
Show More
@@ -1,852 +1,856 b''
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9 """
9 """
10
10
11 # Copyright (c) IPython Development Team.
11 # Copyright (c) IPython Development Team.
12 # Distributed under the terms of the Modified BSD License.
12 # Distributed under the terms of the Modified BSD License.
13
13
14 import hashlib
14 import hashlib
15 import hmac
15 import hmac
16 import logging
16 import logging
17 import os
17 import os
18 import pprint
18 import pprint
19 import random
19 import random
20 import uuid
20 import uuid
21 from datetime import datetime
21 from datetime import datetime
22
22
23 try:
23 try:
24 import cPickle
24 import cPickle
25 pickle = cPickle
25 pickle = cPickle
26 except:
26 except:
27 cPickle = None
27 cPickle = None
28 import pickle
28 import pickle
29
29
30 import zmq
30 import zmq
31 from zmq.utils import jsonapi
31 from zmq.utils import jsonapi
32 from zmq.eventloop.ioloop import IOLoop
32 from zmq.eventloop.ioloop import IOLoop
33 from zmq.eventloop.zmqstream import ZMQStream
33 from zmq.eventloop.zmqstream import ZMQStream
34
34
35 from IPython.core.release import kernel_protocol_version, kernel_protocol_version_info
35 from IPython.core.release import kernel_protocol_version, kernel_protocol_version_info
36 from IPython.config.configurable import Configurable, LoggingConfigurable
36 from IPython.config.configurable import Configurable, LoggingConfigurable
37 from IPython.utils import io
37 from IPython.utils import io
38 from IPython.utils.importstring import import_item
38 from IPython.utils.importstring import import_item
39 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
39 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
40 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
40 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
41 iteritems)
41 iteritems)
42 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
42 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
43 DottedObjectName, CUnicode, Dict, Integer,
43 DottedObjectName, CUnicode, Dict, Integer,
44 TraitError,
44 TraitError,
45 )
45 )
46 from IPython.kernel.adapter import adapt
46 from IPython.kernel.adapter import adapt
47 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
47 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
48
48
49 #-----------------------------------------------------------------------------
49 #-----------------------------------------------------------------------------
50 # utility functions
50 # utility functions
51 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
52
52
53 def squash_unicode(obj):
53 def squash_unicode(obj):
54 """coerce unicode back to bytestrings."""
54 """coerce unicode back to bytestrings."""
55 if isinstance(obj,dict):
55 if isinstance(obj,dict):
56 for key in obj.keys():
56 for key in obj.keys():
57 obj[key] = squash_unicode(obj[key])
57 obj[key] = squash_unicode(obj[key])
58 if isinstance(key, unicode_type):
58 if isinstance(key, unicode_type):
59 obj[squash_unicode(key)] = obj.pop(key)
59 obj[squash_unicode(key)] = obj.pop(key)
60 elif isinstance(obj, list):
60 elif isinstance(obj, list):
61 for i,v in enumerate(obj):
61 for i,v in enumerate(obj):
62 obj[i] = squash_unicode(v)
62 obj[i] = squash_unicode(v)
63 elif isinstance(obj, unicode_type):
63 elif isinstance(obj, unicode_type):
64 obj = obj.encode('utf8')
64 obj = obj.encode('utf8')
65 return obj
65 return obj
66
66
67 #-----------------------------------------------------------------------------
67 #-----------------------------------------------------------------------------
68 # globals and defaults
68 # globals and defaults
69 #-----------------------------------------------------------------------------
69 #-----------------------------------------------------------------------------
70
70
71 # ISO8601-ify datetime objects
71 # ISO8601-ify datetime objects
72 # allow unicode
72 # allow unicode
73 # disallow nan, because it's not actually valid JSON
73 # disallow nan, because it's not actually valid JSON
74 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
74 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
75 ensure_ascii=False, allow_nan=False,
75 ensure_ascii=False, allow_nan=False,
76 )
76 )
77 json_unpacker = lambda s: jsonapi.loads(s)
77 json_unpacker = lambda s: jsonapi.loads(s)
78
78
79 pickle_packer = lambda o: pickle.dumps(squash_dates(o),-1)
79 pickle_packer = lambda o: pickle.dumps(squash_dates(o),-1)
80 pickle_unpacker = pickle.loads
80 pickle_unpacker = pickle.loads
81
81
82 default_packer = json_packer
82 default_packer = json_packer
83 default_unpacker = json_unpacker
83 default_unpacker = json_unpacker
84
84
85 DELIM = b"<IDS|MSG>"
85 DELIM = b"<IDS|MSG>"
86 # singleton dummy tracker, which will always report as done
86 # singleton dummy tracker, which will always report as done
87 DONE = zmq.MessageTracker()
87 DONE = zmq.MessageTracker()
88
88
89 #-----------------------------------------------------------------------------
89 #-----------------------------------------------------------------------------
90 # Mixin tools for apps that use Sessions
90 # Mixin tools for apps that use Sessions
91 #-----------------------------------------------------------------------------
91 #-----------------------------------------------------------------------------
92
92
93 session_aliases = dict(
93 session_aliases = dict(
94 ident = 'Session.session',
94 ident = 'Session.session',
95 user = 'Session.username',
95 user = 'Session.username',
96 keyfile = 'Session.keyfile',
96 keyfile = 'Session.keyfile',
97 )
97 )
98
98
99 session_flags = {
99 session_flags = {
100 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
100 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
101 'keyfile' : '' }},
101 'keyfile' : '' }},
102 """Use HMAC digests for authentication of messages.
102 """Use HMAC digests for authentication of messages.
103 Setting this flag will generate a new UUID to use as the HMAC key.
103 Setting this flag will generate a new UUID to use as the HMAC key.
104 """),
104 """),
105 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
105 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
106 """Don't authenticate messages."""),
106 """Don't authenticate messages."""),
107 }
107 }
108
108
109 def default_secure(cfg):
109 def default_secure(cfg):
110 """Set the default behavior for a config environment to be secure.
110 """Set the default behavior for a config environment to be secure.
111
111
112 If Session.key/keyfile have not been set, set Session.key to
112 If Session.key/keyfile have not been set, set Session.key to
113 a new random UUID.
113 a new random UUID.
114 """
114 """
115
115
116 if 'Session' in cfg:
116 if 'Session' in cfg:
117 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
117 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
118 return
118 return
119 # key/keyfile not specified, generate new UUID:
119 # key/keyfile not specified, generate new UUID:
120 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
120 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
121
121
122
122
123 #-----------------------------------------------------------------------------
123 #-----------------------------------------------------------------------------
124 # Classes
124 # Classes
125 #-----------------------------------------------------------------------------
125 #-----------------------------------------------------------------------------
126
126
127 class SessionFactory(LoggingConfigurable):
127 class SessionFactory(LoggingConfigurable):
128 """The Base class for configurables that have a Session, Context, logger,
128 """The Base class for configurables that have a Session, Context, logger,
129 and IOLoop.
129 and IOLoop.
130 """
130 """
131
131
132 logname = Unicode('')
132 logname = Unicode('')
133 def _logname_changed(self, name, old, new):
133 def _logname_changed(self, name, old, new):
134 self.log = logging.getLogger(new)
134 self.log = logging.getLogger(new)
135
135
136 # not configurable:
136 # not configurable:
137 context = Instance('zmq.Context')
137 context = Instance('zmq.Context')
138 def _context_default(self):
138 def _context_default(self):
139 return zmq.Context.instance()
139 return zmq.Context.instance()
140
140
141 session = Instance('IPython.kernel.zmq.session.Session')
141 session = Instance('IPython.kernel.zmq.session.Session')
142
142
143 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
143 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
144 def _loop_default(self):
144 def _loop_default(self):
145 return IOLoop.instance()
145 return IOLoop.instance()
146
146
147 def __init__(self, **kwargs):
147 def __init__(self, **kwargs):
148 super(SessionFactory, self).__init__(**kwargs)
148 super(SessionFactory, self).__init__(**kwargs)
149
149
150 if self.session is None:
150 if self.session is None:
151 # construct the session
151 # construct the session
152 self.session = Session(**kwargs)
152 self.session = Session(**kwargs)
153
153
154
154
155 class Message(object):
155 class Message(object):
156 """A simple message object that maps dict keys to attributes.
156 """A simple message object that maps dict keys to attributes.
157
157
158 A Message can be created from a dict and a dict from a Message instance
158 A Message can be created from a dict and a dict from a Message instance
159 simply by calling dict(msg_obj)."""
159 simply by calling dict(msg_obj)."""
160
160
161 def __init__(self, msg_dict):
161 def __init__(self, msg_dict):
162 dct = self.__dict__
162 dct = self.__dict__
163 for k, v in iteritems(dict(msg_dict)):
163 for k, v in iteritems(dict(msg_dict)):
164 if isinstance(v, dict):
164 if isinstance(v, dict):
165 v = Message(v)
165 v = Message(v)
166 dct[k] = v
166 dct[k] = v
167
167
168 # Having this iterator lets dict(msg_obj) work out of the box.
168 # Having this iterator lets dict(msg_obj) work out of the box.
169 def __iter__(self):
169 def __iter__(self):
170 return iter(iteritems(self.__dict__))
170 return iter(iteritems(self.__dict__))
171
171
172 def __repr__(self):
172 def __repr__(self):
173 return repr(self.__dict__)
173 return repr(self.__dict__)
174
174
175 def __str__(self):
175 def __str__(self):
176 return pprint.pformat(self.__dict__)
176 return pprint.pformat(self.__dict__)
177
177
178 def __contains__(self, k):
178 def __contains__(self, k):
179 return k in self.__dict__
179 return k in self.__dict__
180
180
181 def __getitem__(self, k):
181 def __getitem__(self, k):
182 return self.__dict__[k]
182 return self.__dict__[k]
183
183
184
184
185 def msg_header(msg_id, msg_type, username, session):
185 def msg_header(msg_id, msg_type, username, session):
186 date = datetime.now()
186 date = datetime.now()
187 version = kernel_protocol_version
187 version = kernel_protocol_version
188 return locals()
188 return locals()
189
189
190 def extract_header(msg_or_header):
190 def extract_header(msg_or_header):
191 """Given a message or header, return the header."""
191 """Given a message or header, return the header."""
192 if not msg_or_header:
192 if not msg_or_header:
193 return {}
193 return {}
194 try:
194 try:
195 # See if msg_or_header is the entire message.
195 # See if msg_or_header is the entire message.
196 h = msg_or_header['header']
196 h = msg_or_header['header']
197 except KeyError:
197 except KeyError:
198 try:
198 try:
199 # See if msg_or_header is just the header
199 # See if msg_or_header is just the header
200 h = msg_or_header['msg_id']
200 h = msg_or_header['msg_id']
201 except KeyError:
201 except KeyError:
202 raise
202 raise
203 else:
203 else:
204 h = msg_or_header
204 h = msg_or_header
205 if not isinstance(h, dict):
205 if not isinstance(h, dict):
206 h = dict(h)
206 h = dict(h)
207 return h
207 return h
208
208
209 class Session(Configurable):
209 class Session(Configurable):
210 """Object for handling serialization and sending of messages.
210 """Object for handling serialization and sending of messages.
211
211
212 The Session object handles building messages and sending them
212 The Session object handles building messages and sending them
213 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
213 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
214 other over the network via Session objects, and only need to work with the
214 other over the network via Session objects, and only need to work with the
215 dict-based IPython message spec. The Session will handle
215 dict-based IPython message spec. The Session will handle
216 serialization/deserialization, security, and metadata.
216 serialization/deserialization, security, and metadata.
217
217
218 Sessions support configurable serialiization via packer/unpacker traits,
218 Sessions support configurable serialiization via packer/unpacker traits,
219 and signing with HMAC digests via the key/keyfile traits.
219 and signing with HMAC digests via the key/keyfile traits.
220
220
221 Parameters
221 Parameters
222 ----------
222 ----------
223
223
224 debug : bool
224 debug : bool
225 whether to trigger extra debugging statements
225 whether to trigger extra debugging statements
226 packer/unpacker : str : 'json', 'pickle' or import_string
226 packer/unpacker : str : 'json', 'pickle' or import_string
227 importstrings for methods to serialize message parts. If just
227 importstrings for methods to serialize message parts. If just
228 'json' or 'pickle', predefined JSON and pickle packers will be used.
228 'json' or 'pickle', predefined JSON and pickle packers will be used.
229 Otherwise, the entire importstring must be used.
229 Otherwise, the entire importstring must be used.
230
230
231 The functions must accept at least valid JSON input, and output *bytes*.
231 The functions must accept at least valid JSON input, and output *bytes*.
232
232
233 For example, to use msgpack:
233 For example, to use msgpack:
234 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
234 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
235 pack/unpack : callables
235 pack/unpack : callables
236 You can also set the pack/unpack callables for serialization directly.
236 You can also set the pack/unpack callables for serialization directly.
237 session : bytes
237 session : bytes
238 the ID of this Session object. The default is to generate a new UUID.
238 the ID of this Session object. The default is to generate a new UUID.
239 username : unicode
239 username : unicode
240 username added to message headers. The default is to ask the OS.
240 username added to message headers. The default is to ask the OS.
241 key : bytes
241 key : bytes
242 The key used to initialize an HMAC signature. If unset, messages
242 The key used to initialize an HMAC signature. If unset, messages
243 will not be signed or checked.
243 will not be signed or checked.
244 keyfile : filepath
244 keyfile : filepath
245 The file containing a key. If this is set, `key` will be initialized
245 The file containing a key. If this is set, `key` will be initialized
246 to the contents of the file.
246 to the contents of the file.
247
247
248 """
248 """
249
249
250 debug=Bool(False, config=True, help="""Debug output in the Session""")
250 debug=Bool(False, config=True, help="""Debug output in the Session""")
251
251
252 packer = DottedObjectName('json',config=True,
252 packer = DottedObjectName('json',config=True,
253 help="""The name of the packer for serializing messages.
253 help="""The name of the packer for serializing messages.
254 Should be one of 'json', 'pickle', or an import name
254 Should be one of 'json', 'pickle', or an import name
255 for a custom callable serializer.""")
255 for a custom callable serializer.""")
256 def _packer_changed(self, name, old, new):
256 def _packer_changed(self, name, old, new):
257 if new.lower() == 'json':
257 if new.lower() == 'json':
258 self.pack = json_packer
258 self.pack = json_packer
259 self.unpack = json_unpacker
259 self.unpack = json_unpacker
260 self.unpacker = new
260 self.unpacker = new
261 elif new.lower() == 'pickle':
261 elif new.lower() == 'pickle':
262 self.pack = pickle_packer
262 self.pack = pickle_packer
263 self.unpack = pickle_unpacker
263 self.unpack = pickle_unpacker
264 self.unpacker = new
264 self.unpacker = new
265 else:
265 else:
266 self.pack = import_item(str(new))
266 self.pack = import_item(str(new))
267
267
268 unpacker = DottedObjectName('json', config=True,
268 unpacker = DottedObjectName('json', config=True,
269 help="""The name of the unpacker for unserializing messages.
269 help="""The name of the unpacker for unserializing messages.
270 Only used with custom functions for `packer`.""")
270 Only used with custom functions for `packer`.""")
271 def _unpacker_changed(self, name, old, new):
271 def _unpacker_changed(self, name, old, new):
272 if new.lower() == 'json':
272 if new.lower() == 'json':
273 self.pack = json_packer
273 self.pack = json_packer
274 self.unpack = json_unpacker
274 self.unpack = json_unpacker
275 self.packer = new
275 self.packer = new
276 elif new.lower() == 'pickle':
276 elif new.lower() == 'pickle':
277 self.pack = pickle_packer
277 self.pack = pickle_packer
278 self.unpack = pickle_unpacker
278 self.unpack = pickle_unpacker
279 self.packer = new
279 self.packer = new
280 else:
280 else:
281 self.unpack = import_item(str(new))
281 self.unpack = import_item(str(new))
282
282
283 session = CUnicode(u'', config=True,
283 session = CUnicode(u'', config=True,
284 help="""The UUID identifying this session.""")
284 help="""The UUID identifying this session.""")
285 def _session_default(self):
285 def _session_default(self):
286 u = unicode_type(uuid.uuid4())
286 u = unicode_type(uuid.uuid4())
287 self.bsession = u.encode('ascii')
287 self.bsession = u.encode('ascii')
288 return u
288 return u
289
289
290 def _session_changed(self, name, old, new):
290 def _session_changed(self, name, old, new):
291 self.bsession = self.session.encode('ascii')
291 self.bsession = self.session.encode('ascii')
292
292
293 # bsession is the session as bytes
293 # bsession is the session as bytes
294 bsession = CBytes(b'')
294 bsession = CBytes(b'')
295
295
296 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
296 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
297 help="""Username for the Session. Default is your system username.""",
297 help="""Username for the Session. Default is your system username.""",
298 config=True)
298 config=True)
299
299
300 metadata = Dict({}, config=True,
300 metadata = Dict({}, config=True,
301 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
301 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
302
302
303 # if 0, no adapting to do.
303 # if 0, no adapting to do.
304 adapt_version = Integer(0)
304 adapt_version = Integer(0)
305
305
306 # message signature related traits:
306 # message signature related traits:
307
307
308 key = CBytes(b'', config=True,
308 key = CBytes(b'', config=True,
309 help="""execution key, for extra authentication.""")
309 help="""execution key, for extra authentication.""")
310 def _key_changed(self, name, old, new):
310 def _key_changed(self):
311 if new:
311 self._new_auth()
312 self.auth = hmac.HMAC(new, digestmod=self.digest_mod)
313 else:
314 self.auth = None
315
312
316 signature_scheme = Unicode('hmac-sha256', config=True,
313 signature_scheme = Unicode('hmac-sha256', config=True,
317 help="""The digest scheme used to construct the message signatures.
314 help="""The digest scheme used to construct the message signatures.
318 Must have the form 'hmac-HASH'.""")
315 Must have the form 'hmac-HASH'.""")
319 def _signature_scheme_changed(self, name, old, new):
316 def _signature_scheme_changed(self, name, old, new):
320 if not new.startswith('hmac-'):
317 if not new.startswith('hmac-'):
321 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
318 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
322 hash_name = new.split('-', 1)[1]
319 hash_name = new.split('-', 1)[1]
323 try:
320 try:
324 self.digest_mod = getattr(hashlib, hash_name)
321 self.digest_mod = getattr(hashlib, hash_name)
325 except AttributeError:
322 except AttributeError:
326 raise TraitError("hashlib has no such attribute: %s" % hash_name)
323 raise TraitError("hashlib has no such attribute: %s" % hash_name)
324 self._new_auth()
327
325
328 digest_mod = Any()
326 digest_mod = Any()
329 def _digest_mod_default(self):
327 def _digest_mod_default(self):
330 return hashlib.sha256
328 return hashlib.sha256
331
329
332 auth = Instance(hmac.HMAC)
330 auth = Instance(hmac.HMAC)
333
331
332 def _new_auth(self):
333 if self.key:
334 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
335 else:
336 self.auth = None
337
334 digest_history = Set()
338 digest_history = Set()
335 digest_history_size = Integer(2**16, config=True,
339 digest_history_size = Integer(2**16, config=True,
336 help="""The maximum number of digests to remember.
340 help="""The maximum number of digests to remember.
337
341
338 The digest history will be culled when it exceeds this value.
342 The digest history will be culled when it exceeds this value.
339 """
343 """
340 )
344 )
341
345
342 keyfile = Unicode('', config=True,
346 keyfile = Unicode('', config=True,
343 help="""path to file containing execution key.""")
347 help="""path to file containing execution key.""")
344 def _keyfile_changed(self, name, old, new):
348 def _keyfile_changed(self, name, old, new):
345 with open(new, 'rb') as f:
349 with open(new, 'rb') as f:
346 self.key = f.read().strip()
350 self.key = f.read().strip()
347
351
348 # for protecting against sends from forks
352 # for protecting against sends from forks
349 pid = Integer()
353 pid = Integer()
350
354
351 # serialization traits:
355 # serialization traits:
352
356
353 pack = Any(default_packer) # the actual packer function
357 pack = Any(default_packer) # the actual packer function
354 def _pack_changed(self, name, old, new):
358 def _pack_changed(self, name, old, new):
355 if not callable(new):
359 if not callable(new):
356 raise TypeError("packer must be callable, not %s"%type(new))
360 raise TypeError("packer must be callable, not %s"%type(new))
357
361
358 unpack = Any(default_unpacker) # the actual packer function
362 unpack = Any(default_unpacker) # the actual packer function
359 def _unpack_changed(self, name, old, new):
363 def _unpack_changed(self, name, old, new):
360 # unpacker is not checked - it is assumed to be
364 # unpacker is not checked - it is assumed to be
361 if not callable(new):
365 if not callable(new):
362 raise TypeError("unpacker must be callable, not %s"%type(new))
366 raise TypeError("unpacker must be callable, not %s"%type(new))
363
367
364 # thresholds:
368 # thresholds:
365 copy_threshold = Integer(2**16, config=True,
369 copy_threshold = Integer(2**16, config=True,
366 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
370 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
367 buffer_threshold = Integer(MAX_BYTES, config=True,
371 buffer_threshold = Integer(MAX_BYTES, config=True,
368 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
372 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
369 item_threshold = Integer(MAX_ITEMS, config=True,
373 item_threshold = Integer(MAX_ITEMS, config=True,
370 help="""The maximum number of items for a container to be introspected for custom serialization.
374 help="""The maximum number of items for a container to be introspected for custom serialization.
371 Containers larger than this are pickled outright.
375 Containers larger than this are pickled outright.
372 """
376 """
373 )
377 )
374
378
375
379
376 def __init__(self, **kwargs):
380 def __init__(self, **kwargs):
377 """create a Session object
381 """create a Session object
378
382
379 Parameters
383 Parameters
380 ----------
384 ----------
381
385
382 debug : bool
386 debug : bool
383 whether to trigger extra debugging statements
387 whether to trigger extra debugging statements
384 packer/unpacker : str : 'json', 'pickle' or import_string
388 packer/unpacker : str : 'json', 'pickle' or import_string
385 importstrings for methods to serialize message parts. If just
389 importstrings for methods to serialize message parts. If just
386 'json' or 'pickle', predefined JSON and pickle packers will be used.
390 'json' or 'pickle', predefined JSON and pickle packers will be used.
387 Otherwise, the entire importstring must be used.
391 Otherwise, the entire importstring must be used.
388
392
389 The functions must accept at least valid JSON input, and output
393 The functions must accept at least valid JSON input, and output
390 *bytes*.
394 *bytes*.
391
395
392 For example, to use msgpack:
396 For example, to use msgpack:
393 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
397 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
394 pack/unpack : callables
398 pack/unpack : callables
395 You can also set the pack/unpack callables for serialization
399 You can also set the pack/unpack callables for serialization
396 directly.
400 directly.
397 session : unicode (must be ascii)
401 session : unicode (must be ascii)
398 the ID of this Session object. The default is to generate a new
402 the ID of this Session object. The default is to generate a new
399 UUID.
403 UUID.
400 bsession : bytes
404 bsession : bytes
401 The session as bytes
405 The session as bytes
402 username : unicode
406 username : unicode
403 username added to message headers. The default is to ask the OS.
407 username added to message headers. The default is to ask the OS.
404 key : bytes
408 key : bytes
405 The key used to initialize an HMAC signature. If unset, messages
409 The key used to initialize an HMAC signature. If unset, messages
406 will not be signed or checked.
410 will not be signed or checked.
407 signature_scheme : str
411 signature_scheme : str
408 The message digest scheme. Currently must be of the form 'hmac-HASH',
412 The message digest scheme. Currently must be of the form 'hmac-HASH',
409 where 'HASH' is a hashing function available in Python's hashlib.
413 where 'HASH' is a hashing function available in Python's hashlib.
410 The default is 'hmac-sha256'.
414 The default is 'hmac-sha256'.
411 This is ignored if 'key' is empty.
415 This is ignored if 'key' is empty.
412 keyfile : filepath
416 keyfile : filepath
413 The file containing a key. If this is set, `key` will be
417 The file containing a key. If this is set, `key` will be
414 initialized to the contents of the file.
418 initialized to the contents of the file.
415 """
419 """
416 super(Session, self).__init__(**kwargs)
420 super(Session, self).__init__(**kwargs)
417 self._check_packers()
421 self._check_packers()
418 self.none = self.pack({})
422 self.none = self.pack({})
419 # ensure self._session_default() if necessary, so bsession is defined:
423 # ensure self._session_default() if necessary, so bsession is defined:
420 self.session
424 self.session
421 self.pid = os.getpid()
425 self.pid = os.getpid()
422
426
423 @property
427 @property
424 def msg_id(self):
428 def msg_id(self):
425 """always return new uuid"""
429 """always return new uuid"""
426 return str(uuid.uuid4())
430 return str(uuid.uuid4())
427
431
428 def _check_packers(self):
432 def _check_packers(self):
429 """check packers for datetime support."""
433 """check packers for datetime support."""
430 pack = self.pack
434 pack = self.pack
431 unpack = self.unpack
435 unpack = self.unpack
432
436
433 # check simple serialization
437 # check simple serialization
434 msg = dict(a=[1,'hi'])
438 msg = dict(a=[1,'hi'])
435 try:
439 try:
436 packed = pack(msg)
440 packed = pack(msg)
437 except Exception as e:
441 except Exception as e:
438 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
442 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
439 if self.packer == 'json':
443 if self.packer == 'json':
440 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
444 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
441 else:
445 else:
442 jsonmsg = ""
446 jsonmsg = ""
443 raise ValueError(
447 raise ValueError(
444 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
448 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
445 )
449 )
446
450
447 # ensure packed message is bytes
451 # ensure packed message is bytes
448 if not isinstance(packed, bytes):
452 if not isinstance(packed, bytes):
449 raise ValueError("message packed to %r, but bytes are required"%type(packed))
453 raise ValueError("message packed to %r, but bytes are required"%type(packed))
450
454
451 # check that unpack is pack's inverse
455 # check that unpack is pack's inverse
452 try:
456 try:
453 unpacked = unpack(packed)
457 unpacked = unpack(packed)
454 assert unpacked == msg
458 assert unpacked == msg
455 except Exception as e:
459 except Exception as e:
456 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
460 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
457 if self.packer == 'json':
461 if self.packer == 'json':
458 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
462 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
459 else:
463 else:
460 jsonmsg = ""
464 jsonmsg = ""
461 raise ValueError(
465 raise ValueError(
462 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
466 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
463 )
467 )
464
468
465 # check datetime support
469 # check datetime support
466 msg = dict(t=datetime.now())
470 msg = dict(t=datetime.now())
467 try:
471 try:
468 unpacked = unpack(pack(msg))
472 unpacked = unpack(pack(msg))
469 if isinstance(unpacked['t'], datetime):
473 if isinstance(unpacked['t'], datetime):
470 raise ValueError("Shouldn't deserialize to datetime")
474 raise ValueError("Shouldn't deserialize to datetime")
471 except Exception:
475 except Exception:
472 self.pack = lambda o: pack(squash_dates(o))
476 self.pack = lambda o: pack(squash_dates(o))
473 self.unpack = lambda s: unpack(s)
477 self.unpack = lambda s: unpack(s)
474
478
475 def msg_header(self, msg_type):
479 def msg_header(self, msg_type):
476 return msg_header(self.msg_id, msg_type, self.username, self.session)
480 return msg_header(self.msg_id, msg_type, self.username, self.session)
477
481
478 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
482 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
479 """Return the nested message dict.
483 """Return the nested message dict.
480
484
481 This format is different from what is sent over the wire. The
485 This format is different from what is sent over the wire. The
482 serialize/unserialize methods converts this nested message dict to the wire
486 serialize/unserialize methods converts this nested message dict to the wire
483 format, which is a list of message parts.
487 format, which is a list of message parts.
484 """
488 """
485 msg = {}
489 msg = {}
486 header = self.msg_header(msg_type) if header is None else header
490 header = self.msg_header(msg_type) if header is None else header
487 msg['header'] = header
491 msg['header'] = header
488 msg['msg_id'] = header['msg_id']
492 msg['msg_id'] = header['msg_id']
489 msg['msg_type'] = header['msg_type']
493 msg['msg_type'] = header['msg_type']
490 msg['parent_header'] = {} if parent is None else extract_header(parent)
494 msg['parent_header'] = {} if parent is None else extract_header(parent)
491 msg['content'] = {} if content is None else content
495 msg['content'] = {} if content is None else content
492 msg['metadata'] = self.metadata.copy()
496 msg['metadata'] = self.metadata.copy()
493 if metadata is not None:
497 if metadata is not None:
494 msg['metadata'].update(metadata)
498 msg['metadata'].update(metadata)
495 return msg
499 return msg
496
500
497 def sign(self, msg_list):
501 def sign(self, msg_list):
498 """Sign a message with HMAC digest. If no auth, return b''.
502 """Sign a message with HMAC digest. If no auth, return b''.
499
503
500 Parameters
504 Parameters
501 ----------
505 ----------
502 msg_list : list
506 msg_list : list
503 The [p_header,p_parent,p_content] part of the message list.
507 The [p_header,p_parent,p_content] part of the message list.
504 """
508 """
505 if self.auth is None:
509 if self.auth is None:
506 return b''
510 return b''
507 h = self.auth.copy()
511 h = self.auth.copy()
508 for m in msg_list:
512 for m in msg_list:
509 h.update(m)
513 h.update(m)
510 return str_to_bytes(h.hexdigest())
514 return str_to_bytes(h.hexdigest())
511
515
512 def serialize(self, msg, ident=None):
516 def serialize(self, msg, ident=None):
513 """Serialize the message components to bytes.
517 """Serialize the message components to bytes.
514
518
515 This is roughly the inverse of unserialize. The serialize/unserialize
519 This is roughly the inverse of unserialize. The serialize/unserialize
516 methods work with full message lists, whereas pack/unpack work with
520 methods work with full message lists, whereas pack/unpack work with
517 the individual message parts in the message list.
521 the individual message parts in the message list.
518
522
519 Parameters
523 Parameters
520 ----------
524 ----------
521 msg : dict or Message
525 msg : dict or Message
522 The nexted message dict as returned by the self.msg method.
526 The nexted message dict as returned by the self.msg method.
523
527
524 Returns
528 Returns
525 -------
529 -------
526 msg_list : list
530 msg_list : list
527 The list of bytes objects to be sent with the format::
531 The list of bytes objects to be sent with the format::
528
532
529 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
533 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
530 p_metadata, p_content, buffer1, buffer2, ...]
534 p_metadata, p_content, buffer1, buffer2, ...]
531
535
532 In this list, the ``p_*`` entities are the packed or serialized
536 In this list, the ``p_*`` entities are the packed or serialized
533 versions, so if JSON is used, these are utf8 encoded JSON strings.
537 versions, so if JSON is used, these are utf8 encoded JSON strings.
534 """
538 """
535 content = msg.get('content', {})
539 content = msg.get('content', {})
536 if content is None:
540 if content is None:
537 content = self.none
541 content = self.none
538 elif isinstance(content, dict):
542 elif isinstance(content, dict):
539 content = self.pack(content)
543 content = self.pack(content)
540 elif isinstance(content, bytes):
544 elif isinstance(content, bytes):
541 # content is already packed, as in a relayed message
545 # content is already packed, as in a relayed message
542 pass
546 pass
543 elif isinstance(content, unicode_type):
547 elif isinstance(content, unicode_type):
544 # should be bytes, but JSON often spits out unicode
548 # should be bytes, but JSON often spits out unicode
545 content = content.encode('utf8')
549 content = content.encode('utf8')
546 else:
550 else:
547 raise TypeError("Content incorrect type: %s"%type(content))
551 raise TypeError("Content incorrect type: %s"%type(content))
548
552
549 real_message = [self.pack(msg['header']),
553 real_message = [self.pack(msg['header']),
550 self.pack(msg['parent_header']),
554 self.pack(msg['parent_header']),
551 self.pack(msg['metadata']),
555 self.pack(msg['metadata']),
552 content,
556 content,
553 ]
557 ]
554
558
555 to_send = []
559 to_send = []
556
560
557 if isinstance(ident, list):
561 if isinstance(ident, list):
558 # accept list of idents
562 # accept list of idents
559 to_send.extend(ident)
563 to_send.extend(ident)
560 elif ident is not None:
564 elif ident is not None:
561 to_send.append(ident)
565 to_send.append(ident)
562 to_send.append(DELIM)
566 to_send.append(DELIM)
563
567
564 signature = self.sign(real_message)
568 signature = self.sign(real_message)
565 to_send.append(signature)
569 to_send.append(signature)
566
570
567 to_send.extend(real_message)
571 to_send.extend(real_message)
568
572
569 return to_send
573 return to_send
570
574
571 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
575 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
572 buffers=None, track=False, header=None, metadata=None):
576 buffers=None, track=False, header=None, metadata=None):
573 """Build and send a message via stream or socket.
577 """Build and send a message via stream or socket.
574
578
575 The message format used by this function internally is as follows:
579 The message format used by this function internally is as follows:
576
580
577 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
581 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
578 buffer1,buffer2,...]
582 buffer1,buffer2,...]
579
583
580 The serialize/unserialize methods convert the nested message dict into this
584 The serialize/unserialize methods convert the nested message dict into this
581 format.
585 format.
582
586
583 Parameters
587 Parameters
584 ----------
588 ----------
585
589
586 stream : zmq.Socket or ZMQStream
590 stream : zmq.Socket or ZMQStream
587 The socket-like object used to send the data.
591 The socket-like object used to send the data.
588 msg_or_type : str or Message/dict
592 msg_or_type : str or Message/dict
589 Normally, msg_or_type will be a msg_type unless a message is being
593 Normally, msg_or_type will be a msg_type unless a message is being
590 sent more than once. If a header is supplied, this can be set to
594 sent more than once. If a header is supplied, this can be set to
591 None and the msg_type will be pulled from the header.
595 None and the msg_type will be pulled from the header.
592
596
593 content : dict or None
597 content : dict or None
594 The content of the message (ignored if msg_or_type is a message).
598 The content of the message (ignored if msg_or_type is a message).
595 header : dict or None
599 header : dict or None
596 The header dict for the message (ignored if msg_to_type is a message).
600 The header dict for the message (ignored if msg_to_type is a message).
597 parent : Message or dict or None
601 parent : Message or dict or None
598 The parent or parent header describing the parent of this message
602 The parent or parent header describing the parent of this message
599 (ignored if msg_or_type is a message).
603 (ignored if msg_or_type is a message).
600 ident : bytes or list of bytes
604 ident : bytes or list of bytes
601 The zmq.IDENTITY routing path.
605 The zmq.IDENTITY routing path.
602 metadata : dict or None
606 metadata : dict or None
603 The metadata describing the message
607 The metadata describing the message
604 buffers : list or None
608 buffers : list or None
605 The already-serialized buffers to be appended to the message.
609 The already-serialized buffers to be appended to the message.
606 track : bool
610 track : bool
607 Whether to track. Only for use with Sockets, because ZMQStream
611 Whether to track. Only for use with Sockets, because ZMQStream
608 objects cannot track messages.
612 objects cannot track messages.
609
613
610
614
611 Returns
615 Returns
612 -------
616 -------
613 msg : dict
617 msg : dict
614 The constructed message.
618 The constructed message.
615 """
619 """
616 if not isinstance(stream, zmq.Socket):
620 if not isinstance(stream, zmq.Socket):
617 # ZMQStreams and dummy sockets do not support tracking.
621 # ZMQStreams and dummy sockets do not support tracking.
618 track = False
622 track = False
619
623
620 if isinstance(msg_or_type, (Message, dict)):
624 if isinstance(msg_or_type, (Message, dict)):
621 # We got a Message or message dict, not a msg_type so don't
625 # We got a Message or message dict, not a msg_type so don't
622 # build a new Message.
626 # build a new Message.
623 msg = msg_or_type
627 msg = msg_or_type
624 else:
628 else:
625 msg = self.msg(msg_or_type, content=content, parent=parent,
629 msg = self.msg(msg_or_type, content=content, parent=parent,
626 header=header, metadata=metadata)
630 header=header, metadata=metadata)
627 if not os.getpid() == self.pid:
631 if not os.getpid() == self.pid:
628 io.rprint("WARNING: attempted to send message from fork")
632 io.rprint("WARNING: attempted to send message from fork")
629 io.rprint(msg)
633 io.rprint(msg)
630 return
634 return
631 buffers = [] if buffers is None else buffers
635 buffers = [] if buffers is None else buffers
632 if self.adapt_version:
636 if self.adapt_version:
633 msg = adapt(msg, self.adapt_version)
637 msg = adapt(msg, self.adapt_version)
634 to_send = self.serialize(msg, ident)
638 to_send = self.serialize(msg, ident)
635 to_send.extend(buffers)
639 to_send.extend(buffers)
636 longest = max([ len(s) for s in to_send ])
640 longest = max([ len(s) for s in to_send ])
637 copy = (longest < self.copy_threshold)
641 copy = (longest < self.copy_threshold)
638
642
639 if buffers and track and not copy:
643 if buffers and track and not copy:
640 # only really track when we are doing zero-copy buffers
644 # only really track when we are doing zero-copy buffers
641 tracker = stream.send_multipart(to_send, copy=False, track=True)
645 tracker = stream.send_multipart(to_send, copy=False, track=True)
642 else:
646 else:
643 # use dummy tracker, which will be done immediately
647 # use dummy tracker, which will be done immediately
644 tracker = DONE
648 tracker = DONE
645 stream.send_multipart(to_send, copy=copy)
649 stream.send_multipart(to_send, copy=copy)
646
650
647 if self.debug:
651 if self.debug:
648 pprint.pprint(msg)
652 pprint.pprint(msg)
649 pprint.pprint(to_send)
653 pprint.pprint(to_send)
650 pprint.pprint(buffers)
654 pprint.pprint(buffers)
651
655
652 msg['tracker'] = tracker
656 msg['tracker'] = tracker
653
657
654 return msg
658 return msg
655
659
656 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
660 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
657 """Send a raw message via ident path.
661 """Send a raw message via ident path.
658
662
659 This method is used to send a already serialized message.
663 This method is used to send a already serialized message.
660
664
661 Parameters
665 Parameters
662 ----------
666 ----------
663 stream : ZMQStream or Socket
667 stream : ZMQStream or Socket
664 The ZMQ stream or socket to use for sending the message.
668 The ZMQ stream or socket to use for sending the message.
665 msg_list : list
669 msg_list : list
666 The serialized list of messages to send. This only includes the
670 The serialized list of messages to send. This only includes the
667 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
671 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
668 the message.
672 the message.
669 ident : ident or list
673 ident : ident or list
670 A single ident or a list of idents to use in sending.
674 A single ident or a list of idents to use in sending.
671 """
675 """
672 to_send = []
676 to_send = []
673 if isinstance(ident, bytes):
677 if isinstance(ident, bytes):
674 ident = [ident]
678 ident = [ident]
675 if ident is not None:
679 if ident is not None:
676 to_send.extend(ident)
680 to_send.extend(ident)
677
681
678 to_send.append(DELIM)
682 to_send.append(DELIM)
679 to_send.append(self.sign(msg_list))
683 to_send.append(self.sign(msg_list))
680 to_send.extend(msg_list)
684 to_send.extend(msg_list)
681 stream.send_multipart(to_send, flags, copy=copy)
685 stream.send_multipart(to_send, flags, copy=copy)
682
686
683 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
687 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
684 """Receive and unpack a message.
688 """Receive and unpack a message.
685
689
686 Parameters
690 Parameters
687 ----------
691 ----------
688 socket : ZMQStream or Socket
692 socket : ZMQStream or Socket
689 The socket or stream to use in receiving.
693 The socket or stream to use in receiving.
690
694
691 Returns
695 Returns
692 -------
696 -------
693 [idents], msg
697 [idents], msg
694 [idents] is a list of idents and msg is a nested message dict of
698 [idents] is a list of idents and msg is a nested message dict of
695 same format as self.msg returns.
699 same format as self.msg returns.
696 """
700 """
697 if isinstance(socket, ZMQStream):
701 if isinstance(socket, ZMQStream):
698 socket = socket.socket
702 socket = socket.socket
699 try:
703 try:
700 msg_list = socket.recv_multipart(mode, copy=copy)
704 msg_list = socket.recv_multipart(mode, copy=copy)
701 except zmq.ZMQError as e:
705 except zmq.ZMQError as e:
702 if e.errno == zmq.EAGAIN:
706 if e.errno == zmq.EAGAIN:
703 # We can convert EAGAIN to None as we know in this case
707 # We can convert EAGAIN to None as we know in this case
704 # recv_multipart won't return None.
708 # recv_multipart won't return None.
705 return None,None
709 return None,None
706 else:
710 else:
707 raise
711 raise
708 # split multipart message into identity list and message dict
712 # split multipart message into identity list and message dict
709 # invalid large messages can cause very expensive string comparisons
713 # invalid large messages can cause very expensive string comparisons
710 idents, msg_list = self.feed_identities(msg_list, copy)
714 idents, msg_list = self.feed_identities(msg_list, copy)
711 try:
715 try:
712 return idents, self.unserialize(msg_list, content=content, copy=copy)
716 return idents, self.unserialize(msg_list, content=content, copy=copy)
713 except Exception as e:
717 except Exception as e:
714 # TODO: handle it
718 # TODO: handle it
715 raise e
719 raise e
716
720
717 def feed_identities(self, msg_list, copy=True):
721 def feed_identities(self, msg_list, copy=True):
718 """Split the identities from the rest of the message.
722 """Split the identities from the rest of the message.
719
723
720 Feed until DELIM is reached, then return the prefix as idents and
724 Feed until DELIM is reached, then return the prefix as idents and
721 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
725 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
722 but that would be silly.
726 but that would be silly.
723
727
724 Parameters
728 Parameters
725 ----------
729 ----------
726 msg_list : a list of Message or bytes objects
730 msg_list : a list of Message or bytes objects
727 The message to be split.
731 The message to be split.
728 copy : bool
732 copy : bool
729 flag determining whether the arguments are bytes or Messages
733 flag determining whether the arguments are bytes or Messages
730
734
731 Returns
735 Returns
732 -------
736 -------
733 (idents, msg_list) : two lists
737 (idents, msg_list) : two lists
734 idents will always be a list of bytes, each of which is a ZMQ
738 idents will always be a list of bytes, each of which is a ZMQ
735 identity. msg_list will be a list of bytes or zmq.Messages of the
739 identity. msg_list will be a list of bytes or zmq.Messages of the
736 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
740 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
737 should be unpackable/unserializable via self.unserialize at this
741 should be unpackable/unserializable via self.unserialize at this
738 point.
742 point.
739 """
743 """
740 if copy:
744 if copy:
741 idx = msg_list.index(DELIM)
745 idx = msg_list.index(DELIM)
742 return msg_list[:idx], msg_list[idx+1:]
746 return msg_list[:idx], msg_list[idx+1:]
743 else:
747 else:
744 failed = True
748 failed = True
745 for idx,m in enumerate(msg_list):
749 for idx,m in enumerate(msg_list):
746 if m.bytes == DELIM:
750 if m.bytes == DELIM:
747 failed = False
751 failed = False
748 break
752 break
749 if failed:
753 if failed:
750 raise ValueError("DELIM not in msg_list")
754 raise ValueError("DELIM not in msg_list")
751 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
755 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
752 return [m.bytes for m in idents], msg_list
756 return [m.bytes for m in idents], msg_list
753
757
754 def _add_digest(self, signature):
758 def _add_digest(self, signature):
755 """add a digest to history to protect against replay attacks"""
759 """add a digest to history to protect against replay attacks"""
756 if self.digest_history_size == 0:
760 if self.digest_history_size == 0:
757 # no history, never add digests
761 # no history, never add digests
758 return
762 return
759
763
760 self.digest_history.add(signature)
764 self.digest_history.add(signature)
761 if len(self.digest_history) > self.digest_history_size:
765 if len(self.digest_history) > self.digest_history_size:
762 # threshold reached, cull 10%
766 # threshold reached, cull 10%
763 self._cull_digest_history()
767 self._cull_digest_history()
764
768
765 def _cull_digest_history(self):
769 def _cull_digest_history(self):
766 """cull the digest history
770 """cull the digest history
767
771
768 Removes a randomly selected 10% of the digest history
772 Removes a randomly selected 10% of the digest history
769 """
773 """
770 current = len(self.digest_history)
774 current = len(self.digest_history)
771 n_to_cull = max(int(current // 10), current - self.digest_history_size)
775 n_to_cull = max(int(current // 10), current - self.digest_history_size)
772 if n_to_cull >= current:
776 if n_to_cull >= current:
773 self.digest_history = set()
777 self.digest_history = set()
774 return
778 return
775 to_cull = random.sample(self.digest_history, n_to_cull)
779 to_cull = random.sample(self.digest_history, n_to_cull)
776 self.digest_history.difference_update(to_cull)
780 self.digest_history.difference_update(to_cull)
777
781
778 def unserialize(self, msg_list, content=True, copy=True):
782 def unserialize(self, msg_list, content=True, copy=True):
779 """Unserialize a msg_list to a nested message dict.
783 """Unserialize a msg_list to a nested message dict.
780
784
781 This is roughly the inverse of serialize. The serialize/unserialize
785 This is roughly the inverse of serialize. The serialize/unserialize
782 methods work with full message lists, whereas pack/unpack work with
786 methods work with full message lists, whereas pack/unpack work with
783 the individual message parts in the message list.
787 the individual message parts in the message list.
784
788
785 Parameters
789 Parameters
786 ----------
790 ----------
787 msg_list : list of bytes or Message objects
791 msg_list : list of bytes or Message objects
788 The list of message parts of the form [HMAC,p_header,p_parent,
792 The list of message parts of the form [HMAC,p_header,p_parent,
789 p_metadata,p_content,buffer1,buffer2,...].
793 p_metadata,p_content,buffer1,buffer2,...].
790 content : bool (True)
794 content : bool (True)
791 Whether to unpack the content dict (True), or leave it packed
795 Whether to unpack the content dict (True), or leave it packed
792 (False).
796 (False).
793 copy : bool (True)
797 copy : bool (True)
794 Whether to return the bytes (True), or the non-copying Message
798 Whether to return the bytes (True), or the non-copying Message
795 object in each place (False).
799 object in each place (False).
796
800
797 Returns
801 Returns
798 -------
802 -------
799 msg : dict
803 msg : dict
800 The nested message dict with top-level keys [header, parent_header,
804 The nested message dict with top-level keys [header, parent_header,
801 content, buffers].
805 content, buffers].
802 """
806 """
803 minlen = 5
807 minlen = 5
804 message = {}
808 message = {}
805 if not copy:
809 if not copy:
806 for i in range(minlen):
810 for i in range(minlen):
807 msg_list[i] = msg_list[i].bytes
811 msg_list[i] = msg_list[i].bytes
808 if self.auth is not None:
812 if self.auth is not None:
809 signature = msg_list[0]
813 signature = msg_list[0]
810 if not signature:
814 if not signature:
811 raise ValueError("Unsigned Message")
815 raise ValueError("Unsigned Message")
812 if signature in self.digest_history:
816 if signature in self.digest_history:
813 raise ValueError("Duplicate Signature: %r" % signature)
817 raise ValueError("Duplicate Signature: %r" % signature)
814 self._add_digest(signature)
818 self._add_digest(signature)
815 check = self.sign(msg_list[1:5])
819 check = self.sign(msg_list[1:5])
816 if not signature == check:
820 if not signature == check:
817 raise ValueError("Invalid Signature: %r" % signature)
821 raise ValueError("Invalid Signature: %r" % signature)
818 if not len(msg_list) >= minlen:
822 if not len(msg_list) >= minlen:
819 raise TypeError("malformed message, must have at least %i elements"%minlen)
823 raise TypeError("malformed message, must have at least %i elements"%minlen)
820 header = self.unpack(msg_list[1])
824 header = self.unpack(msg_list[1])
821 message['header'] = extract_dates(header)
825 message['header'] = extract_dates(header)
822 message['msg_id'] = header['msg_id']
826 message['msg_id'] = header['msg_id']
823 message['msg_type'] = header['msg_type']
827 message['msg_type'] = header['msg_type']
824 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
828 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
825 message['metadata'] = self.unpack(msg_list[3])
829 message['metadata'] = self.unpack(msg_list[3])
826 if content:
830 if content:
827 message['content'] = self.unpack(msg_list[4])
831 message['content'] = self.unpack(msg_list[4])
828 else:
832 else:
829 message['content'] = msg_list[4]
833 message['content'] = msg_list[4]
830
834
831 message['buffers'] = msg_list[5:]
835 message['buffers'] = msg_list[5:]
832 # print("received: %s: %s\n %s" % (message['msg_type'], message['header'], message['content']))
836 # print("received: %s: %s\n %s" % (message['msg_type'], message['header'], message['content']))
833 # adapt to the current version
837 # adapt to the current version
834 return adapt(message)
838 return adapt(message)
835 # print("adapted: %s: %s\n %s" % (adapted['msg_type'], adapted['header'], adapted['content']))
839 # print("adapted: %s: %s\n %s" % (adapted['msg_type'], adapted['header'], adapted['content']))
836
840
837 def test_msg2obj():
841 def test_msg2obj():
838 am = dict(x=1)
842 am = dict(x=1)
839 ao = Message(am)
843 ao = Message(am)
840 assert ao.x == am['x']
844 assert ao.x == am['x']
841
845
842 am['y'] = dict(z=1)
846 am['y'] = dict(z=1)
843 ao = Message(am)
847 ao = Message(am)
844 assert ao.y.z == am['y']['z']
848 assert ao.y.z == am['y']['z']
845
849
846 k1, k2 = 'y', 'z'
850 k1, k2 = 'y', 'z'
847 assert ao[k1][k2] == am[k1][k2]
851 assert ao[k1][k2] == am[k1][k2]
848
852
849 am2 = dict(ao)
853 am2 = dict(ao)
850 assert am['x'] == am2['x']
854 assert am['x'] == am2['x']
851 assert am['y']['z'] == am2['y']['z']
855 assert am['y']['z'] == am2['y']['z']
852
856
General Comments 0
You need to be logged in to leave comments. Login now