##// END OF EJS Templates
Don't use io.rprint in jupyter_client.session...
Min RK -
Show More
@@ -1,889 +1,891 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9 """
10 10
11 11 # Copyright (c) IPython Development Team.
12 12 # Distributed under the terms of the Modified BSD License.
13 13
14 14 import hashlib
15 15 import hmac
16 16 import logging
17 17 import os
18 18 import pprint
19 19 import random
20 20 import uuid
21 21 import warnings
22 22 from datetime import datetime
23 23
24 24 try:
25 25 import cPickle
26 26 pickle = cPickle
27 27 except:
28 28 cPickle = None
29 29 import pickle
30 30
31 31 try:
32 32 # py3
33 33 PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
34 34 except AttributeError:
35 35 PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
36 36
37 37 try:
38 38 # We are using compare_digest to limit the surface of timing attacks
39 39 from hmac import compare_digest
40 40 except ImportError:
41 41 # Python < 2.7.7: When digests don't match no feedback is provided,
42 42 # limiting the surface of attack
43 43 def compare_digest(a,b): return a == b
44 44
45 45 import zmq
46 46 from zmq.utils import jsonapi
47 47 from zmq.eventloop.ioloop import IOLoop
48 48 from zmq.eventloop.zmqstream import ZMQStream
49 49
50 50 from IPython.core.release import kernel_protocol_version
51 51 from IPython.config.configurable import Configurable, LoggingConfigurable
52 from IPython.utils import io
53 52 from IPython.utils.importstring import import_item
54 53 from jupyter_client.jsonutil import extract_dates, squash_dates, date_default
55 54 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
56 55 iteritems)
57 56 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
58 57 DottedObjectName, CUnicode, Dict, Integer,
59 58 TraitError,
60 59 )
61 60 from jupyter_client.adapter import adapt
61 from traitlets.log import get_logger
62
62 63
63 64 #-----------------------------------------------------------------------------
64 65 # utility functions
65 66 #-----------------------------------------------------------------------------
66 67
67 68 def squash_unicode(obj):
68 69 """coerce unicode back to bytestrings."""
69 70 if isinstance(obj,dict):
70 71 for key in obj.keys():
71 72 obj[key] = squash_unicode(obj[key])
72 73 if isinstance(key, unicode_type):
73 74 obj[squash_unicode(key)] = obj.pop(key)
74 75 elif isinstance(obj, list):
75 76 for i,v in enumerate(obj):
76 77 obj[i] = squash_unicode(v)
77 78 elif isinstance(obj, unicode_type):
78 79 obj = obj.encode('utf8')
79 80 return obj
80 81
81 82 #-----------------------------------------------------------------------------
82 83 # globals and defaults
83 84 #-----------------------------------------------------------------------------
84 85
85 86 # default values for the thresholds:
86 87 MAX_ITEMS = 64
87 88 MAX_BYTES = 1024
88 89
89 90 # ISO8601-ify datetime objects
90 91 # allow unicode
91 92 # disallow nan, because it's not actually valid JSON
92 93 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
93 94 ensure_ascii=False, allow_nan=False,
94 95 )
95 96 json_unpacker = lambda s: jsonapi.loads(s)
96 97
97 98 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
98 99 pickle_unpacker = pickle.loads
99 100
100 101 default_packer = json_packer
101 102 default_unpacker = json_unpacker
102 103
103 104 DELIM = b"<IDS|MSG>"
104 105 # singleton dummy tracker, which will always report as done
105 106 DONE = zmq.MessageTracker()
106 107
107 108 #-----------------------------------------------------------------------------
108 109 # Mixin tools for apps that use Sessions
109 110 #-----------------------------------------------------------------------------
110 111
111 112 session_aliases = dict(
112 113 ident = 'Session.session',
113 114 user = 'Session.username',
114 115 keyfile = 'Session.keyfile',
115 116 )
116 117
117 118 session_flags = {
118 119 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
119 120 'keyfile' : '' }},
120 121 """Use HMAC digests for authentication of messages.
121 122 Setting this flag will generate a new UUID to use as the HMAC key.
122 123 """),
123 124 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
124 125 """Don't authenticate messages."""),
125 126 }
126 127
127 128 def default_secure(cfg):
128 129 """Set the default behavior for a config environment to be secure.
129 130
130 131 If Session.key/keyfile have not been set, set Session.key to
131 132 a new random UUID.
132 133 """
133 134 warnings.warn("default_secure is deprecated", DeprecationWarning)
134 135 if 'Session' in cfg:
135 136 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
136 137 return
137 138 # key/keyfile not specified, generate new UUID:
138 139 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
139 140
140 141
141 142 #-----------------------------------------------------------------------------
142 143 # Classes
143 144 #-----------------------------------------------------------------------------
144 145
145 146 class SessionFactory(LoggingConfigurable):
146 147 """The Base class for configurables that have a Session, Context, logger,
147 148 and IOLoop.
148 149 """
149 150
150 151 logname = Unicode('')
151 152 def _logname_changed(self, name, old, new):
152 153 self.log = logging.getLogger(new)
153 154
154 155 # not configurable:
155 156 context = Instance('zmq.Context')
156 157 def _context_default(self):
157 158 return zmq.Context.instance()
158 159
159 160 session = Instance('jupyter_client.session.Session',
160 161 allow_none=True)
161 162
162 163 loop = Instance('zmq.eventloop.ioloop.IOLoop')
163 164 def _loop_default(self):
164 165 return IOLoop.instance()
165 166
166 167 def __init__(self, **kwargs):
167 168 super(SessionFactory, self).__init__(**kwargs)
168 169
169 170 if self.session is None:
170 171 # construct the session
171 172 self.session = Session(**kwargs)
172 173
173 174
174 175 class Message(object):
175 176 """A simple message object that maps dict keys to attributes.
176 177
177 178 A Message can be created from a dict and a dict from a Message instance
178 179 simply by calling dict(msg_obj)."""
179 180
180 181 def __init__(self, msg_dict):
181 182 dct = self.__dict__
182 183 for k, v in iteritems(dict(msg_dict)):
183 184 if isinstance(v, dict):
184 185 v = Message(v)
185 186 dct[k] = v
186 187
187 188 # Having this iterator lets dict(msg_obj) work out of the box.
188 189 def __iter__(self):
189 190 return iter(iteritems(self.__dict__))
190 191
191 192 def __repr__(self):
192 193 return repr(self.__dict__)
193 194
194 195 def __str__(self):
195 196 return pprint.pformat(self.__dict__)
196 197
197 198 def __contains__(self, k):
198 199 return k in self.__dict__
199 200
200 201 def __getitem__(self, k):
201 202 return self.__dict__[k]
202 203
203 204
204 205 def msg_header(msg_id, msg_type, username, session):
205 206 date = datetime.now()
206 207 version = kernel_protocol_version
207 208 return locals()
208 209
209 210 def extract_header(msg_or_header):
210 211 """Given a message or header, return the header."""
211 212 if not msg_or_header:
212 213 return {}
213 214 try:
214 215 # See if msg_or_header is the entire message.
215 216 h = msg_or_header['header']
216 217 except KeyError:
217 218 try:
218 219 # See if msg_or_header is just the header
219 220 h = msg_or_header['msg_id']
220 221 except KeyError:
221 222 raise
222 223 else:
223 224 h = msg_or_header
224 225 if not isinstance(h, dict):
225 226 h = dict(h)
226 227 return h
227 228
228 229 class Session(Configurable):
229 230 """Object for handling serialization and sending of messages.
230 231
231 232 The Session object handles building messages and sending them
232 233 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
233 234 other over the network via Session objects, and only need to work with the
234 235 dict-based IPython message spec. The Session will handle
235 236 serialization/deserialization, security, and metadata.
236 237
237 238 Sessions support configurable serialization via packer/unpacker traits,
238 239 and signing with HMAC digests via the key/keyfile traits.
239 240
240 241 Parameters
241 242 ----------
242 243
243 244 debug : bool
244 245 whether to trigger extra debugging statements
245 246 packer/unpacker : str : 'json', 'pickle' or import_string
246 247 importstrings for methods to serialize message parts. If just
247 248 'json' or 'pickle', predefined JSON and pickle packers will be used.
248 249 Otherwise, the entire importstring must be used.
249 250
250 251 The functions must accept at least valid JSON input, and output *bytes*.
251 252
252 253 For example, to use msgpack:
253 254 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
254 255 pack/unpack : callables
255 256 You can also set the pack/unpack callables for serialization directly.
256 257 session : bytes
257 258 the ID of this Session object. The default is to generate a new UUID.
258 259 username : unicode
259 260 username added to message headers. The default is to ask the OS.
260 261 key : bytes
261 262 The key used to initialize an HMAC signature. If unset, messages
262 263 will not be signed or checked.
263 264 keyfile : filepath
264 265 The file containing a key. If this is set, `key` will be initialized
265 266 to the contents of the file.
266 267
267 268 """
268 269
269 270 debug=Bool(False, config=True, help="""Debug output in the Session""")
270 271
271 272 packer = DottedObjectName('json',config=True,
272 273 help="""The name of the packer for serializing messages.
273 274 Should be one of 'json', 'pickle', or an import name
274 275 for a custom callable serializer.""")
275 276 def _packer_changed(self, name, old, new):
276 277 if new.lower() == 'json':
277 278 self.pack = json_packer
278 279 self.unpack = json_unpacker
279 280 self.unpacker = new
280 281 elif new.lower() == 'pickle':
281 282 self.pack = pickle_packer
282 283 self.unpack = pickle_unpacker
283 284 self.unpacker = new
284 285 else:
285 286 self.pack = import_item(str(new))
286 287
287 288 unpacker = DottedObjectName('json', config=True,
288 289 help="""The name of the unpacker for unserializing messages.
289 290 Only used with custom functions for `packer`.""")
290 291 def _unpacker_changed(self, name, old, new):
291 292 if new.lower() == 'json':
292 293 self.pack = json_packer
293 294 self.unpack = json_unpacker
294 295 self.packer = new
295 296 elif new.lower() == 'pickle':
296 297 self.pack = pickle_packer
297 298 self.unpack = pickle_unpacker
298 299 self.packer = new
299 300 else:
300 301 self.unpack = import_item(str(new))
301 302
302 303 session = CUnicode(u'', config=True,
303 304 help="""The UUID identifying this session.""")
304 305 def _session_default(self):
305 306 u = unicode_type(uuid.uuid4())
306 307 self.bsession = u.encode('ascii')
307 308 return u
308 309
309 310 def _session_changed(self, name, old, new):
310 311 self.bsession = self.session.encode('ascii')
311 312
312 313 # bsession is the session as bytes
313 314 bsession = CBytes(b'')
314 315
315 316 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
316 317 help="""Username for the Session. Default is your system username.""",
317 318 config=True)
318 319
319 320 metadata = Dict({}, config=True,
320 321 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
321 322
322 323 # if 0, no adapting to do.
323 324 adapt_version = Integer(0)
324 325
325 326 # message signature related traits:
326 327
327 328 key = CBytes(config=True,
328 329 help="""execution key, for signing messages.""")
329 330 def _key_default(self):
330 331 return str_to_bytes(str(uuid.uuid4()))
331 332
332 333 def _key_changed(self):
333 334 self._new_auth()
334 335
335 336 signature_scheme = Unicode('hmac-sha256', config=True,
336 337 help="""The digest scheme used to construct the message signatures.
337 338 Must have the form 'hmac-HASH'.""")
338 339 def _signature_scheme_changed(self, name, old, new):
339 340 if not new.startswith('hmac-'):
340 341 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
341 342 hash_name = new.split('-', 1)[1]
342 343 try:
343 344 self.digest_mod = getattr(hashlib, hash_name)
344 345 except AttributeError:
345 346 raise TraitError("hashlib has no such attribute: %s" % hash_name)
346 347 self._new_auth()
347 348
348 349 digest_mod = Any()
349 350 def _digest_mod_default(self):
350 351 return hashlib.sha256
351 352
352 353 auth = Instance(hmac.HMAC, allow_none=True)
353 354
354 355 def _new_auth(self):
355 356 if self.key:
356 357 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
357 358 else:
358 359 self.auth = None
359 360
360 361 digest_history = Set()
361 362 digest_history_size = Integer(2**16, config=True,
362 363 help="""The maximum number of digests to remember.
363 364
364 365 The digest history will be culled when it exceeds this value.
365 366 """
366 367 )
367 368
368 369 keyfile = Unicode('', config=True,
369 370 help="""path to file containing execution key.""")
370 371 def _keyfile_changed(self, name, old, new):
371 372 with open(new, 'rb') as f:
372 373 self.key = f.read().strip()
373 374
374 375 # for protecting against sends from forks
375 376 pid = Integer()
376 377
377 378 # serialization traits:
378 379
379 380 pack = Any(default_packer) # the actual packer function
380 381 def _pack_changed(self, name, old, new):
381 382 if not callable(new):
382 383 raise TypeError("packer must be callable, not %s"%type(new))
383 384
384 385 unpack = Any(default_unpacker) # the actual packer function
385 386 def _unpack_changed(self, name, old, new):
386 387 # unpacker is not checked - it is assumed to be
387 388 if not callable(new):
388 389 raise TypeError("unpacker must be callable, not %s"%type(new))
389 390
390 391 # thresholds:
391 392 copy_threshold = Integer(2**16, config=True,
392 393 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
393 394 buffer_threshold = Integer(MAX_BYTES, config=True,
394 395 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
395 396 item_threshold = Integer(MAX_ITEMS, config=True,
396 397 help="""The maximum number of items for a container to be introspected for custom serialization.
397 398 Containers larger than this are pickled outright.
398 399 """
399 400 )
400 401
401 402
402 403 def __init__(self, **kwargs):
403 404 """create a Session object
404 405
405 406 Parameters
406 407 ----------
407 408
408 409 debug : bool
409 410 whether to trigger extra debugging statements
410 411 packer/unpacker : str : 'json', 'pickle' or import_string
411 412 importstrings for methods to serialize message parts. If just
412 413 'json' or 'pickle', predefined JSON and pickle packers will be used.
413 414 Otherwise, the entire importstring must be used.
414 415
415 416 The functions must accept at least valid JSON input, and output
416 417 *bytes*.
417 418
418 419 For example, to use msgpack:
419 420 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
420 421 pack/unpack : callables
421 422 You can also set the pack/unpack callables for serialization
422 423 directly.
423 424 session : unicode (must be ascii)
424 425 the ID of this Session object. The default is to generate a new
425 426 UUID.
426 427 bsession : bytes
427 428 The session as bytes
428 429 username : unicode
429 430 username added to message headers. The default is to ask the OS.
430 431 key : bytes
431 432 The key used to initialize an HMAC signature. If unset, messages
432 433 will not be signed or checked.
433 434 signature_scheme : str
434 435 The message digest scheme. Currently must be of the form 'hmac-HASH',
435 436 where 'HASH' is a hashing function available in Python's hashlib.
436 437 The default is 'hmac-sha256'.
437 438 This is ignored if 'key' is empty.
438 439 keyfile : filepath
439 440 The file containing a key. If this is set, `key` will be
440 441 initialized to the contents of the file.
441 442 """
442 443 super(Session, self).__init__(**kwargs)
443 444 self._check_packers()
444 445 self.none = self.pack({})
445 446 # ensure self._session_default() if necessary, so bsession is defined:
446 447 self.session
447 448 self.pid = os.getpid()
448 449 self._new_auth()
449 450
450 451 @property
451 452 def msg_id(self):
452 453 """always return new uuid"""
453 454 return str(uuid.uuid4())
454 455
455 456 def _check_packers(self):
456 457 """check packers for datetime support."""
457 458 pack = self.pack
458 459 unpack = self.unpack
459 460
460 461 # check simple serialization
461 462 msg = dict(a=[1,'hi'])
462 463 try:
463 464 packed = pack(msg)
464 465 except Exception as e:
465 466 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
466 467 if self.packer == 'json':
467 468 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
468 469 else:
469 470 jsonmsg = ""
470 471 raise ValueError(
471 472 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
472 473 )
473 474
474 475 # ensure packed message is bytes
475 476 if not isinstance(packed, bytes):
476 477 raise ValueError("message packed to %r, but bytes are required"%type(packed))
477 478
478 479 # check that unpack is pack's inverse
479 480 try:
480 481 unpacked = unpack(packed)
481 482 assert unpacked == msg
482 483 except Exception as e:
483 484 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
484 485 if self.packer == 'json':
485 486 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
486 487 else:
487 488 jsonmsg = ""
488 489 raise ValueError(
489 490 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
490 491 )
491 492
492 493 # check datetime support
493 494 msg = dict(t=datetime.now())
494 495 try:
495 496 unpacked = unpack(pack(msg))
496 497 if isinstance(unpacked['t'], datetime):
497 498 raise ValueError("Shouldn't deserialize to datetime")
498 499 except Exception:
499 500 self.pack = lambda o: pack(squash_dates(o))
500 501 self.unpack = lambda s: unpack(s)
501 502
502 503 def msg_header(self, msg_type):
503 504 return msg_header(self.msg_id, msg_type, self.username, self.session)
504 505
505 506 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
506 507 """Return the nested message dict.
507 508
508 509 This format is different from what is sent over the wire. The
509 510 serialize/deserialize methods converts this nested message dict to the wire
510 511 format, which is a list of message parts.
511 512 """
512 513 msg = {}
513 514 header = self.msg_header(msg_type) if header is None else header
514 515 msg['header'] = header
515 516 msg['msg_id'] = header['msg_id']
516 517 msg['msg_type'] = header['msg_type']
517 518 msg['parent_header'] = {} if parent is None else extract_header(parent)
518 519 msg['content'] = {} if content is None else content
519 520 msg['metadata'] = self.metadata.copy()
520 521 if metadata is not None:
521 522 msg['metadata'].update(metadata)
522 523 return msg
523 524
524 525 def sign(self, msg_list):
525 526 """Sign a message with HMAC digest. If no auth, return b''.
526 527
527 528 Parameters
528 529 ----------
529 530 msg_list : list
530 531 The [p_header,p_parent,p_content] part of the message list.
531 532 """
532 533 if self.auth is None:
533 534 return b''
534 535 h = self.auth.copy()
535 536 for m in msg_list:
536 537 h.update(m)
537 538 return str_to_bytes(h.hexdigest())
538 539
539 540 def serialize(self, msg, ident=None):
540 541 """Serialize the message components to bytes.
541 542
542 543 This is roughly the inverse of deserialize. The serialize/deserialize
543 544 methods work with full message lists, whereas pack/unpack work with
544 545 the individual message parts in the message list.
545 546
546 547 Parameters
547 548 ----------
548 549 msg : dict or Message
549 550 The next message dict as returned by the self.msg method.
550 551
551 552 Returns
552 553 -------
553 554 msg_list : list
554 555 The list of bytes objects to be sent with the format::
555 556
556 557 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
557 558 p_metadata, p_content, buffer1, buffer2, ...]
558 559
559 560 In this list, the ``p_*`` entities are the packed or serialized
560 561 versions, so if JSON is used, these are utf8 encoded JSON strings.
561 562 """
562 563 content = msg.get('content', {})
563 564 if content is None:
564 565 content = self.none
565 566 elif isinstance(content, dict):
566 567 content = self.pack(content)
567 568 elif isinstance(content, bytes):
568 569 # content is already packed, as in a relayed message
569 570 pass
570 571 elif isinstance(content, unicode_type):
571 572 # should be bytes, but JSON often spits out unicode
572 573 content = content.encode('utf8')
573 574 else:
574 575 raise TypeError("Content incorrect type: %s"%type(content))
575 576
576 577 real_message = [self.pack(msg['header']),
577 578 self.pack(msg['parent_header']),
578 579 self.pack(msg['metadata']),
579 580 content,
580 581 ]
581 582
582 583 to_send = []
583 584
584 585 if isinstance(ident, list):
585 586 # accept list of idents
586 587 to_send.extend(ident)
587 588 elif ident is not None:
588 589 to_send.append(ident)
589 590 to_send.append(DELIM)
590 591
591 592 signature = self.sign(real_message)
592 593 to_send.append(signature)
593 594
594 595 to_send.extend(real_message)
595 596
596 597 return to_send
597 598
598 599 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
599 600 buffers=None, track=False, header=None, metadata=None):
600 601 """Build and send a message via stream or socket.
601 602
602 603 The message format used by this function internally is as follows:
603 604
604 605 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
605 606 buffer1,buffer2,...]
606 607
607 608 The serialize/deserialize methods convert the nested message dict into this
608 609 format.
609 610
610 611 Parameters
611 612 ----------
612 613
613 614 stream : zmq.Socket or ZMQStream
614 615 The socket-like object used to send the data.
615 616 msg_or_type : str or Message/dict
616 617 Normally, msg_or_type will be a msg_type unless a message is being
617 618 sent more than once. If a header is supplied, this can be set to
618 619 None and the msg_type will be pulled from the header.
619 620
620 621 content : dict or None
621 622 The content of the message (ignored if msg_or_type is a message).
622 623 header : dict or None
623 624 The header dict for the message (ignored if msg_to_type is a message).
624 625 parent : Message or dict or None
625 626 The parent or parent header describing the parent of this message
626 627 (ignored if msg_or_type is a message).
627 628 ident : bytes or list of bytes
628 629 The zmq.IDENTITY routing path.
629 630 metadata : dict or None
630 631 The metadata describing the message
631 632 buffers : list or None
632 633 The already-serialized buffers to be appended to the message.
633 634 track : bool
634 635 Whether to track. Only for use with Sockets, because ZMQStream
635 636 objects cannot track messages.
636 637
637 638
638 639 Returns
639 640 -------
640 641 msg : dict
641 642 The constructed message.
642 643 """
643 644 if not isinstance(stream, zmq.Socket):
644 645 # ZMQStreams and dummy sockets do not support tracking.
645 646 track = False
646 647
647 648 if isinstance(msg_or_type, (Message, dict)):
648 649 # We got a Message or message dict, not a msg_type so don't
649 650 # build a new Message.
650 651 msg = msg_or_type
651 652 buffers = buffers or msg.get('buffers', [])
652 653 else:
653 654 msg = self.msg(msg_or_type, content=content, parent=parent,
654 655 header=header, metadata=metadata)
655 656 if not os.getpid() == self.pid:
656 io.rprint("WARNING: attempted to send message from fork")
657 io.rprint(msg)
657 get_logger().warn("WARNING: attempted to send message from fork\n%s",
658 msg
659 )
658 660 return
659 661 buffers = [] if buffers is None else buffers
660 662 if self.adapt_version:
661 663 msg = adapt(msg, self.adapt_version)
662 664 to_send = self.serialize(msg, ident)
663 665 to_send.extend(buffers)
664 666 longest = max([ len(s) for s in to_send ])
665 667 copy = (longest < self.copy_threshold)
666 668
667 669 if buffers and track and not copy:
668 670 # only really track when we are doing zero-copy buffers
669 671 tracker = stream.send_multipart(to_send, copy=False, track=True)
670 672 else:
671 673 # use dummy tracker, which will be done immediately
672 674 tracker = DONE
673 675 stream.send_multipart(to_send, copy=copy)
674 676
675 677 if self.debug:
676 678 pprint.pprint(msg)
677 679 pprint.pprint(to_send)
678 680 pprint.pprint(buffers)
679 681
680 682 msg['tracker'] = tracker
681 683
682 684 return msg
683 685
684 686 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
685 687 """Send a raw message via ident path.
686 688
687 689 This method is used to send a already serialized message.
688 690
689 691 Parameters
690 692 ----------
691 693 stream : ZMQStream or Socket
692 694 The ZMQ stream or socket to use for sending the message.
693 695 msg_list : list
694 696 The serialized list of messages to send. This only includes the
695 697 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
696 698 the message.
697 699 ident : ident or list
698 700 A single ident or a list of idents to use in sending.
699 701 """
700 702 to_send = []
701 703 if isinstance(ident, bytes):
702 704 ident = [ident]
703 705 if ident is not None:
704 706 to_send.extend(ident)
705 707
706 708 to_send.append(DELIM)
707 709 to_send.append(self.sign(msg_list))
708 710 to_send.extend(msg_list)
709 711 stream.send_multipart(to_send, flags, copy=copy)
710 712
711 713 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
712 714 """Receive and unpack a message.
713 715
714 716 Parameters
715 717 ----------
716 718 socket : ZMQStream or Socket
717 719 The socket or stream to use in receiving.
718 720
719 721 Returns
720 722 -------
721 723 [idents], msg
722 724 [idents] is a list of idents and msg is a nested message dict of
723 725 same format as self.msg returns.
724 726 """
725 727 if isinstance(socket, ZMQStream):
726 728 socket = socket.socket
727 729 try:
728 730 msg_list = socket.recv_multipart(mode, copy=copy)
729 731 except zmq.ZMQError as e:
730 732 if e.errno == zmq.EAGAIN:
731 733 # We can convert EAGAIN to None as we know in this case
732 734 # recv_multipart won't return None.
733 735 return None,None
734 736 else:
735 737 raise
736 738 # split multipart message into identity list and message dict
737 739 # invalid large messages can cause very expensive string comparisons
738 740 idents, msg_list = self.feed_identities(msg_list, copy)
739 741 try:
740 742 return idents, self.deserialize(msg_list, content=content, copy=copy)
741 743 except Exception as e:
742 744 # TODO: handle it
743 745 raise e
744 746
745 747 def feed_identities(self, msg_list, copy=True):
746 748 """Split the identities from the rest of the message.
747 749
748 750 Feed until DELIM is reached, then return the prefix as idents and
749 751 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
750 752 but that would be silly.
751 753
752 754 Parameters
753 755 ----------
754 756 msg_list : a list of Message or bytes objects
755 757 The message to be split.
756 758 copy : bool
757 759 flag determining whether the arguments are bytes or Messages
758 760
759 761 Returns
760 762 -------
761 763 (idents, msg_list) : two lists
762 764 idents will always be a list of bytes, each of which is a ZMQ
763 765 identity. msg_list will be a list of bytes or zmq.Messages of the
764 766 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
765 767 should be unpackable/unserializable via self.deserialize at this
766 768 point.
767 769 """
768 770 if copy:
769 771 idx = msg_list.index(DELIM)
770 772 return msg_list[:idx], msg_list[idx+1:]
771 773 else:
772 774 failed = True
773 775 for idx,m in enumerate(msg_list):
774 776 if m.bytes == DELIM:
775 777 failed = False
776 778 break
777 779 if failed:
778 780 raise ValueError("DELIM not in msg_list")
779 781 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
780 782 return [m.bytes for m in idents], msg_list
781 783
782 784 def _add_digest(self, signature):
783 785 """add a digest to history to protect against replay attacks"""
784 786 if self.digest_history_size == 0:
785 787 # no history, never add digests
786 788 return
787 789
788 790 self.digest_history.add(signature)
789 791 if len(self.digest_history) > self.digest_history_size:
790 792 # threshold reached, cull 10%
791 793 self._cull_digest_history()
792 794
793 795 def _cull_digest_history(self):
794 796 """cull the digest history
795 797
796 798 Removes a randomly selected 10% of the digest history
797 799 """
798 800 current = len(self.digest_history)
799 801 n_to_cull = max(int(current // 10), current - self.digest_history_size)
800 802 if n_to_cull >= current:
801 803 self.digest_history = set()
802 804 return
803 805 to_cull = random.sample(self.digest_history, n_to_cull)
804 806 self.digest_history.difference_update(to_cull)
805 807
806 808 def deserialize(self, msg_list, content=True, copy=True):
807 809 """Unserialize a msg_list to a nested message dict.
808 810
809 811 This is roughly the inverse of serialize. The serialize/deserialize
810 812 methods work with full message lists, whereas pack/unpack work with
811 813 the individual message parts in the message list.
812 814
813 815 Parameters
814 816 ----------
815 817 msg_list : list of bytes or Message objects
816 818 The list of message parts of the form [HMAC,p_header,p_parent,
817 819 p_metadata,p_content,buffer1,buffer2,...].
818 820 content : bool (True)
819 821 Whether to unpack the content dict (True), or leave it packed
820 822 (False).
821 823 copy : bool (True)
822 824 Whether msg_list contains bytes (True) or the non-copying Message
823 825 objects in each place (False).
824 826
825 827 Returns
826 828 -------
827 829 msg : dict
828 830 The nested message dict with top-level keys [header, parent_header,
829 831 content, buffers]. The buffers are returned as memoryviews.
830 832 """
831 833 minlen = 5
832 834 message = {}
833 835 if not copy:
834 836 # pyzmq didn't copy the first parts of the message, so we'll do it
835 837 for i in range(minlen):
836 838 msg_list[i] = msg_list[i].bytes
837 839 if self.auth is not None:
838 840 signature = msg_list[0]
839 841 if not signature:
840 842 raise ValueError("Unsigned Message")
841 843 if signature in self.digest_history:
842 844 raise ValueError("Duplicate Signature: %r" % signature)
843 845 self._add_digest(signature)
844 846 check = self.sign(msg_list[1:5])
845 847 if not compare_digest(signature, check):
846 848 raise ValueError("Invalid Signature: %r" % signature)
847 849 if not len(msg_list) >= minlen:
848 850 raise TypeError("malformed message, must have at least %i elements"%minlen)
849 851 header = self.unpack(msg_list[1])
850 852 message['header'] = extract_dates(header)
851 853 message['msg_id'] = header['msg_id']
852 854 message['msg_type'] = header['msg_type']
853 855 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
854 856 message['metadata'] = self.unpack(msg_list[3])
855 857 if content:
856 858 message['content'] = self.unpack(msg_list[4])
857 859 else:
858 860 message['content'] = msg_list[4]
859 861 buffers = [memoryview(b) for b in msg_list[5:]]
860 862 if buffers and buffers[0].shape is None:
861 863 # force copy to workaround pyzmq #646
862 864 buffers = [memoryview(b.bytes) for b in msg_list[5:]]
863 865 message['buffers'] = buffers
864 866 # adapt to the current version
865 867 return adapt(message)
866 868
867 869 def unserialize(self, *args, **kwargs):
868 870 warnings.warn(
869 871 "Session.unserialize is deprecated. Use Session.deserialize.",
870 872 DeprecationWarning,
871 873 )
872 874 return self.deserialize(*args, **kwargs)
873 875
874 876
875 877 def test_msg2obj():
876 878 am = dict(x=1)
877 879 ao = Message(am)
878 880 assert ao.x == am['x']
879 881
880 882 am['y'] = dict(z=1)
881 883 ao = Message(am)
882 884 assert ao.y.z == am['y']['z']
883 885
884 886 k1, k2 = 'y', 'z'
885 887 assert ao[k1][k2] == am[k1][k2]
886 888
887 889 am2 = dict(ao)
888 890 assert am['x'] == am2['x']
889 891 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now