##// END OF EJS Templates
minor cleanup
Paul Ivanov -
Show More
@@ -1,864 +1,864
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9 """
10 10
11 11 # Copyright (c) IPython Development Team.
12 12 # Distributed under the terms of the Modified BSD License.
13 13
14 14 import hashlib
15 15 import hmac
16 16 import logging
17 17 import os
18 18 import pprint
19 19 import random
20 20 import uuid
21 21 from datetime import datetime
22 22
23 23 try:
24 24 import cPickle
25 25 pickle = cPickle
26 26 except:
27 27 cPickle = None
28 28 import pickle
29 29
30 30 try:
31 31 from hmac import compare_digest
32 32 except ImportError:
33 33 # Python < 2.7.7
34 34 def compare_digest(a,b):
35 35 return a == b
36 36
37 37 import zmq
38 38 from zmq.utils import jsonapi
39 39 from zmq.eventloop.ioloop import IOLoop
40 40 from zmq.eventloop.zmqstream import ZMQStream
41 41
42 from IPython.core.release import kernel_protocol_version, kernel_protocol_version_info
42 from IPython.core.release import kernel_protocol_version
43 43 from IPython.config.configurable import Configurable, LoggingConfigurable
44 44 from IPython.utils import io
45 45 from IPython.utils.importstring import import_item
46 46 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
47 47 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
48 48 iteritems)
49 49 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
50 50 DottedObjectName, CUnicode, Dict, Integer,
51 51 TraitError,
52 52 )
53 53 from IPython.utils.pickleutil import PICKLE_PROTOCOL
54 54 from IPython.kernel.adapter import adapt
55 55 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
56 56
57 57 #-----------------------------------------------------------------------------
58 58 # utility functions
59 59 #-----------------------------------------------------------------------------
60 60
61 61 def squash_unicode(obj):
62 62 """coerce unicode back to bytestrings."""
63 63 if isinstance(obj,dict):
64 64 for key in obj.keys():
65 65 obj[key] = squash_unicode(obj[key])
66 66 if isinstance(key, unicode_type):
67 67 obj[squash_unicode(key)] = obj.pop(key)
68 68 elif isinstance(obj, list):
69 69 for i,v in enumerate(obj):
70 70 obj[i] = squash_unicode(v)
71 71 elif isinstance(obj, unicode_type):
72 72 obj = obj.encode('utf8')
73 73 return obj
74 74
75 75 #-----------------------------------------------------------------------------
76 76 # globals and defaults
77 77 #-----------------------------------------------------------------------------
78 78
79 79 # ISO8601-ify datetime objects
80 80 # allow unicode
81 81 # disallow nan, because it's not actually valid JSON
82 82 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
83 83 ensure_ascii=False, allow_nan=False,
84 84 )
85 85 json_unpacker = lambda s: jsonapi.loads(s)
86 86
87 87 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
88 88 pickle_unpacker = pickle.loads
89 89
90 90 default_packer = json_packer
91 91 default_unpacker = json_unpacker
92 92
93 93 DELIM = b"<IDS|MSG>"
94 94 # singleton dummy tracker, which will always report as done
95 95 DONE = zmq.MessageTracker()
96 96
97 97 #-----------------------------------------------------------------------------
98 98 # Mixin tools for apps that use Sessions
99 99 #-----------------------------------------------------------------------------
100 100
101 101 session_aliases = dict(
102 102 ident = 'Session.session',
103 103 user = 'Session.username',
104 104 keyfile = 'Session.keyfile',
105 105 )
106 106
107 107 session_flags = {
108 108 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
109 109 'keyfile' : '' }},
110 110 """Use HMAC digests for authentication of messages.
111 111 Setting this flag will generate a new UUID to use as the HMAC key.
112 112 """),
113 113 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
114 114 """Don't authenticate messages."""),
115 115 }
116 116
117 117 def default_secure(cfg):
118 118 """Set the default behavior for a config environment to be secure.
119 119
120 120 If Session.key/keyfile have not been set, set Session.key to
121 121 a new random UUID.
122 122 """
123 123
124 124 if 'Session' in cfg:
125 125 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
126 126 return
127 127 # key/keyfile not specified, generate new UUID:
128 128 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
129 129
130 130
131 131 #-----------------------------------------------------------------------------
132 132 # Classes
133 133 #-----------------------------------------------------------------------------
134 134
135 135 class SessionFactory(LoggingConfigurable):
136 136 """The Base class for configurables that have a Session, Context, logger,
137 137 and IOLoop.
138 138 """
139 139
140 140 logname = Unicode('')
141 141 def _logname_changed(self, name, old, new):
142 142 self.log = logging.getLogger(new)
143 143
144 144 # not configurable:
145 145 context = Instance('zmq.Context')
146 146 def _context_default(self):
147 147 return zmq.Context.instance()
148 148
149 149 session = Instance('IPython.kernel.zmq.session.Session')
150 150
151 151 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
152 152 def _loop_default(self):
153 153 return IOLoop.instance()
154 154
155 155 def __init__(self, **kwargs):
156 156 super(SessionFactory, self).__init__(**kwargs)
157 157
158 158 if self.session is None:
159 159 # construct the session
160 160 self.session = Session(**kwargs)
161 161
162 162
163 163 class Message(object):
164 164 """A simple message object that maps dict keys to attributes.
165 165
166 166 A Message can be created from a dict and a dict from a Message instance
167 167 simply by calling dict(msg_obj)."""
168 168
169 169 def __init__(self, msg_dict):
170 170 dct = self.__dict__
171 171 for k, v in iteritems(dict(msg_dict)):
172 172 if isinstance(v, dict):
173 173 v = Message(v)
174 174 dct[k] = v
175 175
176 176 # Having this iterator lets dict(msg_obj) work out of the box.
177 177 def __iter__(self):
178 178 return iter(iteritems(self.__dict__))
179 179
180 180 def __repr__(self):
181 181 return repr(self.__dict__)
182 182
183 183 def __str__(self):
184 184 return pprint.pformat(self.__dict__)
185 185
186 186 def __contains__(self, k):
187 187 return k in self.__dict__
188 188
189 189 def __getitem__(self, k):
190 190 return self.__dict__[k]
191 191
192 192
193 193 def msg_header(msg_id, msg_type, username, session):
194 194 date = datetime.now()
195 195 version = kernel_protocol_version
196 196 return locals()
197 197
198 198 def extract_header(msg_or_header):
199 199 """Given a message or header, return the header."""
200 200 if not msg_or_header:
201 201 return {}
202 202 try:
203 203 # See if msg_or_header is the entire message.
204 204 h = msg_or_header['header']
205 205 except KeyError:
206 206 try:
207 207 # See if msg_or_header is just the header
208 208 h = msg_or_header['msg_id']
209 209 except KeyError:
210 210 raise
211 211 else:
212 212 h = msg_or_header
213 213 if not isinstance(h, dict):
214 214 h = dict(h)
215 215 return h
216 216
217 217 class Session(Configurable):
218 218 """Object for handling serialization and sending of messages.
219 219
220 220 The Session object handles building messages and sending them
221 221 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
222 222 other over the network via Session objects, and only need to work with the
223 223 dict-based IPython message spec. The Session will handle
224 224 serialization/deserialization, security, and metadata.
225 225
226 Sessions support configurable serialiization via packer/unpacker traits,
226 Sessions support configurable serialization via packer/unpacker traits,
227 227 and signing with HMAC digests via the key/keyfile traits.
228 228
229 229 Parameters
230 230 ----------
231 231
232 232 debug : bool
233 233 whether to trigger extra debugging statements
234 234 packer/unpacker : str : 'json', 'pickle' or import_string
235 235 importstrings for methods to serialize message parts. If just
236 236 'json' or 'pickle', predefined JSON and pickle packers will be used.
237 237 Otherwise, the entire importstring must be used.
238 238
239 239 The functions must accept at least valid JSON input, and output *bytes*.
240 240
241 241 For example, to use msgpack:
242 242 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
243 243 pack/unpack : callables
244 244 You can also set the pack/unpack callables for serialization directly.
245 245 session : bytes
246 246 the ID of this Session object. The default is to generate a new UUID.
247 247 username : unicode
248 248 username added to message headers. The default is to ask the OS.
249 249 key : bytes
250 250 The key used to initialize an HMAC signature. If unset, messages
251 251 will not be signed or checked.
252 252 keyfile : filepath
253 253 The file containing a key. If this is set, `key` will be initialized
254 254 to the contents of the file.
255 255
256 256 """
257 257
258 258 debug=Bool(False, config=True, help="""Debug output in the Session""")
259 259
260 260 packer = DottedObjectName('json',config=True,
261 261 help="""The name of the packer for serializing messages.
262 262 Should be one of 'json', 'pickle', or an import name
263 263 for a custom callable serializer.""")
264 264 def _packer_changed(self, name, old, new):
265 265 if new.lower() == 'json':
266 266 self.pack = json_packer
267 267 self.unpack = json_unpacker
268 268 self.unpacker = new
269 269 elif new.lower() == 'pickle':
270 270 self.pack = pickle_packer
271 271 self.unpack = pickle_unpacker
272 272 self.unpacker = new
273 273 else:
274 274 self.pack = import_item(str(new))
275 275
276 276 unpacker = DottedObjectName('json', config=True,
277 277 help="""The name of the unpacker for unserializing messages.
278 278 Only used with custom functions for `packer`.""")
279 279 def _unpacker_changed(self, name, old, new):
280 280 if new.lower() == 'json':
281 281 self.pack = json_packer
282 282 self.unpack = json_unpacker
283 283 self.packer = new
284 284 elif new.lower() == 'pickle':
285 285 self.pack = pickle_packer
286 286 self.unpack = pickle_unpacker
287 287 self.packer = new
288 288 else:
289 289 self.unpack = import_item(str(new))
290 290
291 291 session = CUnicode(u'', config=True,
292 292 help="""The UUID identifying this session.""")
293 293 def _session_default(self):
294 294 u = unicode_type(uuid.uuid4())
295 295 self.bsession = u.encode('ascii')
296 296 return u
297 297
298 298 def _session_changed(self, name, old, new):
299 299 self.bsession = self.session.encode('ascii')
300 300
301 301 # bsession is the session as bytes
302 302 bsession = CBytes(b'')
303 303
304 304 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
305 305 help="""Username for the Session. Default is your system username.""",
306 306 config=True)
307 307
308 308 metadata = Dict({}, config=True,
309 309 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
310 310
311 311 # if 0, no adapting to do.
312 312 adapt_version = Integer(0)
313 313
314 314 # message signature related traits:
315 315
316 316 key = CBytes(b'', config=True,
317 317 help="""execution key, for extra authentication.""")
318 318 def _key_changed(self):
319 319 self._new_auth()
320 320
321 321 signature_scheme = Unicode('hmac-sha256', config=True,
322 322 help="""The digest scheme used to construct the message signatures.
323 323 Must have the form 'hmac-HASH'.""")
324 324 def _signature_scheme_changed(self, name, old, new):
325 325 if not new.startswith('hmac-'):
326 326 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
327 327 hash_name = new.split('-', 1)[1]
328 328 try:
329 329 self.digest_mod = getattr(hashlib, hash_name)
330 330 except AttributeError:
331 331 raise TraitError("hashlib has no such attribute: %s" % hash_name)
332 332 self._new_auth()
333 333
334 334 digest_mod = Any()
335 335 def _digest_mod_default(self):
336 336 return hashlib.sha256
337 337
338 338 auth = Instance(hmac.HMAC)
339 339
340 340 def _new_auth(self):
341 341 if self.key:
342 342 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
343 343 else:
344 344 self.auth = None
345 345
346 346 digest_history = Set()
347 347 digest_history_size = Integer(2**16, config=True,
348 348 help="""The maximum number of digests to remember.
349 349
350 350 The digest history will be culled when it exceeds this value.
351 351 """
352 352 )
353 353
354 354 keyfile = Unicode('', config=True,
355 355 help="""path to file containing execution key.""")
356 356 def _keyfile_changed(self, name, old, new):
357 357 with open(new, 'rb') as f:
358 358 self.key = f.read().strip()
359 359
360 360 # for protecting against sends from forks
361 361 pid = Integer()
362 362
363 363 # serialization traits:
364 364
365 365 pack = Any(default_packer) # the actual packer function
366 366 def _pack_changed(self, name, old, new):
367 367 if not callable(new):
368 368 raise TypeError("packer must be callable, not %s"%type(new))
369 369
370 370 unpack = Any(default_unpacker) # the actual packer function
371 371 def _unpack_changed(self, name, old, new):
372 372 # unpacker is not checked - it is assumed to be
373 373 if not callable(new):
374 374 raise TypeError("unpacker must be callable, not %s"%type(new))
375 375
376 376 # thresholds:
377 377 copy_threshold = Integer(2**16, config=True,
378 378 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
379 379 buffer_threshold = Integer(MAX_BYTES, config=True,
380 380 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
381 381 item_threshold = Integer(MAX_ITEMS, config=True,
382 382 help="""The maximum number of items for a container to be introspected for custom serialization.
383 383 Containers larger than this are pickled outright.
384 384 """
385 385 )
386 386
387 387
388 388 def __init__(self, **kwargs):
389 389 """create a Session object
390 390
391 391 Parameters
392 392 ----------
393 393
394 394 debug : bool
395 395 whether to trigger extra debugging statements
396 396 packer/unpacker : str : 'json', 'pickle' or import_string
397 397 importstrings for methods to serialize message parts. If just
398 398 'json' or 'pickle', predefined JSON and pickle packers will be used.
399 399 Otherwise, the entire importstring must be used.
400 400
401 401 The functions must accept at least valid JSON input, and output
402 402 *bytes*.
403 403
404 404 For example, to use msgpack:
405 405 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
406 406 pack/unpack : callables
407 407 You can also set the pack/unpack callables for serialization
408 408 directly.
409 409 session : unicode (must be ascii)
410 410 the ID of this Session object. The default is to generate a new
411 411 UUID.
412 412 bsession : bytes
413 413 The session as bytes
414 414 username : unicode
415 415 username added to message headers. The default is to ask the OS.
416 416 key : bytes
417 417 The key used to initialize an HMAC signature. If unset, messages
418 418 will not be signed or checked.
419 419 signature_scheme : str
420 420 The message digest scheme. Currently must be of the form 'hmac-HASH',
421 421 where 'HASH' is a hashing function available in Python's hashlib.
422 422 The default is 'hmac-sha256'.
423 423 This is ignored if 'key' is empty.
424 424 keyfile : filepath
425 425 The file containing a key. If this is set, `key` will be
426 426 initialized to the contents of the file.
427 427 """
428 428 super(Session, self).__init__(**kwargs)
429 429 self._check_packers()
430 430 self.none = self.pack({})
431 431 # ensure self._session_default() if necessary, so bsession is defined:
432 432 self.session
433 433 self.pid = os.getpid()
434 434
435 435 @property
436 436 def msg_id(self):
437 437 """always return new uuid"""
438 438 return str(uuid.uuid4())
439 439
440 440 def _check_packers(self):
441 441 """check packers for datetime support."""
442 442 pack = self.pack
443 443 unpack = self.unpack
444 444
445 445 # check simple serialization
446 446 msg = dict(a=[1,'hi'])
447 447 try:
448 448 packed = pack(msg)
449 449 except Exception as e:
450 450 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
451 451 if self.packer == 'json':
452 452 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
453 453 else:
454 454 jsonmsg = ""
455 455 raise ValueError(
456 456 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
457 457 )
458 458
459 459 # ensure packed message is bytes
460 460 if not isinstance(packed, bytes):
461 461 raise ValueError("message packed to %r, but bytes are required"%type(packed))
462 462
463 463 # check that unpack is pack's inverse
464 464 try:
465 465 unpacked = unpack(packed)
466 466 assert unpacked == msg
467 467 except Exception as e:
468 468 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
469 469 if self.packer == 'json':
470 470 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
471 471 else:
472 472 jsonmsg = ""
473 473 raise ValueError(
474 474 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
475 475 )
476 476
477 477 # check datetime support
478 478 msg = dict(t=datetime.now())
479 479 try:
480 480 unpacked = unpack(pack(msg))
481 481 if isinstance(unpacked['t'], datetime):
482 482 raise ValueError("Shouldn't deserialize to datetime")
483 483 except Exception:
484 484 self.pack = lambda o: pack(squash_dates(o))
485 485 self.unpack = lambda s: unpack(s)
486 486
487 487 def msg_header(self, msg_type):
488 488 return msg_header(self.msg_id, msg_type, self.username, self.session)
489 489
490 490 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
491 491 """Return the nested message dict.
492 492
493 493 This format is different from what is sent over the wire. The
494 494 serialize/unserialize methods converts this nested message dict to the wire
495 495 format, which is a list of message parts.
496 496 """
497 497 msg = {}
498 498 header = self.msg_header(msg_type) if header is None else header
499 499 msg['header'] = header
500 500 msg['msg_id'] = header['msg_id']
501 501 msg['msg_type'] = header['msg_type']
502 502 msg['parent_header'] = {} if parent is None else extract_header(parent)
503 503 msg['content'] = {} if content is None else content
504 504 msg['metadata'] = self.metadata.copy()
505 505 if metadata is not None:
506 506 msg['metadata'].update(metadata)
507 507 return msg
508 508
509 509 def sign(self, msg_list):
510 510 """Sign a message with HMAC digest. If no auth, return b''.
511 511
512 512 Parameters
513 513 ----------
514 514 msg_list : list
515 515 The [p_header,p_parent,p_content] part of the message list.
516 516 """
517 517 if self.auth is None:
518 518 return b''
519 519 h = self.auth.copy()
520 520 for m in msg_list:
521 521 h.update(m)
522 522 return str_to_bytes(h.hexdigest())
523 523
524 524 def serialize(self, msg, ident=None):
525 525 """Serialize the message components to bytes.
526 526
527 527 This is roughly the inverse of unserialize. The serialize/unserialize
528 528 methods work with full message lists, whereas pack/unpack work with
529 529 the individual message parts in the message list.
530 530
531 531 Parameters
532 532 ----------
533 533 msg : dict or Message
534 The nexted message dict as returned by the self.msg method.
534 The next message dict as returned by the self.msg method.
535 535
536 536 Returns
537 537 -------
538 538 msg_list : list
539 539 The list of bytes objects to be sent with the format::
540 540
541 541 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
542 542 p_metadata, p_content, buffer1, buffer2, ...]
543 543
544 544 In this list, the ``p_*`` entities are the packed or serialized
545 545 versions, so if JSON is used, these are utf8 encoded JSON strings.
546 546 """
547 547 content = msg.get('content', {})
548 548 if content is None:
549 549 content = self.none
550 550 elif isinstance(content, dict):
551 551 content = self.pack(content)
552 552 elif isinstance(content, bytes):
553 553 # content is already packed, as in a relayed message
554 554 pass
555 555 elif isinstance(content, unicode_type):
556 556 # should be bytes, but JSON often spits out unicode
557 557 content = content.encode('utf8')
558 558 else:
559 559 raise TypeError("Content incorrect type: %s"%type(content))
560 560
561 561 real_message = [self.pack(msg['header']),
562 562 self.pack(msg['parent_header']),
563 563 self.pack(msg['metadata']),
564 564 content,
565 565 ]
566 566
567 567 to_send = []
568 568
569 569 if isinstance(ident, list):
570 570 # accept list of idents
571 571 to_send.extend(ident)
572 572 elif ident is not None:
573 573 to_send.append(ident)
574 574 to_send.append(DELIM)
575 575
576 576 signature = self.sign(real_message)
577 577 to_send.append(signature)
578 578
579 579 to_send.extend(real_message)
580 580
581 581 return to_send
582 582
583 583 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
584 584 buffers=None, track=False, header=None, metadata=None):
585 585 """Build and send a message via stream or socket.
586 586
587 587 The message format used by this function internally is as follows:
588 588
589 589 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
590 590 buffer1,buffer2,...]
591 591
592 592 The serialize/unserialize methods convert the nested message dict into this
593 593 format.
594 594
595 595 Parameters
596 596 ----------
597 597
598 598 stream : zmq.Socket or ZMQStream
599 599 The socket-like object used to send the data.
600 600 msg_or_type : str or Message/dict
601 601 Normally, msg_or_type will be a msg_type unless a message is being
602 602 sent more than once. If a header is supplied, this can be set to
603 603 None and the msg_type will be pulled from the header.
604 604
605 605 content : dict or None
606 606 The content of the message (ignored if msg_or_type is a message).
607 607 header : dict or None
608 608 The header dict for the message (ignored if msg_to_type is a message).
609 609 parent : Message or dict or None
610 610 The parent or parent header describing the parent of this message
611 611 (ignored if msg_or_type is a message).
612 612 ident : bytes or list of bytes
613 613 The zmq.IDENTITY routing path.
614 614 metadata : dict or None
615 615 The metadata describing the message
616 616 buffers : list or None
617 617 The already-serialized buffers to be appended to the message.
618 618 track : bool
619 619 Whether to track. Only for use with Sockets, because ZMQStream
620 620 objects cannot track messages.
621 621
622 622
623 623 Returns
624 624 -------
625 625 msg : dict
626 626 The constructed message.
627 627 """
628 628 if not isinstance(stream, zmq.Socket):
629 629 # ZMQStreams and dummy sockets do not support tracking.
630 630 track = False
631 631
632 632 if isinstance(msg_or_type, (Message, dict)):
633 633 # We got a Message or message dict, not a msg_type so don't
634 634 # build a new Message.
635 635 msg = msg_or_type
636 636 else:
637 637 msg = self.msg(msg_or_type, content=content, parent=parent,
638 638 header=header, metadata=metadata)
639 639 if not os.getpid() == self.pid:
640 640 io.rprint("WARNING: attempted to send message from fork")
641 641 io.rprint(msg)
642 642 return
643 643 buffers = [] if buffers is None else buffers
644 644 if self.adapt_version:
645 645 msg = adapt(msg, self.adapt_version)
646 646 to_send = self.serialize(msg, ident)
647 647 to_send.extend(buffers)
648 648 longest = max([ len(s) for s in to_send ])
649 649 copy = (longest < self.copy_threshold)
650 650
651 651 if buffers and track and not copy:
652 652 # only really track when we are doing zero-copy buffers
653 653 tracker = stream.send_multipart(to_send, copy=False, track=True)
654 654 else:
655 655 # use dummy tracker, which will be done immediately
656 656 tracker = DONE
657 657 stream.send_multipart(to_send, copy=copy)
658 658
659 659 if self.debug:
660 660 pprint.pprint(msg)
661 661 pprint.pprint(to_send)
662 662 pprint.pprint(buffers)
663 663
664 664 msg['tracker'] = tracker
665 665
666 666 return msg
667 667
668 668 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
669 669 """Send a raw message via ident path.
670 670
671 671 This method is used to send a already serialized message.
672 672
673 673 Parameters
674 674 ----------
675 675 stream : ZMQStream or Socket
676 676 The ZMQ stream or socket to use for sending the message.
677 677 msg_list : list
678 678 The serialized list of messages to send. This only includes the
679 679 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
680 680 the message.
681 681 ident : ident or list
682 682 A single ident or a list of idents to use in sending.
683 683 """
684 684 to_send = []
685 685 if isinstance(ident, bytes):
686 686 ident = [ident]
687 687 if ident is not None:
688 688 to_send.extend(ident)
689 689
690 690 to_send.append(DELIM)
691 691 to_send.append(self.sign(msg_list))
692 692 to_send.extend(msg_list)
693 693 stream.send_multipart(to_send, flags, copy=copy)
694 694
695 695 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
696 696 """Receive and unpack a message.
697 697
698 698 Parameters
699 699 ----------
700 700 socket : ZMQStream or Socket
701 701 The socket or stream to use in receiving.
702 702
703 703 Returns
704 704 -------
705 705 [idents], msg
706 706 [idents] is a list of idents and msg is a nested message dict of
707 707 same format as self.msg returns.
708 708 """
709 709 if isinstance(socket, ZMQStream):
710 710 socket = socket.socket
711 711 try:
712 712 msg_list = socket.recv_multipart(mode, copy=copy)
713 713 except zmq.ZMQError as e:
714 714 if e.errno == zmq.EAGAIN:
715 715 # We can convert EAGAIN to None as we know in this case
716 716 # recv_multipart won't return None.
717 717 return None,None
718 718 else:
719 719 raise
720 720 # split multipart message into identity list and message dict
721 721 # invalid large messages can cause very expensive string comparisons
722 722 idents, msg_list = self.feed_identities(msg_list, copy)
723 723 try:
724 724 return idents, self.unserialize(msg_list, content=content, copy=copy)
725 725 except Exception as e:
726 726 # TODO: handle it
727 727 raise e
728 728
729 729 def feed_identities(self, msg_list, copy=True):
730 730 """Split the identities from the rest of the message.
731 731
732 732 Feed until DELIM is reached, then return the prefix as idents and
733 733 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
734 734 but that would be silly.
735 735
736 736 Parameters
737 737 ----------
738 738 msg_list : a list of Message or bytes objects
739 739 The message to be split.
740 740 copy : bool
741 741 flag determining whether the arguments are bytes or Messages
742 742
743 743 Returns
744 744 -------
745 745 (idents, msg_list) : two lists
746 746 idents will always be a list of bytes, each of which is a ZMQ
747 747 identity. msg_list will be a list of bytes or zmq.Messages of the
748 748 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
749 749 should be unpackable/unserializable via self.unserialize at this
750 750 point.
751 751 """
752 752 if copy:
753 753 idx = msg_list.index(DELIM)
754 754 return msg_list[:idx], msg_list[idx+1:]
755 755 else:
756 756 failed = True
757 757 for idx,m in enumerate(msg_list):
758 758 if m.bytes == DELIM:
759 759 failed = False
760 760 break
761 761 if failed:
762 762 raise ValueError("DELIM not in msg_list")
763 763 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
764 764 return [m.bytes for m in idents], msg_list
765 765
766 766 def _add_digest(self, signature):
767 767 """add a digest to history to protect against replay attacks"""
768 768 if self.digest_history_size == 0:
769 769 # no history, never add digests
770 770 return
771 771
772 772 self.digest_history.add(signature)
773 773 if len(self.digest_history) > self.digest_history_size:
774 774 # threshold reached, cull 10%
775 775 self._cull_digest_history()
776 776
777 777 def _cull_digest_history(self):
778 778 """cull the digest history
779 779
780 780 Removes a randomly selected 10% of the digest history
781 781 """
782 782 current = len(self.digest_history)
783 783 n_to_cull = max(int(current // 10), current - self.digest_history_size)
784 784 if n_to_cull >= current:
785 785 self.digest_history = set()
786 786 return
787 787 to_cull = random.sample(self.digest_history, n_to_cull)
788 788 self.digest_history.difference_update(to_cull)
789 789
790 790 def unserialize(self, msg_list, content=True, copy=True):
791 791 """Unserialize a msg_list to a nested message dict.
792 792
793 793 This is roughly the inverse of serialize. The serialize/unserialize
794 794 methods work with full message lists, whereas pack/unpack work with
795 795 the individual message parts in the message list.
796 796
797 797 Parameters
798 798 ----------
799 799 msg_list : list of bytes or Message objects
800 800 The list of message parts of the form [HMAC,p_header,p_parent,
801 801 p_metadata,p_content,buffer1,buffer2,...].
802 802 content : bool (True)
803 803 Whether to unpack the content dict (True), or leave it packed
804 804 (False).
805 805 copy : bool (True)
806 806 Whether to return the bytes (True), or the non-copying Message
807 807 object in each place (False).
808 808
809 809 Returns
810 810 -------
811 811 msg : dict
812 812 The nested message dict with top-level keys [header, parent_header,
813 813 content, buffers].
814 814 """
815 815 minlen = 5
816 816 message = {}
817 817 if not copy:
818 818 for i in range(minlen):
819 819 msg_list[i] = msg_list[i].bytes
820 820 if self.auth is not None:
821 821 signature = msg_list[0]
822 822 if not signature:
823 823 raise ValueError("Unsigned Message")
824 824 if signature in self.digest_history:
825 825 raise ValueError("Duplicate Signature: %r" % signature)
826 826 self._add_digest(signature)
827 827 check = self.sign(msg_list[1:5])
828 828 if not compare_digest(signature, check):
829 829 raise ValueError("Invalid Signature: %r" % signature)
830 830 if not len(msg_list) >= minlen:
831 831 raise TypeError("malformed message, must have at least %i elements"%minlen)
832 832 header = self.unpack(msg_list[1])
833 833 message['header'] = extract_dates(header)
834 834 message['msg_id'] = header['msg_id']
835 835 message['msg_type'] = header['msg_type']
836 836 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
837 837 message['metadata'] = self.unpack(msg_list[3])
838 838 if content:
839 839 message['content'] = self.unpack(msg_list[4])
840 840 else:
841 841 message['content'] = msg_list[4]
842 842
843 843 message['buffers'] = msg_list[5:]
844 844 # print("received: %s: %s\n %s" % (message['msg_type'], message['header'], message['content']))
845 845 # adapt to the current version
846 846 return adapt(message)
847 847 # print("adapted: %s: %s\n %s" % (adapted['msg_type'], adapted['header'], adapted['content']))
848 848
849 849 def test_msg2obj():
850 850 am = dict(x=1)
851 851 ao = Message(am)
852 852 assert ao.x == am['x']
853 853
854 854 am['y'] = dict(z=1)
855 855 ao = Message(am)
856 856 assert ao.y.z == am['y']['z']
857 857
858 858 k1, k2 = 'y', 'z'
859 859 assert ao[k1][k2] == am[k1][k2]
860 860
861 861 am2 = dict(ao)
862 862 assert am['x'] == am2['x']
863 863 assert am['y']['z'] == am2['y']['z']
864 864
General Comments 0
You need to be logged in to leave comments. Login now