##// END OF EJS Templates
Extract session buffers as memoryviews...
Jason Grout -
Show More
@@ -1,873 +1,876 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9 """
10 10
11 11 # Copyright (c) IPython Development Team.
12 12 # Distributed under the terms of the Modified BSD License.
13 13
14 14 import hashlib
15 15 import hmac
16 16 import logging
17 17 import os
18 18 import pprint
19 19 import random
20 20 import uuid
21 21 import warnings
22 22 from datetime import datetime
23 23
24 24 try:
25 25 import cPickle
26 26 pickle = cPickle
27 27 except:
28 28 cPickle = None
29 29 import pickle
30 30
31 31 try:
32 32 # We are using compare_digest to limit the surface of timing attacks
33 33 from hmac import compare_digest
34 34 except ImportError:
35 35 # Python < 2.7.7: When digests don't match no feedback is provided,
36 36 # limiting the surface of attack
37 37 def compare_digest(a,b): return a == b
38 38
39 39 import zmq
40 40 from zmq.utils import jsonapi
41 41 from zmq.eventloop.ioloop import IOLoop
42 42 from zmq.eventloop.zmqstream import ZMQStream
43 43
44 44 from IPython.core.release import kernel_protocol_version
45 45 from IPython.config.configurable import Configurable, LoggingConfigurable
46 46 from IPython.utils import io
47 47 from IPython.utils.importstring import import_item
48 48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
49 49 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
50 50 iteritems)
51 51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
52 52 DottedObjectName, CUnicode, Dict, Integer,
53 53 TraitError,
54 54 )
55 55 from IPython.utils.pickleutil import PICKLE_PROTOCOL
56 56 from IPython.kernel.adapter import adapt
57 57 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
58 58
59 59 #-----------------------------------------------------------------------------
60 60 # utility functions
61 61 #-----------------------------------------------------------------------------
62 62
63 63 def squash_unicode(obj):
64 64 """coerce unicode back to bytestrings."""
65 65 if isinstance(obj,dict):
66 66 for key in obj.keys():
67 67 obj[key] = squash_unicode(obj[key])
68 68 if isinstance(key, unicode_type):
69 69 obj[squash_unicode(key)] = obj.pop(key)
70 70 elif isinstance(obj, list):
71 71 for i,v in enumerate(obj):
72 72 obj[i] = squash_unicode(v)
73 73 elif isinstance(obj, unicode_type):
74 74 obj = obj.encode('utf8')
75 75 return obj
76 76
77 77 #-----------------------------------------------------------------------------
78 78 # globals and defaults
79 79 #-----------------------------------------------------------------------------
80 80
81 81 # ISO8601-ify datetime objects
82 82 # allow unicode
83 83 # disallow nan, because it's not actually valid JSON
84 84 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
85 85 ensure_ascii=False, allow_nan=False,
86 86 )
87 87 json_unpacker = lambda s: jsonapi.loads(s)
88 88
89 89 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
90 90 pickle_unpacker = pickle.loads
91 91
92 92 default_packer = json_packer
93 93 default_unpacker = json_unpacker
94 94
95 95 DELIM = b"<IDS|MSG>"
96 96 # singleton dummy tracker, which will always report as done
97 97 DONE = zmq.MessageTracker()
98 98
99 99 #-----------------------------------------------------------------------------
100 100 # Mixin tools for apps that use Sessions
101 101 #-----------------------------------------------------------------------------
102 102
103 103 session_aliases = dict(
104 104 ident = 'Session.session',
105 105 user = 'Session.username',
106 106 keyfile = 'Session.keyfile',
107 107 )
108 108
109 109 session_flags = {
110 110 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
111 111 'keyfile' : '' }},
112 112 """Use HMAC digests for authentication of messages.
113 113 Setting this flag will generate a new UUID to use as the HMAC key.
114 114 """),
115 115 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
116 116 """Don't authenticate messages."""),
117 117 }
118 118
119 119 def default_secure(cfg):
120 120 """Set the default behavior for a config environment to be secure.
121 121
122 122 If Session.key/keyfile have not been set, set Session.key to
123 123 a new random UUID.
124 124 """
125 125
126 126 if 'Session' in cfg:
127 127 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
128 128 return
129 129 # key/keyfile not specified, generate new UUID:
130 130 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
131 131
132 132
133 133 #-----------------------------------------------------------------------------
134 134 # Classes
135 135 #-----------------------------------------------------------------------------
136 136
137 137 class SessionFactory(LoggingConfigurable):
138 138 """The Base class for configurables that have a Session, Context, logger,
139 139 and IOLoop.
140 140 """
141 141
142 142 logname = Unicode('')
143 143 def _logname_changed(self, name, old, new):
144 144 self.log = logging.getLogger(new)
145 145
146 146 # not configurable:
147 147 context = Instance('zmq.Context')
148 148 def _context_default(self):
149 149 return zmq.Context.instance()
150 150
151 151 session = Instance('IPython.kernel.zmq.session.Session')
152 152
153 153 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
154 154 def _loop_default(self):
155 155 return IOLoop.instance()
156 156
157 157 def __init__(self, **kwargs):
158 158 super(SessionFactory, self).__init__(**kwargs)
159 159
160 160 if self.session is None:
161 161 # construct the session
162 162 self.session = Session(**kwargs)
163 163
164 164
165 165 class Message(object):
166 166 """A simple message object that maps dict keys to attributes.
167 167
168 168 A Message can be created from a dict and a dict from a Message instance
169 169 simply by calling dict(msg_obj)."""
170 170
171 171 def __init__(self, msg_dict):
172 172 dct = self.__dict__
173 173 for k, v in iteritems(dict(msg_dict)):
174 174 if isinstance(v, dict):
175 175 v = Message(v)
176 176 dct[k] = v
177 177
178 178 # Having this iterator lets dict(msg_obj) work out of the box.
179 179 def __iter__(self):
180 180 return iter(iteritems(self.__dict__))
181 181
182 182 def __repr__(self):
183 183 return repr(self.__dict__)
184 184
185 185 def __str__(self):
186 186 return pprint.pformat(self.__dict__)
187 187
188 188 def __contains__(self, k):
189 189 return k in self.__dict__
190 190
191 191 def __getitem__(self, k):
192 192 return self.__dict__[k]
193 193
194 194
195 195 def msg_header(msg_id, msg_type, username, session):
196 196 date = datetime.now()
197 197 version = kernel_protocol_version
198 198 return locals()
199 199
200 200 def extract_header(msg_or_header):
201 201 """Given a message or header, return the header."""
202 202 if not msg_or_header:
203 203 return {}
204 204 try:
205 205 # See if msg_or_header is the entire message.
206 206 h = msg_or_header['header']
207 207 except KeyError:
208 208 try:
209 209 # See if msg_or_header is just the header
210 210 h = msg_or_header['msg_id']
211 211 except KeyError:
212 212 raise
213 213 else:
214 214 h = msg_or_header
215 215 if not isinstance(h, dict):
216 216 h = dict(h)
217 217 return h
218 218
219 219 class Session(Configurable):
220 220 """Object for handling serialization and sending of messages.
221 221
222 222 The Session object handles building messages and sending them
223 223 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
224 224 other over the network via Session objects, and only need to work with the
225 225 dict-based IPython message spec. The Session will handle
226 226 serialization/deserialization, security, and metadata.
227 227
228 228 Sessions support configurable serialization via packer/unpacker traits,
229 229 and signing with HMAC digests via the key/keyfile traits.
230 230
231 231 Parameters
232 232 ----------
233 233
234 234 debug : bool
235 235 whether to trigger extra debugging statements
236 236 packer/unpacker : str : 'json', 'pickle' or import_string
237 237 importstrings for methods to serialize message parts. If just
238 238 'json' or 'pickle', predefined JSON and pickle packers will be used.
239 239 Otherwise, the entire importstring must be used.
240 240
241 241 The functions must accept at least valid JSON input, and output *bytes*.
242 242
243 243 For example, to use msgpack:
244 244 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
245 245 pack/unpack : callables
246 246 You can also set the pack/unpack callables for serialization directly.
247 247 session : bytes
248 248 the ID of this Session object. The default is to generate a new UUID.
249 249 username : unicode
250 250 username added to message headers. The default is to ask the OS.
251 251 key : bytes
252 252 The key used to initialize an HMAC signature. If unset, messages
253 253 will not be signed or checked.
254 254 keyfile : filepath
255 255 The file containing a key. If this is set, `key` will be initialized
256 256 to the contents of the file.
257 257
258 258 """
259 259
260 260 debug=Bool(False, config=True, help="""Debug output in the Session""")
261 261
262 262 packer = DottedObjectName('json',config=True,
263 263 help="""The name of the packer for serializing messages.
264 264 Should be one of 'json', 'pickle', or an import name
265 265 for a custom callable serializer.""")
266 266 def _packer_changed(self, name, old, new):
267 267 if new.lower() == 'json':
268 268 self.pack = json_packer
269 269 self.unpack = json_unpacker
270 270 self.unpacker = new
271 271 elif new.lower() == 'pickle':
272 272 self.pack = pickle_packer
273 273 self.unpack = pickle_unpacker
274 274 self.unpacker = new
275 275 else:
276 276 self.pack = import_item(str(new))
277 277
278 278 unpacker = DottedObjectName('json', config=True,
279 279 help="""The name of the unpacker for unserializing messages.
280 280 Only used with custom functions for `packer`.""")
281 281 def _unpacker_changed(self, name, old, new):
282 282 if new.lower() == 'json':
283 283 self.pack = json_packer
284 284 self.unpack = json_unpacker
285 285 self.packer = new
286 286 elif new.lower() == 'pickle':
287 287 self.pack = pickle_packer
288 288 self.unpack = pickle_unpacker
289 289 self.packer = new
290 290 else:
291 291 self.unpack = import_item(str(new))
292 292
293 293 session = CUnicode(u'', config=True,
294 294 help="""The UUID identifying this session.""")
295 295 def _session_default(self):
296 296 u = unicode_type(uuid.uuid4())
297 297 self.bsession = u.encode('ascii')
298 298 return u
299 299
300 300 def _session_changed(self, name, old, new):
301 301 self.bsession = self.session.encode('ascii')
302 302
303 303 # bsession is the session as bytes
304 304 bsession = CBytes(b'')
305 305
306 306 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
307 307 help="""Username for the Session. Default is your system username.""",
308 308 config=True)
309 309
310 310 metadata = Dict({}, config=True,
311 311 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
312 312
313 313 # if 0, no adapting to do.
314 314 adapt_version = Integer(0)
315 315
316 316 # message signature related traits:
317 317
318 318 key = CBytes(b'', config=True,
319 319 help="""execution key, for extra authentication.""")
320 320 def _key_changed(self):
321 321 self._new_auth()
322 322
323 323 signature_scheme = Unicode('hmac-sha256', config=True,
324 324 help="""The digest scheme used to construct the message signatures.
325 325 Must have the form 'hmac-HASH'.""")
326 326 def _signature_scheme_changed(self, name, old, new):
327 327 if not new.startswith('hmac-'):
328 328 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
329 329 hash_name = new.split('-', 1)[1]
330 330 try:
331 331 self.digest_mod = getattr(hashlib, hash_name)
332 332 except AttributeError:
333 333 raise TraitError("hashlib has no such attribute: %s" % hash_name)
334 334 self._new_auth()
335 335
336 336 digest_mod = Any()
337 337 def _digest_mod_default(self):
338 338 return hashlib.sha256
339 339
340 340 auth = Instance(hmac.HMAC)
341 341
342 342 def _new_auth(self):
343 343 if self.key:
344 344 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
345 345 else:
346 346 self.auth = None
347 347
348 348 digest_history = Set()
349 349 digest_history_size = Integer(2**16, config=True,
350 350 help="""The maximum number of digests to remember.
351 351
352 352 The digest history will be culled when it exceeds this value.
353 353 """
354 354 )
355 355
356 356 keyfile = Unicode('', config=True,
357 357 help="""path to file containing execution key.""")
358 358 def _keyfile_changed(self, name, old, new):
359 359 with open(new, 'rb') as f:
360 360 self.key = f.read().strip()
361 361
362 362 # for protecting against sends from forks
363 363 pid = Integer()
364 364
365 365 # serialization traits:
366 366
367 367 pack = Any(default_packer) # the actual packer function
368 368 def _pack_changed(self, name, old, new):
369 369 if not callable(new):
370 370 raise TypeError("packer must be callable, not %s"%type(new))
371 371
372 372 unpack = Any(default_unpacker) # the actual packer function
373 373 def _unpack_changed(self, name, old, new):
374 374 # unpacker is not checked - it is assumed to be
375 375 if not callable(new):
376 376 raise TypeError("unpacker must be callable, not %s"%type(new))
377 377
378 378 # thresholds:
379 379 copy_threshold = Integer(2**16, config=True,
380 380 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
381 381 buffer_threshold = Integer(MAX_BYTES, config=True,
382 382 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
383 383 item_threshold = Integer(MAX_ITEMS, config=True,
384 384 help="""The maximum number of items for a container to be introspected for custom serialization.
385 385 Containers larger than this are pickled outright.
386 386 """
387 387 )
388 388
389 389
390 390 def __init__(self, **kwargs):
391 391 """create a Session object
392 392
393 393 Parameters
394 394 ----------
395 395
396 396 debug : bool
397 397 whether to trigger extra debugging statements
398 398 packer/unpacker : str : 'json', 'pickle' or import_string
399 399 importstrings for methods to serialize message parts. If just
400 400 'json' or 'pickle', predefined JSON and pickle packers will be used.
401 401 Otherwise, the entire importstring must be used.
402 402
403 403 The functions must accept at least valid JSON input, and output
404 404 *bytes*.
405 405
406 406 For example, to use msgpack:
407 407 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
408 408 pack/unpack : callables
409 409 You can also set the pack/unpack callables for serialization
410 410 directly.
411 411 session : unicode (must be ascii)
412 412 the ID of this Session object. The default is to generate a new
413 413 UUID.
414 414 bsession : bytes
415 415 The session as bytes
416 416 username : unicode
417 417 username added to message headers. The default is to ask the OS.
418 418 key : bytes
419 419 The key used to initialize an HMAC signature. If unset, messages
420 420 will not be signed or checked.
421 421 signature_scheme : str
422 422 The message digest scheme. Currently must be of the form 'hmac-HASH',
423 423 where 'HASH' is a hashing function available in Python's hashlib.
424 424 The default is 'hmac-sha256'.
425 425 This is ignored if 'key' is empty.
426 426 keyfile : filepath
427 427 The file containing a key. If this is set, `key` will be
428 428 initialized to the contents of the file.
429 429 """
430 430 super(Session, self).__init__(**kwargs)
431 431 self._check_packers()
432 432 self.none = self.pack({})
433 433 # ensure self._session_default() if necessary, so bsession is defined:
434 434 self.session
435 435 self.pid = os.getpid()
436 436
437 437 @property
438 438 def msg_id(self):
439 439 """always return new uuid"""
440 440 return str(uuid.uuid4())
441 441
442 442 def _check_packers(self):
443 443 """check packers for datetime support."""
444 444 pack = self.pack
445 445 unpack = self.unpack
446 446
447 447 # check simple serialization
448 448 msg = dict(a=[1,'hi'])
449 449 try:
450 450 packed = pack(msg)
451 451 except Exception as e:
452 452 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
453 453 if self.packer == 'json':
454 454 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
455 455 else:
456 456 jsonmsg = ""
457 457 raise ValueError(
458 458 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
459 459 )
460 460
461 461 # ensure packed message is bytes
462 462 if not isinstance(packed, bytes):
463 463 raise ValueError("message packed to %r, but bytes are required"%type(packed))
464 464
465 465 # check that unpack is pack's inverse
466 466 try:
467 467 unpacked = unpack(packed)
468 468 assert unpacked == msg
469 469 except Exception as e:
470 470 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
471 471 if self.packer == 'json':
472 472 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
473 473 else:
474 474 jsonmsg = ""
475 475 raise ValueError(
476 476 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
477 477 )
478 478
479 479 # check datetime support
480 480 msg = dict(t=datetime.now())
481 481 try:
482 482 unpacked = unpack(pack(msg))
483 483 if isinstance(unpacked['t'], datetime):
484 484 raise ValueError("Shouldn't deserialize to datetime")
485 485 except Exception:
486 486 self.pack = lambda o: pack(squash_dates(o))
487 487 self.unpack = lambda s: unpack(s)
488 488
489 489 def msg_header(self, msg_type):
490 490 return msg_header(self.msg_id, msg_type, self.username, self.session)
491 491
492 492 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
493 493 """Return the nested message dict.
494 494
495 495 This format is different from what is sent over the wire. The
496 496 serialize/deserialize methods converts this nested message dict to the wire
497 497 format, which is a list of message parts.
498 498 """
499 499 msg = {}
500 500 header = self.msg_header(msg_type) if header is None else header
501 501 msg['header'] = header
502 502 msg['msg_id'] = header['msg_id']
503 503 msg['msg_type'] = header['msg_type']
504 504 msg['parent_header'] = {} if parent is None else extract_header(parent)
505 505 msg['content'] = {} if content is None else content
506 506 msg['metadata'] = self.metadata.copy()
507 507 if metadata is not None:
508 508 msg['metadata'].update(metadata)
509 509 return msg
510 510
511 511 def sign(self, msg_list):
512 512 """Sign a message with HMAC digest. If no auth, return b''.
513 513
514 514 Parameters
515 515 ----------
516 516 msg_list : list
517 517 The [p_header,p_parent,p_content] part of the message list.
518 518 """
519 519 if self.auth is None:
520 520 return b''
521 521 h = self.auth.copy()
522 522 for m in msg_list:
523 523 h.update(m)
524 524 return str_to_bytes(h.hexdigest())
525 525
526 526 def serialize(self, msg, ident=None):
527 527 """Serialize the message components to bytes.
528 528
529 529 This is roughly the inverse of deserialize. The serialize/deserialize
530 530 methods work with full message lists, whereas pack/unpack work with
531 531 the individual message parts in the message list.
532 532
533 533 Parameters
534 534 ----------
535 535 msg : dict or Message
536 536 The next message dict as returned by the self.msg method.
537 537
538 538 Returns
539 539 -------
540 540 msg_list : list
541 541 The list of bytes objects to be sent with the format::
542 542
543 543 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
544 544 p_metadata, p_content, buffer1, buffer2, ...]
545 545
546 546 In this list, the ``p_*`` entities are the packed or serialized
547 547 versions, so if JSON is used, these are utf8 encoded JSON strings.
548 548 """
549 549 content = msg.get('content', {})
550 550 if content is None:
551 551 content = self.none
552 552 elif isinstance(content, dict):
553 553 content = self.pack(content)
554 554 elif isinstance(content, bytes):
555 555 # content is already packed, as in a relayed message
556 556 pass
557 557 elif isinstance(content, unicode_type):
558 558 # should be bytes, but JSON often spits out unicode
559 559 content = content.encode('utf8')
560 560 else:
561 561 raise TypeError("Content incorrect type: %s"%type(content))
562 562
563 563 real_message = [self.pack(msg['header']),
564 564 self.pack(msg['parent_header']),
565 565 self.pack(msg['metadata']),
566 566 content,
567 567 ]
568 568
569 569 to_send = []
570 570
571 571 if isinstance(ident, list):
572 572 # accept list of idents
573 573 to_send.extend(ident)
574 574 elif ident is not None:
575 575 to_send.append(ident)
576 576 to_send.append(DELIM)
577 577
578 578 signature = self.sign(real_message)
579 579 to_send.append(signature)
580 580
581 581 to_send.extend(real_message)
582 582
583 583 return to_send
584 584
585 585 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
586 586 buffers=None, track=False, header=None, metadata=None):
587 587 """Build and send a message via stream or socket.
588 588
589 589 The message format used by this function internally is as follows:
590 590
591 591 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
592 592 buffer1,buffer2,...]
593 593
594 594 The serialize/deserialize methods convert the nested message dict into this
595 595 format.
596 596
597 597 Parameters
598 598 ----------
599 599
600 600 stream : zmq.Socket or ZMQStream
601 601 The socket-like object used to send the data.
602 602 msg_or_type : str or Message/dict
603 603 Normally, msg_or_type will be a msg_type unless a message is being
604 604 sent more than once. If a header is supplied, this can be set to
605 605 None and the msg_type will be pulled from the header.
606 606
607 607 content : dict or None
608 608 The content of the message (ignored if msg_or_type is a message).
609 609 header : dict or None
610 610 The header dict for the message (ignored if msg_to_type is a message).
611 611 parent : Message or dict or None
612 612 The parent or parent header describing the parent of this message
613 613 (ignored if msg_or_type is a message).
614 614 ident : bytes or list of bytes
615 615 The zmq.IDENTITY routing path.
616 616 metadata : dict or None
617 617 The metadata describing the message
618 618 buffers : list or None
619 619 The already-serialized buffers to be appended to the message.
620 620 track : bool
621 621 Whether to track. Only for use with Sockets, because ZMQStream
622 622 objects cannot track messages.
623 623
624 624
625 625 Returns
626 626 -------
627 627 msg : dict
628 628 The constructed message.
629 629 """
630 630 if not isinstance(stream, zmq.Socket):
631 631 # ZMQStreams and dummy sockets do not support tracking.
632 632 track = False
633 633
634 634 if isinstance(msg_or_type, (Message, dict)):
635 635 # We got a Message or message dict, not a msg_type so don't
636 636 # build a new Message.
637 637 msg = msg_or_type
638 638 buffers = buffers or msg.get('buffers', [])
639 639 else:
640 640 msg = self.msg(msg_or_type, content=content, parent=parent,
641 641 header=header, metadata=metadata)
642 642 if not os.getpid() == self.pid:
643 643 io.rprint("WARNING: attempted to send message from fork")
644 644 io.rprint(msg)
645 645 return
646 646 buffers = [] if buffers is None else buffers
647 647 if self.adapt_version:
648 648 msg = adapt(msg, self.adapt_version)
649 649 to_send = self.serialize(msg, ident)
650 650 to_send.extend(buffers)
651 651 longest = max([ len(s) for s in to_send ])
652 652 copy = (longest < self.copy_threshold)
653 653
654 654 if buffers and track and not copy:
655 655 # only really track when we are doing zero-copy buffers
656 656 tracker = stream.send_multipart(to_send, copy=False, track=True)
657 657 else:
658 658 # use dummy tracker, which will be done immediately
659 659 tracker = DONE
660 660 stream.send_multipart(to_send, copy=copy)
661 661
662 662 if self.debug:
663 663 pprint.pprint(msg)
664 664 pprint.pprint(to_send)
665 665 pprint.pprint(buffers)
666 666
667 667 msg['tracker'] = tracker
668 668
669 669 return msg
670 670
671 671 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
672 672 """Send a raw message via ident path.
673 673
674 674 This method is used to send a already serialized message.
675 675
676 676 Parameters
677 677 ----------
678 678 stream : ZMQStream or Socket
679 679 The ZMQ stream or socket to use for sending the message.
680 680 msg_list : list
681 681 The serialized list of messages to send. This only includes the
682 682 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
683 683 the message.
684 684 ident : ident or list
685 685 A single ident or a list of idents to use in sending.
686 686 """
687 687 to_send = []
688 688 if isinstance(ident, bytes):
689 689 ident = [ident]
690 690 if ident is not None:
691 691 to_send.extend(ident)
692 692
693 693 to_send.append(DELIM)
694 694 to_send.append(self.sign(msg_list))
695 695 to_send.extend(msg_list)
696 696 stream.send_multipart(to_send, flags, copy=copy)
697 697
698 698 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
699 699 """Receive and unpack a message.
700 700
701 701 Parameters
702 702 ----------
703 703 socket : ZMQStream or Socket
704 704 The socket or stream to use in receiving.
705 705
706 706 Returns
707 707 -------
708 708 [idents], msg
709 709 [idents] is a list of idents and msg is a nested message dict of
710 710 same format as self.msg returns.
711 711 """
712 712 if isinstance(socket, ZMQStream):
713 713 socket = socket.socket
714 714 try:
715 715 msg_list = socket.recv_multipart(mode, copy=copy)
716 716 except zmq.ZMQError as e:
717 717 if e.errno == zmq.EAGAIN:
718 718 # We can convert EAGAIN to None as we know in this case
719 719 # recv_multipart won't return None.
720 720 return None,None
721 721 else:
722 722 raise
723 723 # split multipart message into identity list and message dict
724 724 # invalid large messages can cause very expensive string comparisons
725 725 idents, msg_list = self.feed_identities(msg_list, copy)
726 726 try:
727 727 return idents, self.deserialize(msg_list, content=content, copy=copy)
728 728 except Exception as e:
729 729 # TODO: handle it
730 730 raise e
731 731
732 732 def feed_identities(self, msg_list, copy=True):
733 733 """Split the identities from the rest of the message.
734 734
735 735 Feed until DELIM is reached, then return the prefix as idents and
736 736 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
737 737 but that would be silly.
738 738
739 739 Parameters
740 740 ----------
741 741 msg_list : a list of Message or bytes objects
742 742 The message to be split.
743 743 copy : bool
744 744 flag determining whether the arguments are bytes or Messages
745 745
746 746 Returns
747 747 -------
748 748 (idents, msg_list) : two lists
749 749 idents will always be a list of bytes, each of which is a ZMQ
750 750 identity. msg_list will be a list of bytes or zmq.Messages of the
751 751 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
752 752 should be unpackable/unserializable via self.deserialize at this
753 753 point.
754 754 """
755 755 if copy:
756 756 idx = msg_list.index(DELIM)
757 757 return msg_list[:idx], msg_list[idx+1:]
758 758 else:
759 759 failed = True
760 760 for idx,m in enumerate(msg_list):
761 761 if m.bytes == DELIM:
762 762 failed = False
763 763 break
764 764 if failed:
765 765 raise ValueError("DELIM not in msg_list")
766 766 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
767 767 return [m.bytes for m in idents], msg_list
768 768
769 769 def _add_digest(self, signature):
770 770 """add a digest to history to protect against replay attacks"""
771 771 if self.digest_history_size == 0:
772 772 # no history, never add digests
773 773 return
774 774
775 775 self.digest_history.add(signature)
776 776 if len(self.digest_history) > self.digest_history_size:
777 777 # threshold reached, cull 10%
778 778 self._cull_digest_history()
779 779
780 780 def _cull_digest_history(self):
781 781 """cull the digest history
782 782
783 783 Removes a randomly selected 10% of the digest history
784 784 """
785 785 current = len(self.digest_history)
786 786 n_to_cull = max(int(current // 10), current - self.digest_history_size)
787 787 if n_to_cull >= current:
788 788 self.digest_history = set()
789 789 return
790 790 to_cull = random.sample(self.digest_history, n_to_cull)
791 791 self.digest_history.difference_update(to_cull)
792 792
793 793 def deserialize(self, msg_list, content=True, copy=True):
794 794 """Unserialize a msg_list to a nested message dict.
795 795
796 796 This is roughly the inverse of serialize. The serialize/deserialize
797 797 methods work with full message lists, whereas pack/unpack work with
798 798 the individual message parts in the message list.
799 799
800 800 Parameters
801 801 ----------
802 802 msg_list : list of bytes or Message objects
803 803 The list of message parts of the form [HMAC,p_header,p_parent,
804 804 p_metadata,p_content,buffer1,buffer2,...].
805 805 content : bool (True)
806 806 Whether to unpack the content dict (True), or leave it packed
807 807 (False).
808 808 copy : bool (True)
809 809 Whether to return the bytes (True), or the non-copying Message
810 810 object in each place (False).
811 811
812 812 Returns
813 813 -------
814 814 msg : dict
815 815 The nested message dict with top-level keys [header, parent_header,
816 816 content, buffers].
817 817 """
818 818 minlen = 5
819 819 message = {}
820 820 if not copy:
821 821 for i in range(minlen):
822 822 msg_list[i] = msg_list[i].bytes
823 823 if self.auth is not None:
824 824 signature = msg_list[0]
825 825 if not signature:
826 826 raise ValueError("Unsigned Message")
827 827 if signature in self.digest_history:
828 828 raise ValueError("Duplicate Signature: %r" % signature)
829 829 self._add_digest(signature)
830 830 check = self.sign(msg_list[1:5])
831 831 if not compare_digest(signature, check):
832 832 raise ValueError("Invalid Signature: %r" % signature)
833 833 if not len(msg_list) >= minlen:
834 834 raise TypeError("malformed message, must have at least %i elements"%minlen)
835 835 header = self.unpack(msg_list[1])
836 836 message['header'] = extract_dates(header)
837 837 message['msg_id'] = header['msg_id']
838 838 message['msg_type'] = header['msg_type']
839 839 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
840 840 message['metadata'] = self.unpack(msg_list[3])
841 841 if content:
842 842 message['content'] = self.unpack(msg_list[4])
843 843 else:
844 844 message['content'] = msg_list[4]
845
846 message['buffers'] = msg_list[5:]
845 buffers = [memoryview(b) for b in msg_list[5:]]
846 if buffers and buffers[0].shape is None:
847 # force copy to workaround pyzmq #646
848 buffers = [memoryview(b.bytes) for b in msg_list[5:]]
849 message['buffers'] = buffers
847 850 # adapt to the current version
848 851 return adapt(message)
849 852
850 853 def unserialize(self, *args, **kwargs):
851 854 warnings.warn(
852 855 "Session.unserialize is deprecated. Use Session.deserialize.",
853 856 DeprecationWarning,
854 857 )
855 858 return self.deserialize(*args, **kwargs)
856 859
857 860
858 861 def test_msg2obj():
859 862 am = dict(x=1)
860 863 ao = Message(am)
861 864 assert ao.x == am['x']
862 865
863 866 am['y'] = dict(z=1)
864 867 ao = Message(am)
865 868 assert ao.y.z == am['y']['z']
866 869
867 870 k1, k2 = 'y', 'z'
868 871 assert ao[k1][k2] == am[k1][k2]
869 872
870 873 am2 = dict(ao)
871 874 assert am['x'] == am2['x']
872 875 assert am['y']['z'] == am2['y']['z']
873 876
General Comments 0
You need to be logged in to leave comments. Login now