##// END OF EJS Templates
update comment...
Kyle Kelley -
Show More
@@ -1,868 +1,865 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9 """
10 10
11 11 # Copyright (c) IPython Development Team.
12 12 # Distributed under the terms of the Modified BSD License.
13 13
14 14 import hashlib
15 15 import hmac
16 16 import logging
17 17 import os
18 18 import pprint
19 19 import random
20 20 import uuid
21 21 from datetime import datetime
22 22
23 23 try:
24 24 import cPickle
25 25 pickle = cPickle
26 26 except:
27 27 cPickle = None
28 28 import pickle
29 29
30 30 try:
31 # We are using compare_digest to limit the surface of timing attacks
31 32 from hmac import compare_digest
32 33 except ImportError:
33 # Python < 2.7.7
34 import warnings
35 warnings.warn("You are using Python older than 2.7.7, please consider "
36 "updating to the latest version as it reduces a possible security"
37 " vulnerability.")
38 def compare_digest(a,b):
39 return a == b
34 # Python < 2.7.7: When digests don't match no feedback is provided,
35 # limiting the surface of attack
36 def compare_digest(a,b): return a == b
40 37
41 38 import zmq
42 39 from zmq.utils import jsonapi
43 40 from zmq.eventloop.ioloop import IOLoop
44 41 from zmq.eventloop.zmqstream import ZMQStream
45 42
46 43 from IPython.core.release import kernel_protocol_version
47 44 from IPython.config.configurable import Configurable, LoggingConfigurable
48 45 from IPython.utils import io
49 46 from IPython.utils.importstring import import_item
50 47 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
51 48 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
52 49 iteritems)
53 50 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
54 51 DottedObjectName, CUnicode, Dict, Integer,
55 52 TraitError,
56 53 )
57 54 from IPython.utils.pickleutil import PICKLE_PROTOCOL
58 55 from IPython.kernel.adapter import adapt
59 56 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
60 57
61 58 #-----------------------------------------------------------------------------
62 59 # utility functions
63 60 #-----------------------------------------------------------------------------
64 61
65 62 def squash_unicode(obj):
66 63 """coerce unicode back to bytestrings."""
67 64 if isinstance(obj,dict):
68 65 for key in obj.keys():
69 66 obj[key] = squash_unicode(obj[key])
70 67 if isinstance(key, unicode_type):
71 68 obj[squash_unicode(key)] = obj.pop(key)
72 69 elif isinstance(obj, list):
73 70 for i,v in enumerate(obj):
74 71 obj[i] = squash_unicode(v)
75 72 elif isinstance(obj, unicode_type):
76 73 obj = obj.encode('utf8')
77 74 return obj
78 75
79 76 #-----------------------------------------------------------------------------
80 77 # globals and defaults
81 78 #-----------------------------------------------------------------------------
82 79
83 80 # ISO8601-ify datetime objects
84 81 # allow unicode
85 82 # disallow nan, because it's not actually valid JSON
86 83 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
87 84 ensure_ascii=False, allow_nan=False,
88 85 )
89 86 json_unpacker = lambda s: jsonapi.loads(s)
90 87
91 88 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
92 89 pickle_unpacker = pickle.loads
93 90
94 91 default_packer = json_packer
95 92 default_unpacker = json_unpacker
96 93
97 94 DELIM = b"<IDS|MSG>"
98 95 # singleton dummy tracker, which will always report as done
99 96 DONE = zmq.MessageTracker()
100 97
101 98 #-----------------------------------------------------------------------------
102 99 # Mixin tools for apps that use Sessions
103 100 #-----------------------------------------------------------------------------
104 101
105 102 session_aliases = dict(
106 103 ident = 'Session.session',
107 104 user = 'Session.username',
108 105 keyfile = 'Session.keyfile',
109 106 )
110 107
111 108 session_flags = {
112 109 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
113 110 'keyfile' : '' }},
114 111 """Use HMAC digests for authentication of messages.
115 112 Setting this flag will generate a new UUID to use as the HMAC key.
116 113 """),
117 114 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
118 115 """Don't authenticate messages."""),
119 116 }
120 117
121 118 def default_secure(cfg):
122 119 """Set the default behavior for a config environment to be secure.
123 120
124 121 If Session.key/keyfile have not been set, set Session.key to
125 122 a new random UUID.
126 123 """
127 124
128 125 if 'Session' in cfg:
129 126 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
130 127 return
131 128 # key/keyfile not specified, generate new UUID:
132 129 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
133 130
134 131
135 132 #-----------------------------------------------------------------------------
136 133 # Classes
137 134 #-----------------------------------------------------------------------------
138 135
139 136 class SessionFactory(LoggingConfigurable):
140 137 """The Base class for configurables that have a Session, Context, logger,
141 138 and IOLoop.
142 139 """
143 140
144 141 logname = Unicode('')
145 142 def _logname_changed(self, name, old, new):
146 143 self.log = logging.getLogger(new)
147 144
148 145 # not configurable:
149 146 context = Instance('zmq.Context')
150 147 def _context_default(self):
151 148 return zmq.Context.instance()
152 149
153 150 session = Instance('IPython.kernel.zmq.session.Session')
154 151
155 152 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
156 153 def _loop_default(self):
157 154 return IOLoop.instance()
158 155
159 156 def __init__(self, **kwargs):
160 157 super(SessionFactory, self).__init__(**kwargs)
161 158
162 159 if self.session is None:
163 160 # construct the session
164 161 self.session = Session(**kwargs)
165 162
166 163
167 164 class Message(object):
168 165 """A simple message object that maps dict keys to attributes.
169 166
170 167 A Message can be created from a dict and a dict from a Message instance
171 168 simply by calling dict(msg_obj)."""
172 169
173 170 def __init__(self, msg_dict):
174 171 dct = self.__dict__
175 172 for k, v in iteritems(dict(msg_dict)):
176 173 if isinstance(v, dict):
177 174 v = Message(v)
178 175 dct[k] = v
179 176
180 177 # Having this iterator lets dict(msg_obj) work out of the box.
181 178 def __iter__(self):
182 179 return iter(iteritems(self.__dict__))
183 180
184 181 def __repr__(self):
185 182 return repr(self.__dict__)
186 183
187 184 def __str__(self):
188 185 return pprint.pformat(self.__dict__)
189 186
190 187 def __contains__(self, k):
191 188 return k in self.__dict__
192 189
193 190 def __getitem__(self, k):
194 191 return self.__dict__[k]
195 192
196 193
197 194 def msg_header(msg_id, msg_type, username, session):
198 195 date = datetime.now()
199 196 version = kernel_protocol_version
200 197 return locals()
201 198
202 199 def extract_header(msg_or_header):
203 200 """Given a message or header, return the header."""
204 201 if not msg_or_header:
205 202 return {}
206 203 try:
207 204 # See if msg_or_header is the entire message.
208 205 h = msg_or_header['header']
209 206 except KeyError:
210 207 try:
211 208 # See if msg_or_header is just the header
212 209 h = msg_or_header['msg_id']
213 210 except KeyError:
214 211 raise
215 212 else:
216 213 h = msg_or_header
217 214 if not isinstance(h, dict):
218 215 h = dict(h)
219 216 return h
220 217
221 218 class Session(Configurable):
222 219 """Object for handling serialization and sending of messages.
223 220
224 221 The Session object handles building messages and sending them
225 222 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
226 223 other over the network via Session objects, and only need to work with the
227 224 dict-based IPython message spec. The Session will handle
228 225 serialization/deserialization, security, and metadata.
229 226
230 227 Sessions support configurable serialization via packer/unpacker traits,
231 228 and signing with HMAC digests via the key/keyfile traits.
232 229
233 230 Parameters
234 231 ----------
235 232
236 233 debug : bool
237 234 whether to trigger extra debugging statements
238 235 packer/unpacker : str : 'json', 'pickle' or import_string
239 236 importstrings for methods to serialize message parts. If just
240 237 'json' or 'pickle', predefined JSON and pickle packers will be used.
241 238 Otherwise, the entire importstring must be used.
242 239
243 240 The functions must accept at least valid JSON input, and output *bytes*.
244 241
245 242 For example, to use msgpack:
246 243 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
247 244 pack/unpack : callables
248 245 You can also set the pack/unpack callables for serialization directly.
249 246 session : bytes
250 247 the ID of this Session object. The default is to generate a new UUID.
251 248 username : unicode
252 249 username added to message headers. The default is to ask the OS.
253 250 key : bytes
254 251 The key used to initialize an HMAC signature. If unset, messages
255 252 will not be signed or checked.
256 253 keyfile : filepath
257 254 The file containing a key. If this is set, `key` will be initialized
258 255 to the contents of the file.
259 256
260 257 """
261 258
262 259 debug=Bool(False, config=True, help="""Debug output in the Session""")
263 260
264 261 packer = DottedObjectName('json',config=True,
265 262 help="""The name of the packer for serializing messages.
266 263 Should be one of 'json', 'pickle', or an import name
267 264 for a custom callable serializer.""")
268 265 def _packer_changed(self, name, old, new):
269 266 if new.lower() == 'json':
270 267 self.pack = json_packer
271 268 self.unpack = json_unpacker
272 269 self.unpacker = new
273 270 elif new.lower() == 'pickle':
274 271 self.pack = pickle_packer
275 272 self.unpack = pickle_unpacker
276 273 self.unpacker = new
277 274 else:
278 275 self.pack = import_item(str(new))
279 276
280 277 unpacker = DottedObjectName('json', config=True,
281 278 help="""The name of the unpacker for unserializing messages.
282 279 Only used with custom functions for `packer`.""")
283 280 def _unpacker_changed(self, name, old, new):
284 281 if new.lower() == 'json':
285 282 self.pack = json_packer
286 283 self.unpack = json_unpacker
287 284 self.packer = new
288 285 elif new.lower() == 'pickle':
289 286 self.pack = pickle_packer
290 287 self.unpack = pickle_unpacker
291 288 self.packer = new
292 289 else:
293 290 self.unpack = import_item(str(new))
294 291
295 292 session = CUnicode(u'', config=True,
296 293 help="""The UUID identifying this session.""")
297 294 def _session_default(self):
298 295 u = unicode_type(uuid.uuid4())
299 296 self.bsession = u.encode('ascii')
300 297 return u
301 298
302 299 def _session_changed(self, name, old, new):
303 300 self.bsession = self.session.encode('ascii')
304 301
305 302 # bsession is the session as bytes
306 303 bsession = CBytes(b'')
307 304
308 305 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
309 306 help="""Username for the Session. Default is your system username.""",
310 307 config=True)
311 308
312 309 metadata = Dict({}, config=True,
313 310 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
314 311
315 312 # if 0, no adapting to do.
316 313 adapt_version = Integer(0)
317 314
318 315 # message signature related traits:
319 316
320 317 key = CBytes(b'', config=True,
321 318 help="""execution key, for extra authentication.""")
322 319 def _key_changed(self):
323 320 self._new_auth()
324 321
325 322 signature_scheme = Unicode('hmac-sha256', config=True,
326 323 help="""The digest scheme used to construct the message signatures.
327 324 Must have the form 'hmac-HASH'.""")
328 325 def _signature_scheme_changed(self, name, old, new):
329 326 if not new.startswith('hmac-'):
330 327 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
331 328 hash_name = new.split('-', 1)[1]
332 329 try:
333 330 self.digest_mod = getattr(hashlib, hash_name)
334 331 except AttributeError:
335 332 raise TraitError("hashlib has no such attribute: %s" % hash_name)
336 333 self._new_auth()
337 334
338 335 digest_mod = Any()
339 336 def _digest_mod_default(self):
340 337 return hashlib.sha256
341 338
342 339 auth = Instance(hmac.HMAC)
343 340
344 341 def _new_auth(self):
345 342 if self.key:
346 343 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
347 344 else:
348 345 self.auth = None
349 346
350 347 digest_history = Set()
351 348 digest_history_size = Integer(2**16, config=True,
352 349 help="""The maximum number of digests to remember.
353 350
354 351 The digest history will be culled when it exceeds this value.
355 352 """
356 353 )
357 354
358 355 keyfile = Unicode('', config=True,
359 356 help="""path to file containing execution key.""")
360 357 def _keyfile_changed(self, name, old, new):
361 358 with open(new, 'rb') as f:
362 359 self.key = f.read().strip()
363 360
364 361 # for protecting against sends from forks
365 362 pid = Integer()
366 363
367 364 # serialization traits:
368 365
369 366 pack = Any(default_packer) # the actual packer function
370 367 def _pack_changed(self, name, old, new):
371 368 if not callable(new):
372 369 raise TypeError("packer must be callable, not %s"%type(new))
373 370
374 371 unpack = Any(default_unpacker) # the actual packer function
375 372 def _unpack_changed(self, name, old, new):
376 373 # unpacker is not checked - it is assumed to be
377 374 if not callable(new):
378 375 raise TypeError("unpacker must be callable, not %s"%type(new))
379 376
380 377 # thresholds:
381 378 copy_threshold = Integer(2**16, config=True,
382 379 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
383 380 buffer_threshold = Integer(MAX_BYTES, config=True,
384 381 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
385 382 item_threshold = Integer(MAX_ITEMS, config=True,
386 383 help="""The maximum number of items for a container to be introspected for custom serialization.
387 384 Containers larger than this are pickled outright.
388 385 """
389 386 )
390 387
391 388
392 389 def __init__(self, **kwargs):
393 390 """create a Session object
394 391
395 392 Parameters
396 393 ----------
397 394
398 395 debug : bool
399 396 whether to trigger extra debugging statements
400 397 packer/unpacker : str : 'json', 'pickle' or import_string
401 398 importstrings for methods to serialize message parts. If just
402 399 'json' or 'pickle', predefined JSON and pickle packers will be used.
403 400 Otherwise, the entire importstring must be used.
404 401
405 402 The functions must accept at least valid JSON input, and output
406 403 *bytes*.
407 404
408 405 For example, to use msgpack:
409 406 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
410 407 pack/unpack : callables
411 408 You can also set the pack/unpack callables for serialization
412 409 directly.
413 410 session : unicode (must be ascii)
414 411 the ID of this Session object. The default is to generate a new
415 412 UUID.
416 413 bsession : bytes
417 414 The session as bytes
418 415 username : unicode
419 416 username added to message headers. The default is to ask the OS.
420 417 key : bytes
421 418 The key used to initialize an HMAC signature. If unset, messages
422 419 will not be signed or checked.
423 420 signature_scheme : str
424 421 The message digest scheme. Currently must be of the form 'hmac-HASH',
425 422 where 'HASH' is a hashing function available in Python's hashlib.
426 423 The default is 'hmac-sha256'.
427 424 This is ignored if 'key' is empty.
428 425 keyfile : filepath
429 426 The file containing a key. If this is set, `key` will be
430 427 initialized to the contents of the file.
431 428 """
432 429 super(Session, self).__init__(**kwargs)
433 430 self._check_packers()
434 431 self.none = self.pack({})
435 432 # ensure self._session_default() if necessary, so bsession is defined:
436 433 self.session
437 434 self.pid = os.getpid()
438 435
439 436 @property
440 437 def msg_id(self):
441 438 """always return new uuid"""
442 439 return str(uuid.uuid4())
443 440
444 441 def _check_packers(self):
445 442 """check packers for datetime support."""
446 443 pack = self.pack
447 444 unpack = self.unpack
448 445
449 446 # check simple serialization
450 447 msg = dict(a=[1,'hi'])
451 448 try:
452 449 packed = pack(msg)
453 450 except Exception as e:
454 451 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
455 452 if self.packer == 'json':
456 453 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
457 454 else:
458 455 jsonmsg = ""
459 456 raise ValueError(
460 457 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
461 458 )
462 459
463 460 # ensure packed message is bytes
464 461 if not isinstance(packed, bytes):
465 462 raise ValueError("message packed to %r, but bytes are required"%type(packed))
466 463
467 464 # check that unpack is pack's inverse
468 465 try:
469 466 unpacked = unpack(packed)
470 467 assert unpacked == msg
471 468 except Exception as e:
472 469 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
473 470 if self.packer == 'json':
474 471 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
475 472 else:
476 473 jsonmsg = ""
477 474 raise ValueError(
478 475 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
479 476 )
480 477
481 478 # check datetime support
482 479 msg = dict(t=datetime.now())
483 480 try:
484 481 unpacked = unpack(pack(msg))
485 482 if isinstance(unpacked['t'], datetime):
486 483 raise ValueError("Shouldn't deserialize to datetime")
487 484 except Exception:
488 485 self.pack = lambda o: pack(squash_dates(o))
489 486 self.unpack = lambda s: unpack(s)
490 487
491 488 def msg_header(self, msg_type):
492 489 return msg_header(self.msg_id, msg_type, self.username, self.session)
493 490
494 491 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
495 492 """Return the nested message dict.
496 493
497 494 This format is different from what is sent over the wire. The
498 495 serialize/unserialize methods converts this nested message dict to the wire
499 496 format, which is a list of message parts.
500 497 """
501 498 msg = {}
502 499 header = self.msg_header(msg_type) if header is None else header
503 500 msg['header'] = header
504 501 msg['msg_id'] = header['msg_id']
505 502 msg['msg_type'] = header['msg_type']
506 503 msg['parent_header'] = {} if parent is None else extract_header(parent)
507 504 msg['content'] = {} if content is None else content
508 505 msg['metadata'] = self.metadata.copy()
509 506 if metadata is not None:
510 507 msg['metadata'].update(metadata)
511 508 return msg
512 509
513 510 def sign(self, msg_list):
514 511 """Sign a message with HMAC digest. If no auth, return b''.
515 512
516 513 Parameters
517 514 ----------
518 515 msg_list : list
519 516 The [p_header,p_parent,p_content] part of the message list.
520 517 """
521 518 if self.auth is None:
522 519 return b''
523 520 h = self.auth.copy()
524 521 for m in msg_list:
525 522 h.update(m)
526 523 return str_to_bytes(h.hexdigest())
527 524
528 525 def serialize(self, msg, ident=None):
529 526 """Serialize the message components to bytes.
530 527
531 528 This is roughly the inverse of unserialize. The serialize/unserialize
532 529 methods work with full message lists, whereas pack/unpack work with
533 530 the individual message parts in the message list.
534 531
535 532 Parameters
536 533 ----------
537 534 msg : dict or Message
538 535 The next message dict as returned by the self.msg method.
539 536
540 537 Returns
541 538 -------
542 539 msg_list : list
543 540 The list of bytes objects to be sent with the format::
544 541
545 542 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
546 543 p_metadata, p_content, buffer1, buffer2, ...]
547 544
548 545 In this list, the ``p_*`` entities are the packed or serialized
549 546 versions, so if JSON is used, these are utf8 encoded JSON strings.
550 547 """
551 548 content = msg.get('content', {})
552 549 if content is None:
553 550 content = self.none
554 551 elif isinstance(content, dict):
555 552 content = self.pack(content)
556 553 elif isinstance(content, bytes):
557 554 # content is already packed, as in a relayed message
558 555 pass
559 556 elif isinstance(content, unicode_type):
560 557 # should be bytes, but JSON often spits out unicode
561 558 content = content.encode('utf8')
562 559 else:
563 560 raise TypeError("Content incorrect type: %s"%type(content))
564 561
565 562 real_message = [self.pack(msg['header']),
566 563 self.pack(msg['parent_header']),
567 564 self.pack(msg['metadata']),
568 565 content,
569 566 ]
570 567
571 568 to_send = []
572 569
573 570 if isinstance(ident, list):
574 571 # accept list of idents
575 572 to_send.extend(ident)
576 573 elif ident is not None:
577 574 to_send.append(ident)
578 575 to_send.append(DELIM)
579 576
580 577 signature = self.sign(real_message)
581 578 to_send.append(signature)
582 579
583 580 to_send.extend(real_message)
584 581
585 582 return to_send
586 583
587 584 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
588 585 buffers=None, track=False, header=None, metadata=None):
589 586 """Build and send a message via stream or socket.
590 587
591 588 The message format used by this function internally is as follows:
592 589
593 590 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
594 591 buffer1,buffer2,...]
595 592
596 593 The serialize/unserialize methods convert the nested message dict into this
597 594 format.
598 595
599 596 Parameters
600 597 ----------
601 598
602 599 stream : zmq.Socket or ZMQStream
603 600 The socket-like object used to send the data.
604 601 msg_or_type : str or Message/dict
605 602 Normally, msg_or_type will be a msg_type unless a message is being
606 603 sent more than once. If a header is supplied, this can be set to
607 604 None and the msg_type will be pulled from the header.
608 605
609 606 content : dict or None
610 607 The content of the message (ignored if msg_or_type is a message).
611 608 header : dict or None
612 609 The header dict for the message (ignored if msg_to_type is a message).
613 610 parent : Message or dict or None
614 611 The parent or parent header describing the parent of this message
615 612 (ignored if msg_or_type is a message).
616 613 ident : bytes or list of bytes
617 614 The zmq.IDENTITY routing path.
618 615 metadata : dict or None
619 616 The metadata describing the message
620 617 buffers : list or None
621 618 The already-serialized buffers to be appended to the message.
622 619 track : bool
623 620 Whether to track. Only for use with Sockets, because ZMQStream
624 621 objects cannot track messages.
625 622
626 623
627 624 Returns
628 625 -------
629 626 msg : dict
630 627 The constructed message.
631 628 """
632 629 if not isinstance(stream, zmq.Socket):
633 630 # ZMQStreams and dummy sockets do not support tracking.
634 631 track = False
635 632
636 633 if isinstance(msg_or_type, (Message, dict)):
637 634 # We got a Message or message dict, not a msg_type so don't
638 635 # build a new Message.
639 636 msg = msg_or_type
640 637 else:
641 638 msg = self.msg(msg_or_type, content=content, parent=parent,
642 639 header=header, metadata=metadata)
643 640 if not os.getpid() == self.pid:
644 641 io.rprint("WARNING: attempted to send message from fork")
645 642 io.rprint(msg)
646 643 return
647 644 buffers = [] if buffers is None else buffers
648 645 if self.adapt_version:
649 646 msg = adapt(msg, self.adapt_version)
650 647 to_send = self.serialize(msg, ident)
651 648 to_send.extend(buffers)
652 649 longest = max([ len(s) for s in to_send ])
653 650 copy = (longest < self.copy_threshold)
654 651
655 652 if buffers and track and not copy:
656 653 # only really track when we are doing zero-copy buffers
657 654 tracker = stream.send_multipart(to_send, copy=False, track=True)
658 655 else:
659 656 # use dummy tracker, which will be done immediately
660 657 tracker = DONE
661 658 stream.send_multipart(to_send, copy=copy)
662 659
663 660 if self.debug:
664 661 pprint.pprint(msg)
665 662 pprint.pprint(to_send)
666 663 pprint.pprint(buffers)
667 664
668 665 msg['tracker'] = tracker
669 666
670 667 return msg
671 668
672 669 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
673 670 """Send a raw message via ident path.
674 671
675 672 This method is used to send a already serialized message.
676 673
677 674 Parameters
678 675 ----------
679 676 stream : ZMQStream or Socket
680 677 The ZMQ stream or socket to use for sending the message.
681 678 msg_list : list
682 679 The serialized list of messages to send. This only includes the
683 680 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
684 681 the message.
685 682 ident : ident or list
686 683 A single ident or a list of idents to use in sending.
687 684 """
688 685 to_send = []
689 686 if isinstance(ident, bytes):
690 687 ident = [ident]
691 688 if ident is not None:
692 689 to_send.extend(ident)
693 690
694 691 to_send.append(DELIM)
695 692 to_send.append(self.sign(msg_list))
696 693 to_send.extend(msg_list)
697 694 stream.send_multipart(to_send, flags, copy=copy)
698 695
699 696 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
700 697 """Receive and unpack a message.
701 698
702 699 Parameters
703 700 ----------
704 701 socket : ZMQStream or Socket
705 702 The socket or stream to use in receiving.
706 703
707 704 Returns
708 705 -------
709 706 [idents], msg
710 707 [idents] is a list of idents and msg is a nested message dict of
711 708 same format as self.msg returns.
712 709 """
713 710 if isinstance(socket, ZMQStream):
714 711 socket = socket.socket
715 712 try:
716 713 msg_list = socket.recv_multipart(mode, copy=copy)
717 714 except zmq.ZMQError as e:
718 715 if e.errno == zmq.EAGAIN:
719 716 # We can convert EAGAIN to None as we know in this case
720 717 # recv_multipart won't return None.
721 718 return None,None
722 719 else:
723 720 raise
724 721 # split multipart message into identity list and message dict
725 722 # invalid large messages can cause very expensive string comparisons
726 723 idents, msg_list = self.feed_identities(msg_list, copy)
727 724 try:
728 725 return idents, self.unserialize(msg_list, content=content, copy=copy)
729 726 except Exception as e:
730 727 # TODO: handle it
731 728 raise e
732 729
733 730 def feed_identities(self, msg_list, copy=True):
734 731 """Split the identities from the rest of the message.
735 732
736 733 Feed until DELIM is reached, then return the prefix as idents and
737 734 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
738 735 but that would be silly.
739 736
740 737 Parameters
741 738 ----------
742 739 msg_list : a list of Message or bytes objects
743 740 The message to be split.
744 741 copy : bool
745 742 flag determining whether the arguments are bytes or Messages
746 743
747 744 Returns
748 745 -------
749 746 (idents, msg_list) : two lists
750 747 idents will always be a list of bytes, each of which is a ZMQ
751 748 identity. msg_list will be a list of bytes or zmq.Messages of the
752 749 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
753 750 should be unpackable/unserializable via self.unserialize at this
754 751 point.
755 752 """
756 753 if copy:
757 754 idx = msg_list.index(DELIM)
758 755 return msg_list[:idx], msg_list[idx+1:]
759 756 else:
760 757 failed = True
761 758 for idx,m in enumerate(msg_list):
762 759 if m.bytes == DELIM:
763 760 failed = False
764 761 break
765 762 if failed:
766 763 raise ValueError("DELIM not in msg_list")
767 764 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
768 765 return [m.bytes for m in idents], msg_list
769 766
770 767 def _add_digest(self, signature):
771 768 """add a digest to history to protect against replay attacks"""
772 769 if self.digest_history_size == 0:
773 770 # no history, never add digests
774 771 return
775 772
776 773 self.digest_history.add(signature)
777 774 if len(self.digest_history) > self.digest_history_size:
778 775 # threshold reached, cull 10%
779 776 self._cull_digest_history()
780 777
781 778 def _cull_digest_history(self):
782 779 """cull the digest history
783 780
784 781 Removes a randomly selected 10% of the digest history
785 782 """
786 783 current = len(self.digest_history)
787 784 n_to_cull = max(int(current // 10), current - self.digest_history_size)
788 785 if n_to_cull >= current:
789 786 self.digest_history = set()
790 787 return
791 788 to_cull = random.sample(self.digest_history, n_to_cull)
792 789 self.digest_history.difference_update(to_cull)
793 790
794 791 def unserialize(self, msg_list, content=True, copy=True):
795 792 """Unserialize a msg_list to a nested message dict.
796 793
797 794 This is roughly the inverse of serialize. The serialize/unserialize
798 795 methods work with full message lists, whereas pack/unpack work with
799 796 the individual message parts in the message list.
800 797
801 798 Parameters
802 799 ----------
803 800 msg_list : list of bytes or Message objects
804 801 The list of message parts of the form [HMAC,p_header,p_parent,
805 802 p_metadata,p_content,buffer1,buffer2,...].
806 803 content : bool (True)
807 804 Whether to unpack the content dict (True), or leave it packed
808 805 (False).
809 806 copy : bool (True)
810 807 Whether to return the bytes (True), or the non-copying Message
811 808 object in each place (False).
812 809
813 810 Returns
814 811 -------
815 812 msg : dict
816 813 The nested message dict with top-level keys [header, parent_header,
817 814 content, buffers].
818 815 """
819 816 minlen = 5
820 817 message = {}
821 818 if not copy:
822 819 for i in range(minlen):
823 820 msg_list[i] = msg_list[i].bytes
824 821 if self.auth is not None:
825 822 signature = msg_list[0]
826 823 if not signature:
827 824 raise ValueError("Unsigned Message")
828 825 if signature in self.digest_history:
829 826 raise ValueError("Duplicate Signature: %r" % signature)
830 827 self._add_digest(signature)
831 828 check = self.sign(msg_list[1:5])
832 829 if not compare_digest(signature, check):
833 830 raise ValueError("Invalid Signature: %r" % signature)
834 831 if not len(msg_list) >= minlen:
835 832 raise TypeError("malformed message, must have at least %i elements"%minlen)
836 833 header = self.unpack(msg_list[1])
837 834 message['header'] = extract_dates(header)
838 835 message['msg_id'] = header['msg_id']
839 836 message['msg_type'] = header['msg_type']
840 837 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
841 838 message['metadata'] = self.unpack(msg_list[3])
842 839 if content:
843 840 message['content'] = self.unpack(msg_list[4])
844 841 else:
845 842 message['content'] = msg_list[4]
846 843
847 844 message['buffers'] = msg_list[5:]
848 845 # print("received: %s: %s\n %s" % (message['msg_type'], message['header'], message['content']))
849 846 # adapt to the current version
850 847 return adapt(message)
851 848 # print("adapted: %s: %s\n %s" % (adapted['msg_type'], adapted['header'], adapted['content']))
852 849
853 850 def test_msg2obj():
854 851 am = dict(x=1)
855 852 ao = Message(am)
856 853 assert ao.x == am['x']
857 854
858 855 am['y'] = dict(z=1)
859 856 ao = Message(am)
860 857 assert ao.y.z == am['y']['z']
861 858
862 859 k1, k2 = 'y', 'z'
863 860 assert ao[k1][k2] == am[k1][k2]
864 861
865 862 am2 = dict(ao)
866 863 assert am['x'] == am2['x']
867 864 assert am['y']['z'] == am2['y']['z']
868 865
General Comments 0
You need to be logged in to leave comments. Login now