##// END OF EJS Templates
support Python older than 2.7.7
Paul Ivanov -
Show More
@@ -1,857 +1,864 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9 """
10 10
11 11 # Copyright (c) IPython Development Team.
12 12 # Distributed under the terms of the Modified BSD License.
13 13
14 14 import hashlib
15 15 import hmac
16 16 import logging
17 17 import os
18 18 import pprint
19 19 import random
20 20 import uuid
21 21 from datetime import datetime
22 22
23 23 try:
24 24 import cPickle
25 25 pickle = cPickle
26 26 except:
27 27 cPickle = None
28 28 import pickle
29 29
30 try:
31 from hmac import compare_digest
32 except ImportError:
33 # Python < 2.7.7
34 def compare_digest(a,b):
35 return a == b
36
30 37 import zmq
31 38 from zmq.utils import jsonapi
32 39 from zmq.eventloop.ioloop import IOLoop
33 40 from zmq.eventloop.zmqstream import ZMQStream
34 41
35 42 from IPython.core.release import kernel_protocol_version, kernel_protocol_version_info
36 43 from IPython.config.configurable import Configurable, LoggingConfigurable
37 44 from IPython.utils import io
38 45 from IPython.utils.importstring import import_item
39 46 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
40 47 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
41 48 iteritems)
42 49 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
43 50 DottedObjectName, CUnicode, Dict, Integer,
44 51 TraitError,
45 52 )
46 53 from IPython.utils.pickleutil import PICKLE_PROTOCOL
47 54 from IPython.kernel.adapter import adapt
48 55 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
49 56
50 57 #-----------------------------------------------------------------------------
51 58 # utility functions
52 59 #-----------------------------------------------------------------------------
53 60
54 61 def squash_unicode(obj):
55 62 """coerce unicode back to bytestrings."""
56 63 if isinstance(obj,dict):
57 64 for key in obj.keys():
58 65 obj[key] = squash_unicode(obj[key])
59 66 if isinstance(key, unicode_type):
60 67 obj[squash_unicode(key)] = obj.pop(key)
61 68 elif isinstance(obj, list):
62 69 for i,v in enumerate(obj):
63 70 obj[i] = squash_unicode(v)
64 71 elif isinstance(obj, unicode_type):
65 72 obj = obj.encode('utf8')
66 73 return obj
67 74
68 75 #-----------------------------------------------------------------------------
69 76 # globals and defaults
70 77 #-----------------------------------------------------------------------------
71 78
72 79 # ISO8601-ify datetime objects
73 80 # allow unicode
74 81 # disallow nan, because it's not actually valid JSON
75 82 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
76 83 ensure_ascii=False, allow_nan=False,
77 84 )
78 85 json_unpacker = lambda s: jsonapi.loads(s)
79 86
80 87 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
81 88 pickle_unpacker = pickle.loads
82 89
83 90 default_packer = json_packer
84 91 default_unpacker = json_unpacker
85 92
86 93 DELIM = b"<IDS|MSG>"
87 94 # singleton dummy tracker, which will always report as done
88 95 DONE = zmq.MessageTracker()
89 96
90 97 #-----------------------------------------------------------------------------
91 98 # Mixin tools for apps that use Sessions
92 99 #-----------------------------------------------------------------------------
93 100
94 101 session_aliases = dict(
95 102 ident = 'Session.session',
96 103 user = 'Session.username',
97 104 keyfile = 'Session.keyfile',
98 105 )
99 106
100 107 session_flags = {
101 108 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
102 109 'keyfile' : '' }},
103 110 """Use HMAC digests for authentication of messages.
104 111 Setting this flag will generate a new UUID to use as the HMAC key.
105 112 """),
106 113 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
107 114 """Don't authenticate messages."""),
108 115 }
109 116
110 117 def default_secure(cfg):
111 118 """Set the default behavior for a config environment to be secure.
112 119
113 120 If Session.key/keyfile have not been set, set Session.key to
114 121 a new random UUID.
115 122 """
116 123
117 124 if 'Session' in cfg:
118 125 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
119 126 return
120 127 # key/keyfile not specified, generate new UUID:
121 128 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
122 129
123 130
124 131 #-----------------------------------------------------------------------------
125 132 # Classes
126 133 #-----------------------------------------------------------------------------
127 134
128 135 class SessionFactory(LoggingConfigurable):
129 136 """The Base class for configurables that have a Session, Context, logger,
130 137 and IOLoop.
131 138 """
132 139
133 140 logname = Unicode('')
134 141 def _logname_changed(self, name, old, new):
135 142 self.log = logging.getLogger(new)
136 143
137 144 # not configurable:
138 145 context = Instance('zmq.Context')
139 146 def _context_default(self):
140 147 return zmq.Context.instance()
141 148
142 149 session = Instance('IPython.kernel.zmq.session.Session')
143 150
144 151 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
145 152 def _loop_default(self):
146 153 return IOLoop.instance()
147 154
148 155 def __init__(self, **kwargs):
149 156 super(SessionFactory, self).__init__(**kwargs)
150 157
151 158 if self.session is None:
152 159 # construct the session
153 160 self.session = Session(**kwargs)
154 161
155 162
156 163 class Message(object):
157 164 """A simple message object that maps dict keys to attributes.
158 165
159 166 A Message can be created from a dict and a dict from a Message instance
160 167 simply by calling dict(msg_obj)."""
161 168
162 169 def __init__(self, msg_dict):
163 170 dct = self.__dict__
164 171 for k, v in iteritems(dict(msg_dict)):
165 172 if isinstance(v, dict):
166 173 v = Message(v)
167 174 dct[k] = v
168 175
169 176 # Having this iterator lets dict(msg_obj) work out of the box.
170 177 def __iter__(self):
171 178 return iter(iteritems(self.__dict__))
172 179
173 180 def __repr__(self):
174 181 return repr(self.__dict__)
175 182
176 183 def __str__(self):
177 184 return pprint.pformat(self.__dict__)
178 185
179 186 def __contains__(self, k):
180 187 return k in self.__dict__
181 188
182 189 def __getitem__(self, k):
183 190 return self.__dict__[k]
184 191
185 192
186 193 def msg_header(msg_id, msg_type, username, session):
187 194 date = datetime.now()
188 195 version = kernel_protocol_version
189 196 return locals()
190 197
191 198 def extract_header(msg_or_header):
192 199 """Given a message or header, return the header."""
193 200 if not msg_or_header:
194 201 return {}
195 202 try:
196 203 # See if msg_or_header is the entire message.
197 204 h = msg_or_header['header']
198 205 except KeyError:
199 206 try:
200 207 # See if msg_or_header is just the header
201 208 h = msg_or_header['msg_id']
202 209 except KeyError:
203 210 raise
204 211 else:
205 212 h = msg_or_header
206 213 if not isinstance(h, dict):
207 214 h = dict(h)
208 215 return h
209 216
210 217 class Session(Configurable):
211 218 """Object for handling serialization and sending of messages.
212 219
213 220 The Session object handles building messages and sending them
214 221 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
215 222 other over the network via Session objects, and only need to work with the
216 223 dict-based IPython message spec. The Session will handle
217 224 serialization/deserialization, security, and metadata.
218 225
219 226 Sessions support configurable serialiization via packer/unpacker traits,
220 227 and signing with HMAC digests via the key/keyfile traits.
221 228
222 229 Parameters
223 230 ----------
224 231
225 232 debug : bool
226 233 whether to trigger extra debugging statements
227 234 packer/unpacker : str : 'json', 'pickle' or import_string
228 235 importstrings for methods to serialize message parts. If just
229 236 'json' or 'pickle', predefined JSON and pickle packers will be used.
230 237 Otherwise, the entire importstring must be used.
231 238
232 239 The functions must accept at least valid JSON input, and output *bytes*.
233 240
234 241 For example, to use msgpack:
235 242 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
236 243 pack/unpack : callables
237 244 You can also set the pack/unpack callables for serialization directly.
238 245 session : bytes
239 246 the ID of this Session object. The default is to generate a new UUID.
240 247 username : unicode
241 248 username added to message headers. The default is to ask the OS.
242 249 key : bytes
243 250 The key used to initialize an HMAC signature. If unset, messages
244 251 will not be signed or checked.
245 252 keyfile : filepath
246 253 The file containing a key. If this is set, `key` will be initialized
247 254 to the contents of the file.
248 255
249 256 """
250 257
251 258 debug=Bool(False, config=True, help="""Debug output in the Session""")
252 259
253 260 packer = DottedObjectName('json',config=True,
254 261 help="""The name of the packer for serializing messages.
255 262 Should be one of 'json', 'pickle', or an import name
256 263 for a custom callable serializer.""")
257 264 def _packer_changed(self, name, old, new):
258 265 if new.lower() == 'json':
259 266 self.pack = json_packer
260 267 self.unpack = json_unpacker
261 268 self.unpacker = new
262 269 elif new.lower() == 'pickle':
263 270 self.pack = pickle_packer
264 271 self.unpack = pickle_unpacker
265 272 self.unpacker = new
266 273 else:
267 274 self.pack = import_item(str(new))
268 275
269 276 unpacker = DottedObjectName('json', config=True,
270 277 help="""The name of the unpacker for unserializing messages.
271 278 Only used with custom functions for `packer`.""")
272 279 def _unpacker_changed(self, name, old, new):
273 280 if new.lower() == 'json':
274 281 self.pack = json_packer
275 282 self.unpack = json_unpacker
276 283 self.packer = new
277 284 elif new.lower() == 'pickle':
278 285 self.pack = pickle_packer
279 286 self.unpack = pickle_unpacker
280 287 self.packer = new
281 288 else:
282 289 self.unpack = import_item(str(new))
283 290
284 291 session = CUnicode(u'', config=True,
285 292 help="""The UUID identifying this session.""")
286 293 def _session_default(self):
287 294 u = unicode_type(uuid.uuid4())
288 295 self.bsession = u.encode('ascii')
289 296 return u
290 297
291 298 def _session_changed(self, name, old, new):
292 299 self.bsession = self.session.encode('ascii')
293 300
294 301 # bsession is the session as bytes
295 302 bsession = CBytes(b'')
296 303
297 304 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
298 305 help="""Username for the Session. Default is your system username.""",
299 306 config=True)
300 307
301 308 metadata = Dict({}, config=True,
302 309 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
303 310
304 311 # if 0, no adapting to do.
305 312 adapt_version = Integer(0)
306 313
307 314 # message signature related traits:
308 315
309 316 key = CBytes(b'', config=True,
310 317 help="""execution key, for extra authentication.""")
311 318 def _key_changed(self):
312 319 self._new_auth()
313 320
314 321 signature_scheme = Unicode('hmac-sha256', config=True,
315 322 help="""The digest scheme used to construct the message signatures.
316 323 Must have the form 'hmac-HASH'.""")
317 324 def _signature_scheme_changed(self, name, old, new):
318 325 if not new.startswith('hmac-'):
319 326 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
320 327 hash_name = new.split('-', 1)[1]
321 328 try:
322 329 self.digest_mod = getattr(hashlib, hash_name)
323 330 except AttributeError:
324 331 raise TraitError("hashlib has no such attribute: %s" % hash_name)
325 332 self._new_auth()
326 333
327 334 digest_mod = Any()
328 335 def _digest_mod_default(self):
329 336 return hashlib.sha256
330 337
331 338 auth = Instance(hmac.HMAC)
332 339
333 340 def _new_auth(self):
334 341 if self.key:
335 342 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
336 343 else:
337 344 self.auth = None
338 345
339 346 digest_history = Set()
340 347 digest_history_size = Integer(2**16, config=True,
341 348 help="""The maximum number of digests to remember.
342 349
343 350 The digest history will be culled when it exceeds this value.
344 351 """
345 352 )
346 353
347 354 keyfile = Unicode('', config=True,
348 355 help="""path to file containing execution key.""")
349 356 def _keyfile_changed(self, name, old, new):
350 357 with open(new, 'rb') as f:
351 358 self.key = f.read().strip()
352 359
353 360 # for protecting against sends from forks
354 361 pid = Integer()
355 362
356 363 # serialization traits:
357 364
358 365 pack = Any(default_packer) # the actual packer function
359 366 def _pack_changed(self, name, old, new):
360 367 if not callable(new):
361 368 raise TypeError("packer must be callable, not %s"%type(new))
362 369
363 370 unpack = Any(default_unpacker) # the actual packer function
364 371 def _unpack_changed(self, name, old, new):
365 372 # unpacker is not checked - it is assumed to be
366 373 if not callable(new):
367 374 raise TypeError("unpacker must be callable, not %s"%type(new))
368 375
369 376 # thresholds:
370 377 copy_threshold = Integer(2**16, config=True,
371 378 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
372 379 buffer_threshold = Integer(MAX_BYTES, config=True,
373 380 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
374 381 item_threshold = Integer(MAX_ITEMS, config=True,
375 382 help="""The maximum number of items for a container to be introspected for custom serialization.
376 383 Containers larger than this are pickled outright.
377 384 """
378 385 )
379 386
380 387
381 388 def __init__(self, **kwargs):
382 389 """create a Session object
383 390
384 391 Parameters
385 392 ----------
386 393
387 394 debug : bool
388 395 whether to trigger extra debugging statements
389 396 packer/unpacker : str : 'json', 'pickle' or import_string
390 397 importstrings for methods to serialize message parts. If just
391 398 'json' or 'pickle', predefined JSON and pickle packers will be used.
392 399 Otherwise, the entire importstring must be used.
393 400
394 401 The functions must accept at least valid JSON input, and output
395 402 *bytes*.
396 403
397 404 For example, to use msgpack:
398 405 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
399 406 pack/unpack : callables
400 407 You can also set the pack/unpack callables for serialization
401 408 directly.
402 409 session : unicode (must be ascii)
403 410 the ID of this Session object. The default is to generate a new
404 411 UUID.
405 412 bsession : bytes
406 413 The session as bytes
407 414 username : unicode
408 415 username added to message headers. The default is to ask the OS.
409 416 key : bytes
410 417 The key used to initialize an HMAC signature. If unset, messages
411 418 will not be signed or checked.
412 419 signature_scheme : str
413 420 The message digest scheme. Currently must be of the form 'hmac-HASH',
414 421 where 'HASH' is a hashing function available in Python's hashlib.
415 422 The default is 'hmac-sha256'.
416 423 This is ignored if 'key' is empty.
417 424 keyfile : filepath
418 425 The file containing a key. If this is set, `key` will be
419 426 initialized to the contents of the file.
420 427 """
421 428 super(Session, self).__init__(**kwargs)
422 429 self._check_packers()
423 430 self.none = self.pack({})
424 431 # ensure self._session_default() if necessary, so bsession is defined:
425 432 self.session
426 433 self.pid = os.getpid()
427 434
428 435 @property
429 436 def msg_id(self):
430 437 """always return new uuid"""
431 438 return str(uuid.uuid4())
432 439
433 440 def _check_packers(self):
434 441 """check packers for datetime support."""
435 442 pack = self.pack
436 443 unpack = self.unpack
437 444
438 445 # check simple serialization
439 446 msg = dict(a=[1,'hi'])
440 447 try:
441 448 packed = pack(msg)
442 449 except Exception as e:
443 450 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
444 451 if self.packer == 'json':
445 452 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
446 453 else:
447 454 jsonmsg = ""
448 455 raise ValueError(
449 456 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
450 457 )
451 458
452 459 # ensure packed message is bytes
453 460 if not isinstance(packed, bytes):
454 461 raise ValueError("message packed to %r, but bytes are required"%type(packed))
455 462
456 463 # check that unpack is pack's inverse
457 464 try:
458 465 unpacked = unpack(packed)
459 466 assert unpacked == msg
460 467 except Exception as e:
461 468 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
462 469 if self.packer == 'json':
463 470 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
464 471 else:
465 472 jsonmsg = ""
466 473 raise ValueError(
467 474 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
468 475 )
469 476
470 477 # check datetime support
471 478 msg = dict(t=datetime.now())
472 479 try:
473 480 unpacked = unpack(pack(msg))
474 481 if isinstance(unpacked['t'], datetime):
475 482 raise ValueError("Shouldn't deserialize to datetime")
476 483 except Exception:
477 484 self.pack = lambda o: pack(squash_dates(o))
478 485 self.unpack = lambda s: unpack(s)
479 486
480 487 def msg_header(self, msg_type):
481 488 return msg_header(self.msg_id, msg_type, self.username, self.session)
482 489
483 490 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
484 491 """Return the nested message dict.
485 492
486 493 This format is different from what is sent over the wire. The
487 494 serialize/unserialize methods converts this nested message dict to the wire
488 495 format, which is a list of message parts.
489 496 """
490 497 msg = {}
491 498 header = self.msg_header(msg_type) if header is None else header
492 499 msg['header'] = header
493 500 msg['msg_id'] = header['msg_id']
494 501 msg['msg_type'] = header['msg_type']
495 502 msg['parent_header'] = {} if parent is None else extract_header(parent)
496 503 msg['content'] = {} if content is None else content
497 504 msg['metadata'] = self.metadata.copy()
498 505 if metadata is not None:
499 506 msg['metadata'].update(metadata)
500 507 return msg
501 508
502 509 def sign(self, msg_list):
503 510 """Sign a message with HMAC digest. If no auth, return b''.
504 511
505 512 Parameters
506 513 ----------
507 514 msg_list : list
508 515 The [p_header,p_parent,p_content] part of the message list.
509 516 """
510 517 if self.auth is None:
511 518 return b''
512 519 h = self.auth.copy()
513 520 for m in msg_list:
514 521 h.update(m)
515 522 return str_to_bytes(h.hexdigest())
516 523
517 524 def serialize(self, msg, ident=None):
518 525 """Serialize the message components to bytes.
519 526
520 527 This is roughly the inverse of unserialize. The serialize/unserialize
521 528 methods work with full message lists, whereas pack/unpack work with
522 529 the individual message parts in the message list.
523 530
524 531 Parameters
525 532 ----------
526 533 msg : dict or Message
527 534 The nexted message dict as returned by the self.msg method.
528 535
529 536 Returns
530 537 -------
531 538 msg_list : list
532 539 The list of bytes objects to be sent with the format::
533 540
534 541 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
535 542 p_metadata, p_content, buffer1, buffer2, ...]
536 543
537 544 In this list, the ``p_*`` entities are the packed or serialized
538 545 versions, so if JSON is used, these are utf8 encoded JSON strings.
539 546 """
540 547 content = msg.get('content', {})
541 548 if content is None:
542 549 content = self.none
543 550 elif isinstance(content, dict):
544 551 content = self.pack(content)
545 552 elif isinstance(content, bytes):
546 553 # content is already packed, as in a relayed message
547 554 pass
548 555 elif isinstance(content, unicode_type):
549 556 # should be bytes, but JSON often spits out unicode
550 557 content = content.encode('utf8')
551 558 else:
552 559 raise TypeError("Content incorrect type: %s"%type(content))
553 560
554 561 real_message = [self.pack(msg['header']),
555 562 self.pack(msg['parent_header']),
556 563 self.pack(msg['metadata']),
557 564 content,
558 565 ]
559 566
560 567 to_send = []
561 568
562 569 if isinstance(ident, list):
563 570 # accept list of idents
564 571 to_send.extend(ident)
565 572 elif ident is not None:
566 573 to_send.append(ident)
567 574 to_send.append(DELIM)
568 575
569 576 signature = self.sign(real_message)
570 577 to_send.append(signature)
571 578
572 579 to_send.extend(real_message)
573 580
574 581 return to_send
575 582
576 583 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
577 584 buffers=None, track=False, header=None, metadata=None):
578 585 """Build and send a message via stream or socket.
579 586
580 587 The message format used by this function internally is as follows:
581 588
582 589 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
583 590 buffer1,buffer2,...]
584 591
585 592 The serialize/unserialize methods convert the nested message dict into this
586 593 format.
587 594
588 595 Parameters
589 596 ----------
590 597
591 598 stream : zmq.Socket or ZMQStream
592 599 The socket-like object used to send the data.
593 600 msg_or_type : str or Message/dict
594 601 Normally, msg_or_type will be a msg_type unless a message is being
595 602 sent more than once. If a header is supplied, this can be set to
596 603 None and the msg_type will be pulled from the header.
597 604
598 605 content : dict or None
599 606 The content of the message (ignored if msg_or_type is a message).
600 607 header : dict or None
601 608 The header dict for the message (ignored if msg_to_type is a message).
602 609 parent : Message or dict or None
603 610 The parent or parent header describing the parent of this message
604 611 (ignored if msg_or_type is a message).
605 612 ident : bytes or list of bytes
606 613 The zmq.IDENTITY routing path.
607 614 metadata : dict or None
608 615 The metadata describing the message
609 616 buffers : list or None
610 617 The already-serialized buffers to be appended to the message.
611 618 track : bool
612 619 Whether to track. Only for use with Sockets, because ZMQStream
613 620 objects cannot track messages.
614 621
615 622
616 623 Returns
617 624 -------
618 625 msg : dict
619 626 The constructed message.
620 627 """
621 628 if not isinstance(stream, zmq.Socket):
622 629 # ZMQStreams and dummy sockets do not support tracking.
623 630 track = False
624 631
625 632 if isinstance(msg_or_type, (Message, dict)):
626 633 # We got a Message or message dict, not a msg_type so don't
627 634 # build a new Message.
628 635 msg = msg_or_type
629 636 else:
630 637 msg = self.msg(msg_or_type, content=content, parent=parent,
631 638 header=header, metadata=metadata)
632 639 if not os.getpid() == self.pid:
633 640 io.rprint("WARNING: attempted to send message from fork")
634 641 io.rprint(msg)
635 642 return
636 643 buffers = [] if buffers is None else buffers
637 644 if self.adapt_version:
638 645 msg = adapt(msg, self.adapt_version)
639 646 to_send = self.serialize(msg, ident)
640 647 to_send.extend(buffers)
641 648 longest = max([ len(s) for s in to_send ])
642 649 copy = (longest < self.copy_threshold)
643 650
644 651 if buffers and track and not copy:
645 652 # only really track when we are doing zero-copy buffers
646 653 tracker = stream.send_multipart(to_send, copy=False, track=True)
647 654 else:
648 655 # use dummy tracker, which will be done immediately
649 656 tracker = DONE
650 657 stream.send_multipart(to_send, copy=copy)
651 658
652 659 if self.debug:
653 660 pprint.pprint(msg)
654 661 pprint.pprint(to_send)
655 662 pprint.pprint(buffers)
656 663
657 664 msg['tracker'] = tracker
658 665
659 666 return msg
660 667
661 668 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
662 669 """Send a raw message via ident path.
663 670
664 671 This method is used to send a already serialized message.
665 672
666 673 Parameters
667 674 ----------
668 675 stream : ZMQStream or Socket
669 676 The ZMQ stream or socket to use for sending the message.
670 677 msg_list : list
671 678 The serialized list of messages to send. This only includes the
672 679 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
673 680 the message.
674 681 ident : ident or list
675 682 A single ident or a list of idents to use in sending.
676 683 """
677 684 to_send = []
678 685 if isinstance(ident, bytes):
679 686 ident = [ident]
680 687 if ident is not None:
681 688 to_send.extend(ident)
682 689
683 690 to_send.append(DELIM)
684 691 to_send.append(self.sign(msg_list))
685 692 to_send.extend(msg_list)
686 693 stream.send_multipart(to_send, flags, copy=copy)
687 694
688 695 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
689 696 """Receive and unpack a message.
690 697
691 698 Parameters
692 699 ----------
693 700 socket : ZMQStream or Socket
694 701 The socket or stream to use in receiving.
695 702
696 703 Returns
697 704 -------
698 705 [idents], msg
699 706 [idents] is a list of idents and msg is a nested message dict of
700 707 same format as self.msg returns.
701 708 """
702 709 if isinstance(socket, ZMQStream):
703 710 socket = socket.socket
704 711 try:
705 712 msg_list = socket.recv_multipart(mode, copy=copy)
706 713 except zmq.ZMQError as e:
707 714 if e.errno == zmq.EAGAIN:
708 715 # We can convert EAGAIN to None as we know in this case
709 716 # recv_multipart won't return None.
710 717 return None,None
711 718 else:
712 719 raise
713 720 # split multipart message into identity list and message dict
714 721 # invalid large messages can cause very expensive string comparisons
715 722 idents, msg_list = self.feed_identities(msg_list, copy)
716 723 try:
717 724 return idents, self.unserialize(msg_list, content=content, copy=copy)
718 725 except Exception as e:
719 726 # TODO: handle it
720 727 raise e
721 728
722 729 def feed_identities(self, msg_list, copy=True):
723 730 """Split the identities from the rest of the message.
724 731
725 732 Feed until DELIM is reached, then return the prefix as idents and
726 733 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
727 734 but that would be silly.
728 735
729 736 Parameters
730 737 ----------
731 738 msg_list : a list of Message or bytes objects
732 739 The message to be split.
733 740 copy : bool
734 741 flag determining whether the arguments are bytes or Messages
735 742
736 743 Returns
737 744 -------
738 745 (idents, msg_list) : two lists
739 746 idents will always be a list of bytes, each of which is a ZMQ
740 747 identity. msg_list will be a list of bytes or zmq.Messages of the
741 748 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
742 749 should be unpackable/unserializable via self.unserialize at this
743 750 point.
744 751 """
745 752 if copy:
746 753 idx = msg_list.index(DELIM)
747 754 return msg_list[:idx], msg_list[idx+1:]
748 755 else:
749 756 failed = True
750 757 for idx,m in enumerate(msg_list):
751 758 if m.bytes == DELIM:
752 759 failed = False
753 760 break
754 761 if failed:
755 762 raise ValueError("DELIM not in msg_list")
756 763 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
757 764 return [m.bytes for m in idents], msg_list
758 765
759 766 def _add_digest(self, signature):
760 767 """add a digest to history to protect against replay attacks"""
761 768 if self.digest_history_size == 0:
762 769 # no history, never add digests
763 770 return
764 771
765 772 self.digest_history.add(signature)
766 773 if len(self.digest_history) > self.digest_history_size:
767 774 # threshold reached, cull 10%
768 775 self._cull_digest_history()
769 776
770 777 def _cull_digest_history(self):
771 778 """cull the digest history
772 779
773 780 Removes a randomly selected 10% of the digest history
774 781 """
775 782 current = len(self.digest_history)
776 783 n_to_cull = max(int(current // 10), current - self.digest_history_size)
777 784 if n_to_cull >= current:
778 785 self.digest_history = set()
779 786 return
780 787 to_cull = random.sample(self.digest_history, n_to_cull)
781 788 self.digest_history.difference_update(to_cull)
782 789
783 790 def unserialize(self, msg_list, content=True, copy=True):
784 791 """Unserialize a msg_list to a nested message dict.
785 792
786 793 This is roughly the inverse of serialize. The serialize/unserialize
787 794 methods work with full message lists, whereas pack/unpack work with
788 795 the individual message parts in the message list.
789 796
790 797 Parameters
791 798 ----------
792 799 msg_list : list of bytes or Message objects
793 800 The list of message parts of the form [HMAC,p_header,p_parent,
794 801 p_metadata,p_content,buffer1,buffer2,...].
795 802 content : bool (True)
796 803 Whether to unpack the content dict (True), or leave it packed
797 804 (False).
798 805 copy : bool (True)
799 806 Whether to return the bytes (True), or the non-copying Message
800 807 object in each place (False).
801 808
802 809 Returns
803 810 -------
804 811 msg : dict
805 812 The nested message dict with top-level keys [header, parent_header,
806 813 content, buffers].
807 814 """
808 815 minlen = 5
809 816 message = {}
810 817 if not copy:
811 818 for i in range(minlen):
812 819 msg_list[i] = msg_list[i].bytes
813 820 if self.auth is not None:
814 821 signature = msg_list[0]
815 822 if not signature:
816 823 raise ValueError("Unsigned Message")
817 824 if signature in self.digest_history:
818 825 raise ValueError("Duplicate Signature: %r" % signature)
819 826 self._add_digest(signature)
820 827 check = self.sign(msg_list[1:5])
821 if not hmac.compare_digest(signature, check):
828 if not compare_digest(signature, check):
822 829 raise ValueError("Invalid Signature: %r" % signature)
823 830 if not len(msg_list) >= minlen:
824 831 raise TypeError("malformed message, must have at least %i elements"%minlen)
825 832 header = self.unpack(msg_list[1])
826 833 message['header'] = extract_dates(header)
827 834 message['msg_id'] = header['msg_id']
828 835 message['msg_type'] = header['msg_type']
829 836 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
830 837 message['metadata'] = self.unpack(msg_list[3])
831 838 if content:
832 839 message['content'] = self.unpack(msg_list[4])
833 840 else:
834 841 message['content'] = msg_list[4]
835 842
836 843 message['buffers'] = msg_list[5:]
837 844 # print("received: %s: %s\n %s" % (message['msg_type'], message['header'], message['content']))
838 845 # adapt to the current version
839 846 return adapt(message)
840 847 # print("adapted: %s: %s\n %s" % (adapted['msg_type'], adapted['header'], adapted['content']))
841 848
842 849 def test_msg2obj():
843 850 am = dict(x=1)
844 851 ao = Message(am)
845 852 assert ao.x == am['x']
846 853
847 854 am['y'] = dict(z=1)
848 855 ao = Message(am)
849 856 assert ao.y.z == am['y']['z']
850 857
851 858 k1, k2 = 'y', 'z'
852 859 assert ao[k1][k2] == am[k1][k2]
853 860
854 861 am2 = dict(ao)
855 862 assert am['x'] == am2['x']
856 863 assert am['y']['z'] == am2['y']['z']
857 864
General Comments 0
You need to be logged in to leave comments. Login now