##// END OF EJS Templates
missing allow_none=True
Sylvain Corlay -
Show More
@@ -1,882 +1,882 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9 """
10 10
11 11 # Copyright (c) IPython Development Team.
12 12 # Distributed under the terms of the Modified BSD License.
13 13
14 14 import hashlib
15 15 import hmac
16 16 import logging
17 17 import os
18 18 import pprint
19 19 import random
20 20 import uuid
21 21 import warnings
22 22 from datetime import datetime
23 23
24 24 try:
25 25 import cPickle
26 26 pickle = cPickle
27 27 except:
28 28 cPickle = None
29 29 import pickle
30 30
31 31 try:
32 32 # We are using compare_digest to limit the surface of timing attacks
33 33 from hmac import compare_digest
34 34 except ImportError:
35 35 # Python < 2.7.7: When digests don't match no feedback is provided,
36 36 # limiting the surface of attack
37 37 def compare_digest(a,b): return a == b
38 38
39 39 import zmq
40 40 from zmq.utils import jsonapi
41 41 from zmq.eventloop.ioloop import IOLoop
42 42 from zmq.eventloop.zmqstream import ZMQStream
43 43
44 44 from IPython.core.release import kernel_protocol_version
45 45 from IPython.config.configurable import Configurable, LoggingConfigurable
46 46 from IPython.utils import io
47 47 from IPython.utils.importstring import import_item
48 48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
49 49 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
50 50 iteritems)
51 51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
52 52 DottedObjectName, CUnicode, Dict, Integer,
53 53 TraitError,
54 54 )
55 55 from IPython.utils.pickleutil import PICKLE_PROTOCOL
56 56 from IPython.kernel.adapter import adapt
57 57 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
58 58
59 59 #-----------------------------------------------------------------------------
60 60 # utility functions
61 61 #-----------------------------------------------------------------------------
62 62
63 63 def squash_unicode(obj):
64 64 """coerce unicode back to bytestrings."""
65 65 if isinstance(obj,dict):
66 66 for key in obj.keys():
67 67 obj[key] = squash_unicode(obj[key])
68 68 if isinstance(key, unicode_type):
69 69 obj[squash_unicode(key)] = obj.pop(key)
70 70 elif isinstance(obj, list):
71 71 for i,v in enumerate(obj):
72 72 obj[i] = squash_unicode(v)
73 73 elif isinstance(obj, unicode_type):
74 74 obj = obj.encode('utf8')
75 75 return obj
76 76
77 77 #-----------------------------------------------------------------------------
78 78 # globals and defaults
79 79 #-----------------------------------------------------------------------------
80 80
81 81 # ISO8601-ify datetime objects
82 82 # allow unicode
83 83 # disallow nan, because it's not actually valid JSON
84 84 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
85 85 ensure_ascii=False, allow_nan=False,
86 86 )
87 87 json_unpacker = lambda s: jsonapi.loads(s)
88 88
89 89 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
90 90 pickle_unpacker = pickle.loads
91 91
92 92 default_packer = json_packer
93 93 default_unpacker = json_unpacker
94 94
95 95 DELIM = b"<IDS|MSG>"
96 96 # singleton dummy tracker, which will always report as done
97 97 DONE = zmq.MessageTracker()
98 98
99 99 #-----------------------------------------------------------------------------
100 100 # Mixin tools for apps that use Sessions
101 101 #-----------------------------------------------------------------------------
102 102
103 103 session_aliases = dict(
104 104 ident = 'Session.session',
105 105 user = 'Session.username',
106 106 keyfile = 'Session.keyfile',
107 107 )
108 108
109 109 session_flags = {
110 110 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
111 111 'keyfile' : '' }},
112 112 """Use HMAC digests for authentication of messages.
113 113 Setting this flag will generate a new UUID to use as the HMAC key.
114 114 """),
115 115 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
116 116 """Don't authenticate messages."""),
117 117 }
118 118
119 119 def default_secure(cfg):
120 120 """Set the default behavior for a config environment to be secure.
121 121
122 122 If Session.key/keyfile have not been set, set Session.key to
123 123 a new random UUID.
124 124 """
125 125 warnings.warn("default_secure is deprecated", DeprecationWarning)
126 126 if 'Session' in cfg:
127 127 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
128 128 return
129 129 # key/keyfile not specified, generate new UUID:
130 130 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
131 131
132 132
133 133 #-----------------------------------------------------------------------------
134 134 # Classes
135 135 #-----------------------------------------------------------------------------
136 136
137 137 class SessionFactory(LoggingConfigurable):
138 138 """The Base class for configurables that have a Session, Context, logger,
139 139 and IOLoop.
140 140 """
141 141
142 142 logname = Unicode('')
143 143 def _logname_changed(self, name, old, new):
144 144 self.log = logging.getLogger(new)
145 145
146 146 # not configurable:
147 147 context = Instance('zmq.Context')
148 148 def _context_default(self):
149 149 return zmq.Context.instance()
150 150
151 151 session = Instance('IPython.kernel.zmq.session.Session',
152 152 allow_none=True)
153 153
154 154 loop = Instance('zmq.eventloop.ioloop.IOLoop')
155 155 def _loop_default(self):
156 156 return IOLoop.instance()
157 157
158 158 def __init__(self, **kwargs):
159 159 super(SessionFactory, self).__init__(**kwargs)
160 160
161 161 if self.session is None:
162 162 # construct the session
163 163 self.session = Session(**kwargs)
164 164
165 165
166 166 class Message(object):
167 167 """A simple message object that maps dict keys to attributes.
168 168
169 169 A Message can be created from a dict and a dict from a Message instance
170 170 simply by calling dict(msg_obj)."""
171 171
172 172 def __init__(self, msg_dict):
173 173 dct = self.__dict__
174 174 for k, v in iteritems(dict(msg_dict)):
175 175 if isinstance(v, dict):
176 176 v = Message(v)
177 177 dct[k] = v
178 178
179 179 # Having this iterator lets dict(msg_obj) work out of the box.
180 180 def __iter__(self):
181 181 return iter(iteritems(self.__dict__))
182 182
183 183 def __repr__(self):
184 184 return repr(self.__dict__)
185 185
186 186 def __str__(self):
187 187 return pprint.pformat(self.__dict__)
188 188
189 189 def __contains__(self, k):
190 190 return k in self.__dict__
191 191
192 192 def __getitem__(self, k):
193 193 return self.__dict__[k]
194 194
195 195
196 196 def msg_header(msg_id, msg_type, username, session):
197 197 date = datetime.now()
198 198 version = kernel_protocol_version
199 199 return locals()
200 200
201 201 def extract_header(msg_or_header):
202 202 """Given a message or header, return the header."""
203 203 if not msg_or_header:
204 204 return {}
205 205 try:
206 206 # See if msg_or_header is the entire message.
207 207 h = msg_or_header['header']
208 208 except KeyError:
209 209 try:
210 210 # See if msg_or_header is just the header
211 211 h = msg_or_header['msg_id']
212 212 except KeyError:
213 213 raise
214 214 else:
215 215 h = msg_or_header
216 216 if not isinstance(h, dict):
217 217 h = dict(h)
218 218 return h
219 219
220 220 class Session(Configurable):
221 221 """Object for handling serialization and sending of messages.
222 222
223 223 The Session object handles building messages and sending them
224 224 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
225 225 other over the network via Session objects, and only need to work with the
226 226 dict-based IPython message spec. The Session will handle
227 227 serialization/deserialization, security, and metadata.
228 228
229 229 Sessions support configurable serialization via packer/unpacker traits,
230 230 and signing with HMAC digests via the key/keyfile traits.
231 231
232 232 Parameters
233 233 ----------
234 234
235 235 debug : bool
236 236 whether to trigger extra debugging statements
237 237 packer/unpacker : str : 'json', 'pickle' or import_string
238 238 importstrings for methods to serialize message parts. If just
239 239 'json' or 'pickle', predefined JSON and pickle packers will be used.
240 240 Otherwise, the entire importstring must be used.
241 241
242 242 The functions must accept at least valid JSON input, and output *bytes*.
243 243
244 244 For example, to use msgpack:
245 245 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
246 246 pack/unpack : callables
247 247 You can also set the pack/unpack callables for serialization directly.
248 248 session : bytes
249 249 the ID of this Session object. The default is to generate a new UUID.
250 250 username : unicode
251 251 username added to message headers. The default is to ask the OS.
252 252 key : bytes
253 253 The key used to initialize an HMAC signature. If unset, messages
254 254 will not be signed or checked.
255 255 keyfile : filepath
256 256 The file containing a key. If this is set, `key` will be initialized
257 257 to the contents of the file.
258 258
259 259 """
260 260
261 261 debug=Bool(False, config=True, help="""Debug output in the Session""")
262 262
263 263 packer = DottedObjectName('json',config=True,
264 264 help="""The name of the packer for serializing messages.
265 265 Should be one of 'json', 'pickle', or an import name
266 266 for a custom callable serializer.""")
267 267 def _packer_changed(self, name, old, new):
268 268 if new.lower() == 'json':
269 269 self.pack = json_packer
270 270 self.unpack = json_unpacker
271 271 self.unpacker = new
272 272 elif new.lower() == 'pickle':
273 273 self.pack = pickle_packer
274 274 self.unpack = pickle_unpacker
275 275 self.unpacker = new
276 276 else:
277 277 self.pack = import_item(str(new))
278 278
279 279 unpacker = DottedObjectName('json', config=True,
280 280 help="""The name of the unpacker for unserializing messages.
281 281 Only used with custom functions for `packer`.""")
282 282 def _unpacker_changed(self, name, old, new):
283 283 if new.lower() == 'json':
284 284 self.pack = json_packer
285 285 self.unpack = json_unpacker
286 286 self.packer = new
287 287 elif new.lower() == 'pickle':
288 288 self.pack = pickle_packer
289 289 self.unpack = pickle_unpacker
290 290 self.packer = new
291 291 else:
292 292 self.unpack = import_item(str(new))
293 293
294 294 session = CUnicode(u'', config=True,
295 295 help="""The UUID identifying this session.""")
296 296 def _session_default(self):
297 297 u = unicode_type(uuid.uuid4())
298 298 self.bsession = u.encode('ascii')
299 299 return u
300 300
301 301 def _session_changed(self, name, old, new):
302 302 self.bsession = self.session.encode('ascii')
303 303
304 304 # bsession is the session as bytes
305 305 bsession = CBytes(b'')
306 306
307 307 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
308 308 help="""Username for the Session. Default is your system username.""",
309 309 config=True)
310 310
311 311 metadata = Dict({}, config=True,
312 312 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
313 313
314 314 # if 0, no adapting to do.
315 315 adapt_version = Integer(0)
316 316
317 317 # message signature related traits:
318 318
319 319 key = CBytes(config=True,
320 320 help="""execution key, for signing messages.""")
321 321 def _key_default(self):
322 322 return str_to_bytes(str(uuid.uuid4()))
323 323
324 324 def _key_changed(self):
325 325 self._new_auth()
326 326
327 327 signature_scheme = Unicode('hmac-sha256', config=True,
328 328 help="""The digest scheme used to construct the message signatures.
329 329 Must have the form 'hmac-HASH'.""")
330 330 def _signature_scheme_changed(self, name, old, new):
331 331 if not new.startswith('hmac-'):
332 332 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
333 333 hash_name = new.split('-', 1)[1]
334 334 try:
335 335 self.digest_mod = getattr(hashlib, hash_name)
336 336 except AttributeError:
337 337 raise TraitError("hashlib has no such attribute: %s" % hash_name)
338 338 self._new_auth()
339 339
340 340 digest_mod = Any()
341 341 def _digest_mod_default(self):
342 342 return hashlib.sha256
343 343
344 auth = Instance(hmac.HMAC)
344 auth = Instance(hmac.HMAC, allow_none=True)
345 345
346 346 def _new_auth(self):
347 347 if self.key:
348 348 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
349 349 else:
350 350 self.auth = None
351 351
352 352 digest_history = Set()
353 353 digest_history_size = Integer(2**16, config=True,
354 354 help="""The maximum number of digests to remember.
355 355
356 356 The digest history will be culled when it exceeds this value.
357 357 """
358 358 )
359 359
360 360 keyfile = Unicode('', config=True,
361 361 help="""path to file containing execution key.""")
362 362 def _keyfile_changed(self, name, old, new):
363 363 with open(new, 'rb') as f:
364 364 self.key = f.read().strip()
365 365
366 366 # for protecting against sends from forks
367 367 pid = Integer()
368 368
369 369 # serialization traits:
370 370
371 371 pack = Any(default_packer) # the actual packer function
372 372 def _pack_changed(self, name, old, new):
373 373 if not callable(new):
374 374 raise TypeError("packer must be callable, not %s"%type(new))
375 375
376 376 unpack = Any(default_unpacker) # the actual packer function
377 377 def _unpack_changed(self, name, old, new):
378 378 # unpacker is not checked - it is assumed to be
379 379 if not callable(new):
380 380 raise TypeError("unpacker must be callable, not %s"%type(new))
381 381
382 382 # thresholds:
383 383 copy_threshold = Integer(2**16, config=True,
384 384 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
385 385 buffer_threshold = Integer(MAX_BYTES, config=True,
386 386 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
387 387 item_threshold = Integer(MAX_ITEMS, config=True,
388 388 help="""The maximum number of items for a container to be introspected for custom serialization.
389 389 Containers larger than this are pickled outright.
390 390 """
391 391 )
392 392
393 393
394 394 def __init__(self, **kwargs):
395 395 """create a Session object
396 396
397 397 Parameters
398 398 ----------
399 399
400 400 debug : bool
401 401 whether to trigger extra debugging statements
402 402 packer/unpacker : str : 'json', 'pickle' or import_string
403 403 importstrings for methods to serialize message parts. If just
404 404 'json' or 'pickle', predefined JSON and pickle packers will be used.
405 405 Otherwise, the entire importstring must be used.
406 406
407 407 The functions must accept at least valid JSON input, and output
408 408 *bytes*.
409 409
410 410 For example, to use msgpack:
411 411 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
412 412 pack/unpack : callables
413 413 You can also set the pack/unpack callables for serialization
414 414 directly.
415 415 session : unicode (must be ascii)
416 416 the ID of this Session object. The default is to generate a new
417 417 UUID.
418 418 bsession : bytes
419 419 The session as bytes
420 420 username : unicode
421 421 username added to message headers. The default is to ask the OS.
422 422 key : bytes
423 423 The key used to initialize an HMAC signature. If unset, messages
424 424 will not be signed or checked.
425 425 signature_scheme : str
426 426 The message digest scheme. Currently must be of the form 'hmac-HASH',
427 427 where 'HASH' is a hashing function available in Python's hashlib.
428 428 The default is 'hmac-sha256'.
429 429 This is ignored if 'key' is empty.
430 430 keyfile : filepath
431 431 The file containing a key. If this is set, `key` will be
432 432 initialized to the contents of the file.
433 433 """
434 434 super(Session, self).__init__(**kwargs)
435 435 self._check_packers()
436 436 self.none = self.pack({})
437 437 # ensure self._session_default() if necessary, so bsession is defined:
438 438 self.session
439 439 self.pid = os.getpid()
440 440 self._new_auth()
441 441
442 442 @property
443 443 def msg_id(self):
444 444 """always return new uuid"""
445 445 return str(uuid.uuid4())
446 446
447 447 def _check_packers(self):
448 448 """check packers for datetime support."""
449 449 pack = self.pack
450 450 unpack = self.unpack
451 451
452 452 # check simple serialization
453 453 msg = dict(a=[1,'hi'])
454 454 try:
455 455 packed = pack(msg)
456 456 except Exception as e:
457 457 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
458 458 if self.packer == 'json':
459 459 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
460 460 else:
461 461 jsonmsg = ""
462 462 raise ValueError(
463 463 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
464 464 )
465 465
466 466 # ensure packed message is bytes
467 467 if not isinstance(packed, bytes):
468 468 raise ValueError("message packed to %r, but bytes are required"%type(packed))
469 469
470 470 # check that unpack is pack's inverse
471 471 try:
472 472 unpacked = unpack(packed)
473 473 assert unpacked == msg
474 474 except Exception as e:
475 475 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
476 476 if self.packer == 'json':
477 477 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
478 478 else:
479 479 jsonmsg = ""
480 480 raise ValueError(
481 481 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
482 482 )
483 483
484 484 # check datetime support
485 485 msg = dict(t=datetime.now())
486 486 try:
487 487 unpacked = unpack(pack(msg))
488 488 if isinstance(unpacked['t'], datetime):
489 489 raise ValueError("Shouldn't deserialize to datetime")
490 490 except Exception:
491 491 self.pack = lambda o: pack(squash_dates(o))
492 492 self.unpack = lambda s: unpack(s)
493 493
494 494 def msg_header(self, msg_type):
495 495 return msg_header(self.msg_id, msg_type, self.username, self.session)
496 496
497 497 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
498 498 """Return the nested message dict.
499 499
500 500 This format is different from what is sent over the wire. The
501 501 serialize/deserialize methods converts this nested message dict to the wire
502 502 format, which is a list of message parts.
503 503 """
504 504 msg = {}
505 505 header = self.msg_header(msg_type) if header is None else header
506 506 msg['header'] = header
507 507 msg['msg_id'] = header['msg_id']
508 508 msg['msg_type'] = header['msg_type']
509 509 msg['parent_header'] = {} if parent is None else extract_header(parent)
510 510 msg['content'] = {} if content is None else content
511 511 msg['metadata'] = self.metadata.copy()
512 512 if metadata is not None:
513 513 msg['metadata'].update(metadata)
514 514 return msg
515 515
516 516 def sign(self, msg_list):
517 517 """Sign a message with HMAC digest. If no auth, return b''.
518 518
519 519 Parameters
520 520 ----------
521 521 msg_list : list
522 522 The [p_header,p_parent,p_content] part of the message list.
523 523 """
524 524 if self.auth is None:
525 525 return b''
526 526 h = self.auth.copy()
527 527 for m in msg_list:
528 528 h.update(m)
529 529 return str_to_bytes(h.hexdigest())
530 530
531 531 def serialize(self, msg, ident=None):
532 532 """Serialize the message components to bytes.
533 533
534 534 This is roughly the inverse of deserialize. The serialize/deserialize
535 535 methods work with full message lists, whereas pack/unpack work with
536 536 the individual message parts in the message list.
537 537
538 538 Parameters
539 539 ----------
540 540 msg : dict or Message
541 541 The next message dict as returned by the self.msg method.
542 542
543 543 Returns
544 544 -------
545 545 msg_list : list
546 546 The list of bytes objects to be sent with the format::
547 547
548 548 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
549 549 p_metadata, p_content, buffer1, buffer2, ...]
550 550
551 551 In this list, the ``p_*`` entities are the packed or serialized
552 552 versions, so if JSON is used, these are utf8 encoded JSON strings.
553 553 """
554 554 content = msg.get('content', {})
555 555 if content is None:
556 556 content = self.none
557 557 elif isinstance(content, dict):
558 558 content = self.pack(content)
559 559 elif isinstance(content, bytes):
560 560 # content is already packed, as in a relayed message
561 561 pass
562 562 elif isinstance(content, unicode_type):
563 563 # should be bytes, but JSON often spits out unicode
564 564 content = content.encode('utf8')
565 565 else:
566 566 raise TypeError("Content incorrect type: %s"%type(content))
567 567
568 568 real_message = [self.pack(msg['header']),
569 569 self.pack(msg['parent_header']),
570 570 self.pack(msg['metadata']),
571 571 content,
572 572 ]
573 573
574 574 to_send = []
575 575
576 576 if isinstance(ident, list):
577 577 # accept list of idents
578 578 to_send.extend(ident)
579 579 elif ident is not None:
580 580 to_send.append(ident)
581 581 to_send.append(DELIM)
582 582
583 583 signature = self.sign(real_message)
584 584 to_send.append(signature)
585 585
586 586 to_send.extend(real_message)
587 587
588 588 return to_send
589 589
590 590 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
591 591 buffers=None, track=False, header=None, metadata=None):
592 592 """Build and send a message via stream or socket.
593 593
594 594 The message format used by this function internally is as follows:
595 595
596 596 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
597 597 buffer1,buffer2,...]
598 598
599 599 The serialize/deserialize methods convert the nested message dict into this
600 600 format.
601 601
602 602 Parameters
603 603 ----------
604 604
605 605 stream : zmq.Socket or ZMQStream
606 606 The socket-like object used to send the data.
607 607 msg_or_type : str or Message/dict
608 608 Normally, msg_or_type will be a msg_type unless a message is being
609 609 sent more than once. If a header is supplied, this can be set to
610 610 None and the msg_type will be pulled from the header.
611 611
612 612 content : dict or None
613 613 The content of the message (ignored if msg_or_type is a message).
614 614 header : dict or None
615 615 The header dict for the message (ignored if msg_to_type is a message).
616 616 parent : Message or dict or None
617 617 The parent or parent header describing the parent of this message
618 618 (ignored if msg_or_type is a message).
619 619 ident : bytes or list of bytes
620 620 The zmq.IDENTITY routing path.
621 621 metadata : dict or None
622 622 The metadata describing the message
623 623 buffers : list or None
624 624 The already-serialized buffers to be appended to the message.
625 625 track : bool
626 626 Whether to track. Only for use with Sockets, because ZMQStream
627 627 objects cannot track messages.
628 628
629 629
630 630 Returns
631 631 -------
632 632 msg : dict
633 633 The constructed message.
634 634 """
635 635 if not isinstance(stream, zmq.Socket):
636 636 # ZMQStreams and dummy sockets do not support tracking.
637 637 track = False
638 638
639 639 if isinstance(msg_or_type, (Message, dict)):
640 640 # We got a Message or message dict, not a msg_type so don't
641 641 # build a new Message.
642 642 msg = msg_or_type
643 643 buffers = buffers or msg.get('buffers', [])
644 644 else:
645 645 msg = self.msg(msg_or_type, content=content, parent=parent,
646 646 header=header, metadata=metadata)
647 647 if not os.getpid() == self.pid:
648 648 io.rprint("WARNING: attempted to send message from fork")
649 649 io.rprint(msg)
650 650 return
651 651 buffers = [] if buffers is None else buffers
652 652 if self.adapt_version:
653 653 msg = adapt(msg, self.adapt_version)
654 654 to_send = self.serialize(msg, ident)
655 655 to_send.extend(buffers)
656 656 longest = max([ len(s) for s in to_send ])
657 657 copy = (longest < self.copy_threshold)
658 658
659 659 if buffers and track and not copy:
660 660 # only really track when we are doing zero-copy buffers
661 661 tracker = stream.send_multipart(to_send, copy=False, track=True)
662 662 else:
663 663 # use dummy tracker, which will be done immediately
664 664 tracker = DONE
665 665 stream.send_multipart(to_send, copy=copy)
666 666
667 667 if self.debug:
668 668 pprint.pprint(msg)
669 669 pprint.pprint(to_send)
670 670 pprint.pprint(buffers)
671 671
672 672 msg['tracker'] = tracker
673 673
674 674 return msg
675 675
676 676 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
677 677 """Send a raw message via ident path.
678 678
679 679 This method is used to send a already serialized message.
680 680
681 681 Parameters
682 682 ----------
683 683 stream : ZMQStream or Socket
684 684 The ZMQ stream or socket to use for sending the message.
685 685 msg_list : list
686 686 The serialized list of messages to send. This only includes the
687 687 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
688 688 the message.
689 689 ident : ident or list
690 690 A single ident or a list of idents to use in sending.
691 691 """
692 692 to_send = []
693 693 if isinstance(ident, bytes):
694 694 ident = [ident]
695 695 if ident is not None:
696 696 to_send.extend(ident)
697 697
698 698 to_send.append(DELIM)
699 699 to_send.append(self.sign(msg_list))
700 700 to_send.extend(msg_list)
701 701 stream.send_multipart(to_send, flags, copy=copy)
702 702
703 703 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
704 704 """Receive and unpack a message.
705 705
706 706 Parameters
707 707 ----------
708 708 socket : ZMQStream or Socket
709 709 The socket or stream to use in receiving.
710 710
711 711 Returns
712 712 -------
713 713 [idents], msg
714 714 [idents] is a list of idents and msg is a nested message dict of
715 715 same format as self.msg returns.
716 716 """
717 717 if isinstance(socket, ZMQStream):
718 718 socket = socket.socket
719 719 try:
720 720 msg_list = socket.recv_multipart(mode, copy=copy)
721 721 except zmq.ZMQError as e:
722 722 if e.errno == zmq.EAGAIN:
723 723 # We can convert EAGAIN to None as we know in this case
724 724 # recv_multipart won't return None.
725 725 return None,None
726 726 else:
727 727 raise
728 728 # split multipart message into identity list and message dict
729 729 # invalid large messages can cause very expensive string comparisons
730 730 idents, msg_list = self.feed_identities(msg_list, copy)
731 731 try:
732 732 return idents, self.deserialize(msg_list, content=content, copy=copy)
733 733 except Exception as e:
734 734 # TODO: handle it
735 735 raise e
736 736
737 737 def feed_identities(self, msg_list, copy=True):
738 738 """Split the identities from the rest of the message.
739 739
740 740 Feed until DELIM is reached, then return the prefix as idents and
741 741 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
742 742 but that would be silly.
743 743
744 744 Parameters
745 745 ----------
746 746 msg_list : a list of Message or bytes objects
747 747 The message to be split.
748 748 copy : bool
749 749 flag determining whether the arguments are bytes or Messages
750 750
751 751 Returns
752 752 -------
753 753 (idents, msg_list) : two lists
754 754 idents will always be a list of bytes, each of which is a ZMQ
755 755 identity. msg_list will be a list of bytes or zmq.Messages of the
756 756 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
757 757 should be unpackable/unserializable via self.deserialize at this
758 758 point.
759 759 """
760 760 if copy:
761 761 idx = msg_list.index(DELIM)
762 762 return msg_list[:idx], msg_list[idx+1:]
763 763 else:
764 764 failed = True
765 765 for idx,m in enumerate(msg_list):
766 766 if m.bytes == DELIM:
767 767 failed = False
768 768 break
769 769 if failed:
770 770 raise ValueError("DELIM not in msg_list")
771 771 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
772 772 return [m.bytes for m in idents], msg_list
773 773
774 774 def _add_digest(self, signature):
775 775 """add a digest to history to protect against replay attacks"""
776 776 if self.digest_history_size == 0:
777 777 # no history, never add digests
778 778 return
779 779
780 780 self.digest_history.add(signature)
781 781 if len(self.digest_history) > self.digest_history_size:
782 782 # threshold reached, cull 10%
783 783 self._cull_digest_history()
784 784
785 785 def _cull_digest_history(self):
786 786 """cull the digest history
787 787
788 788 Removes a randomly selected 10% of the digest history
789 789 """
790 790 current = len(self.digest_history)
791 791 n_to_cull = max(int(current // 10), current - self.digest_history_size)
792 792 if n_to_cull >= current:
793 793 self.digest_history = set()
794 794 return
795 795 to_cull = random.sample(self.digest_history, n_to_cull)
796 796 self.digest_history.difference_update(to_cull)
797 797
798 798 def deserialize(self, msg_list, content=True, copy=True):
799 799 """Unserialize a msg_list to a nested message dict.
800 800
801 801 This is roughly the inverse of serialize. The serialize/deserialize
802 802 methods work with full message lists, whereas pack/unpack work with
803 803 the individual message parts in the message list.
804 804
805 805 Parameters
806 806 ----------
807 807 msg_list : list of bytes or Message objects
808 808 The list of message parts of the form [HMAC,p_header,p_parent,
809 809 p_metadata,p_content,buffer1,buffer2,...].
810 810 content : bool (True)
811 811 Whether to unpack the content dict (True), or leave it packed
812 812 (False).
813 813 copy : bool (True)
814 814 Whether msg_list contains bytes (True) or the non-copying Message
815 815 objects in each place (False).
816 816
817 817 Returns
818 818 -------
819 819 msg : dict
820 820 The nested message dict with top-level keys [header, parent_header,
821 821 content, buffers]. The buffers are returned as memoryviews.
822 822 """
823 823 minlen = 5
824 824 message = {}
825 825 if not copy:
826 826 # pyzmq didn't copy the first parts of the message, so we'll do it
827 827 for i in range(minlen):
828 828 msg_list[i] = msg_list[i].bytes
829 829 if self.auth is not None:
830 830 signature = msg_list[0]
831 831 if not signature:
832 832 raise ValueError("Unsigned Message")
833 833 if signature in self.digest_history:
834 834 raise ValueError("Duplicate Signature: %r" % signature)
835 835 self._add_digest(signature)
836 836 check = self.sign(msg_list[1:5])
837 837 if not compare_digest(signature, check):
838 838 raise ValueError("Invalid Signature: %r" % signature)
839 839 if not len(msg_list) >= minlen:
840 840 raise TypeError("malformed message, must have at least %i elements"%minlen)
841 841 header = self.unpack(msg_list[1])
842 842 message['header'] = extract_dates(header)
843 843 message['msg_id'] = header['msg_id']
844 844 message['msg_type'] = header['msg_type']
845 845 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
846 846 message['metadata'] = self.unpack(msg_list[3])
847 847 if content:
848 848 message['content'] = self.unpack(msg_list[4])
849 849 else:
850 850 message['content'] = msg_list[4]
851 851 buffers = [memoryview(b) for b in msg_list[5:]]
852 852 if buffers and buffers[0].shape is None:
853 853 # force copy to workaround pyzmq #646
854 854 buffers = [memoryview(b.bytes) for b in msg_list[5:]]
855 855 message['buffers'] = buffers
856 856 # adapt to the current version
857 857 return adapt(message)
858 858
859 859 def unserialize(self, *args, **kwargs):
860 860 warnings.warn(
861 861 "Session.unserialize is deprecated. Use Session.deserialize.",
862 862 DeprecationWarning,
863 863 )
864 864 return self.deserialize(*args, **kwargs)
865 865
866 866
867 867 def test_msg2obj():
868 868 am = dict(x=1)
869 869 ao = Message(am)
870 870 assert ao.x == am['x']
871 871
872 872 am['y'] = dict(z=1)
873 873 ao = Message(am)
874 874 assert ao.y.z == am['y']['z']
875 875
876 876 k1, k2 = 'y', 'z'
877 877 assert ao[k1][k2] == am[k1][k2]
878 878
879 879 am2 = dict(ao)
880 880 assert am['x'] == am2['x']
881 881 assert am['y']['z'] == am2['y']['z']
882 882
General Comments 0
You need to be logged in to leave comments. Login now