##// END OF EJS Templates
adapt messages in Session...
MinRK -
Show More
@@ -1,843 +1,852 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9 """
10 10
11 11 # Copyright (c) IPython Development Team.
12 12 # Distributed under the terms of the Modified BSD License.
13 13
14 14 import hashlib
15 15 import hmac
16 16 import logging
17 17 import os
18 18 import pprint
19 19 import random
20 20 import uuid
21 21 from datetime import datetime
22 22
23 23 try:
24 24 import cPickle
25 25 pickle = cPickle
26 26 except:
27 27 cPickle = None
28 28 import pickle
29 29
30 30 import zmq
31 31 from zmq.utils import jsonapi
32 32 from zmq.eventloop.ioloop import IOLoop
33 33 from zmq.eventloop.zmqstream import ZMQStream
34 34
35 from IPython.core.release import kernel_protocol_version
35 from IPython.core.release import kernel_protocol_version, kernel_protocol_version_info
36 36 from IPython.config.configurable import Configurable, LoggingConfigurable
37 37 from IPython.utils import io
38 38 from IPython.utils.importstring import import_item
39 39 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
40 40 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
41 41 iteritems)
42 42 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
43 43 DottedObjectName, CUnicode, Dict, Integer,
44 44 TraitError,
45 45 )
46 from IPython.kernel.adapter import adapt
46 47 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
47 48
48 49 #-----------------------------------------------------------------------------
49 50 # utility functions
50 51 #-----------------------------------------------------------------------------
51 52
52 53 def squash_unicode(obj):
53 54 """coerce unicode back to bytestrings."""
54 55 if isinstance(obj,dict):
55 56 for key in obj.keys():
56 57 obj[key] = squash_unicode(obj[key])
57 58 if isinstance(key, unicode_type):
58 59 obj[squash_unicode(key)] = obj.pop(key)
59 60 elif isinstance(obj, list):
60 61 for i,v in enumerate(obj):
61 62 obj[i] = squash_unicode(v)
62 63 elif isinstance(obj, unicode_type):
63 64 obj = obj.encode('utf8')
64 65 return obj
65 66
66 67 #-----------------------------------------------------------------------------
67 68 # globals and defaults
68 69 #-----------------------------------------------------------------------------
69 70
70 71 # ISO8601-ify datetime objects
71 72 # allow unicode
72 73 # disallow nan, because it's not actually valid JSON
73 74 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
74 75 ensure_ascii=False, allow_nan=False,
75 76 )
76 77 json_unpacker = lambda s: jsonapi.loads(s)
77 78
78 79 pickle_packer = lambda o: pickle.dumps(squash_dates(o),-1)
79 80 pickle_unpacker = pickle.loads
80 81
81 82 default_packer = json_packer
82 83 default_unpacker = json_unpacker
83 84
84 85 DELIM = b"<IDS|MSG>"
85 86 # singleton dummy tracker, which will always report as done
86 87 DONE = zmq.MessageTracker()
87 88
88 89 #-----------------------------------------------------------------------------
89 90 # Mixin tools for apps that use Sessions
90 91 #-----------------------------------------------------------------------------
91 92
92 93 session_aliases = dict(
93 94 ident = 'Session.session',
94 95 user = 'Session.username',
95 96 keyfile = 'Session.keyfile',
96 97 )
97 98
98 99 session_flags = {
99 100 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
100 101 'keyfile' : '' }},
101 102 """Use HMAC digests for authentication of messages.
102 103 Setting this flag will generate a new UUID to use as the HMAC key.
103 104 """),
104 105 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
105 106 """Don't authenticate messages."""),
106 107 }
107 108
108 109 def default_secure(cfg):
109 110 """Set the default behavior for a config environment to be secure.
110 111
111 112 If Session.key/keyfile have not been set, set Session.key to
112 113 a new random UUID.
113 114 """
114 115
115 116 if 'Session' in cfg:
116 117 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
117 118 return
118 119 # key/keyfile not specified, generate new UUID:
119 120 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
120 121
121 122
122 123 #-----------------------------------------------------------------------------
123 124 # Classes
124 125 #-----------------------------------------------------------------------------
125 126
126 127 class SessionFactory(LoggingConfigurable):
127 128 """The Base class for configurables that have a Session, Context, logger,
128 129 and IOLoop.
129 130 """
130 131
131 132 logname = Unicode('')
132 133 def _logname_changed(self, name, old, new):
133 134 self.log = logging.getLogger(new)
134 135
135 136 # not configurable:
136 137 context = Instance('zmq.Context')
137 138 def _context_default(self):
138 139 return zmq.Context.instance()
139 140
140 141 session = Instance('IPython.kernel.zmq.session.Session')
141 142
142 143 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
143 144 def _loop_default(self):
144 145 return IOLoop.instance()
145 146
146 147 def __init__(self, **kwargs):
147 148 super(SessionFactory, self).__init__(**kwargs)
148 149
149 150 if self.session is None:
150 151 # construct the session
151 152 self.session = Session(**kwargs)
152 153
153 154
154 155 class Message(object):
155 156 """A simple message object that maps dict keys to attributes.
156 157
157 158 A Message can be created from a dict and a dict from a Message instance
158 159 simply by calling dict(msg_obj)."""
159 160
160 161 def __init__(self, msg_dict):
161 162 dct = self.__dict__
162 163 for k, v in iteritems(dict(msg_dict)):
163 164 if isinstance(v, dict):
164 165 v = Message(v)
165 166 dct[k] = v
166 167
167 168 # Having this iterator lets dict(msg_obj) work out of the box.
168 169 def __iter__(self):
169 170 return iter(iteritems(self.__dict__))
170 171
171 172 def __repr__(self):
172 173 return repr(self.__dict__)
173 174
174 175 def __str__(self):
175 176 return pprint.pformat(self.__dict__)
176 177
177 178 def __contains__(self, k):
178 179 return k in self.__dict__
179 180
180 181 def __getitem__(self, k):
181 182 return self.__dict__[k]
182 183
183 184
184 185 def msg_header(msg_id, msg_type, username, session):
185 186 date = datetime.now()
186 187 version = kernel_protocol_version
187 188 return locals()
188 189
189 190 def extract_header(msg_or_header):
190 191 """Given a message or header, return the header."""
191 192 if not msg_or_header:
192 193 return {}
193 194 try:
194 195 # See if msg_or_header is the entire message.
195 196 h = msg_or_header['header']
196 197 except KeyError:
197 198 try:
198 199 # See if msg_or_header is just the header
199 200 h = msg_or_header['msg_id']
200 201 except KeyError:
201 202 raise
202 203 else:
203 204 h = msg_or_header
204 205 if not isinstance(h, dict):
205 206 h = dict(h)
206 207 return h
207 208
208 209 class Session(Configurable):
209 210 """Object for handling serialization and sending of messages.
210 211
211 212 The Session object handles building messages and sending them
212 213 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
213 214 other over the network via Session objects, and only need to work with the
214 215 dict-based IPython message spec. The Session will handle
215 216 serialization/deserialization, security, and metadata.
216 217
217 218 Sessions support configurable serialiization via packer/unpacker traits,
218 219 and signing with HMAC digests via the key/keyfile traits.
219 220
220 221 Parameters
221 222 ----------
222 223
223 224 debug : bool
224 225 whether to trigger extra debugging statements
225 226 packer/unpacker : str : 'json', 'pickle' or import_string
226 227 importstrings for methods to serialize message parts. If just
227 228 'json' or 'pickle', predefined JSON and pickle packers will be used.
228 229 Otherwise, the entire importstring must be used.
229 230
230 231 The functions must accept at least valid JSON input, and output *bytes*.
231 232
232 233 For example, to use msgpack:
233 234 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
234 235 pack/unpack : callables
235 236 You can also set the pack/unpack callables for serialization directly.
236 237 session : bytes
237 238 the ID of this Session object. The default is to generate a new UUID.
238 239 username : unicode
239 240 username added to message headers. The default is to ask the OS.
240 241 key : bytes
241 242 The key used to initialize an HMAC signature. If unset, messages
242 243 will not be signed or checked.
243 244 keyfile : filepath
244 245 The file containing a key. If this is set, `key` will be initialized
245 246 to the contents of the file.
246 247
247 248 """
248 249
249 250 debug=Bool(False, config=True, help="""Debug output in the Session""")
250 251
251 252 packer = DottedObjectName('json',config=True,
252 253 help="""The name of the packer for serializing messages.
253 254 Should be one of 'json', 'pickle', or an import name
254 255 for a custom callable serializer.""")
255 256 def _packer_changed(self, name, old, new):
256 257 if new.lower() == 'json':
257 258 self.pack = json_packer
258 259 self.unpack = json_unpacker
259 260 self.unpacker = new
260 261 elif new.lower() == 'pickle':
261 262 self.pack = pickle_packer
262 263 self.unpack = pickle_unpacker
263 264 self.unpacker = new
264 265 else:
265 266 self.pack = import_item(str(new))
266 267
267 268 unpacker = DottedObjectName('json', config=True,
268 269 help="""The name of the unpacker for unserializing messages.
269 270 Only used with custom functions for `packer`.""")
270 271 def _unpacker_changed(self, name, old, new):
271 272 if new.lower() == 'json':
272 273 self.pack = json_packer
273 274 self.unpack = json_unpacker
274 275 self.packer = new
275 276 elif new.lower() == 'pickle':
276 277 self.pack = pickle_packer
277 278 self.unpack = pickle_unpacker
278 279 self.packer = new
279 280 else:
280 281 self.unpack = import_item(str(new))
281 282
282 283 session = CUnicode(u'', config=True,
283 284 help="""The UUID identifying this session.""")
284 285 def _session_default(self):
285 286 u = unicode_type(uuid.uuid4())
286 287 self.bsession = u.encode('ascii')
287 288 return u
288 289
289 290 def _session_changed(self, name, old, new):
290 291 self.bsession = self.session.encode('ascii')
291 292
292 293 # bsession is the session as bytes
293 294 bsession = CBytes(b'')
294 295
295 296 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
296 297 help="""Username for the Session. Default is your system username.""",
297 298 config=True)
298 299
299 300 metadata = Dict({}, config=True,
300 301 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
302
303 # if 0, no adapting to do.
304 adapt_version = Integer(0)
301 305
302 306 # message signature related traits:
303 307
304 308 key = CBytes(b'', config=True,
305 309 help="""execution key, for extra authentication.""")
306 310 def _key_changed(self, name, old, new):
307 311 if new:
308 312 self.auth = hmac.HMAC(new, digestmod=self.digest_mod)
309 313 else:
310 314 self.auth = None
311 315
312 316 signature_scheme = Unicode('hmac-sha256', config=True,
313 317 help="""The digest scheme used to construct the message signatures.
314 318 Must have the form 'hmac-HASH'.""")
315 319 def _signature_scheme_changed(self, name, old, new):
316 320 if not new.startswith('hmac-'):
317 321 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
318 322 hash_name = new.split('-', 1)[1]
319 323 try:
320 324 self.digest_mod = getattr(hashlib, hash_name)
321 325 except AttributeError:
322 326 raise TraitError("hashlib has no such attribute: %s" % hash_name)
323 327
324 328 digest_mod = Any()
325 329 def _digest_mod_default(self):
326 330 return hashlib.sha256
327 331
328 332 auth = Instance(hmac.HMAC)
329 333
330 334 digest_history = Set()
331 335 digest_history_size = Integer(2**16, config=True,
332 336 help="""The maximum number of digests to remember.
333 337
334 338 The digest history will be culled when it exceeds this value.
335 339 """
336 340 )
337 341
338 342 keyfile = Unicode('', config=True,
339 343 help="""path to file containing execution key.""")
340 344 def _keyfile_changed(self, name, old, new):
341 345 with open(new, 'rb') as f:
342 346 self.key = f.read().strip()
343 347
344 348 # for protecting against sends from forks
345 349 pid = Integer()
346 350
347 351 # serialization traits:
348 352
349 353 pack = Any(default_packer) # the actual packer function
350 354 def _pack_changed(self, name, old, new):
351 355 if not callable(new):
352 356 raise TypeError("packer must be callable, not %s"%type(new))
353 357
354 358 unpack = Any(default_unpacker) # the actual packer function
355 359 def _unpack_changed(self, name, old, new):
356 360 # unpacker is not checked - it is assumed to be
357 361 if not callable(new):
358 362 raise TypeError("unpacker must be callable, not %s"%type(new))
359 363
360 364 # thresholds:
361 365 copy_threshold = Integer(2**16, config=True,
362 366 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
363 367 buffer_threshold = Integer(MAX_BYTES, config=True,
364 368 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
365 369 item_threshold = Integer(MAX_ITEMS, config=True,
366 370 help="""The maximum number of items for a container to be introspected for custom serialization.
367 371 Containers larger than this are pickled outright.
368 372 """
369 373 )
370 374
371 375
372 376 def __init__(self, **kwargs):
373 377 """create a Session object
374 378
375 379 Parameters
376 380 ----------
377 381
378 382 debug : bool
379 383 whether to trigger extra debugging statements
380 384 packer/unpacker : str : 'json', 'pickle' or import_string
381 385 importstrings for methods to serialize message parts. If just
382 386 'json' or 'pickle', predefined JSON and pickle packers will be used.
383 387 Otherwise, the entire importstring must be used.
384 388
385 389 The functions must accept at least valid JSON input, and output
386 390 *bytes*.
387 391
388 392 For example, to use msgpack:
389 393 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
390 394 pack/unpack : callables
391 395 You can also set the pack/unpack callables for serialization
392 396 directly.
393 397 session : unicode (must be ascii)
394 398 the ID of this Session object. The default is to generate a new
395 399 UUID.
396 400 bsession : bytes
397 401 The session as bytes
398 402 username : unicode
399 403 username added to message headers. The default is to ask the OS.
400 404 key : bytes
401 405 The key used to initialize an HMAC signature. If unset, messages
402 406 will not be signed or checked.
403 407 signature_scheme : str
404 408 The message digest scheme. Currently must be of the form 'hmac-HASH',
405 409 where 'HASH' is a hashing function available in Python's hashlib.
406 410 The default is 'hmac-sha256'.
407 411 This is ignored if 'key' is empty.
408 412 keyfile : filepath
409 413 The file containing a key. If this is set, `key` will be
410 414 initialized to the contents of the file.
411 415 """
412 416 super(Session, self).__init__(**kwargs)
413 417 self._check_packers()
414 418 self.none = self.pack({})
415 419 # ensure self._session_default() if necessary, so bsession is defined:
416 420 self.session
417 421 self.pid = os.getpid()
418 422
419 423 @property
420 424 def msg_id(self):
421 425 """always return new uuid"""
422 426 return str(uuid.uuid4())
423 427
424 428 def _check_packers(self):
425 429 """check packers for datetime support."""
426 430 pack = self.pack
427 431 unpack = self.unpack
428 432
429 433 # check simple serialization
430 434 msg = dict(a=[1,'hi'])
431 435 try:
432 436 packed = pack(msg)
433 437 except Exception as e:
434 438 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
435 439 if self.packer == 'json':
436 440 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
437 441 else:
438 442 jsonmsg = ""
439 443 raise ValueError(
440 444 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
441 445 )
442 446
443 447 # ensure packed message is bytes
444 448 if not isinstance(packed, bytes):
445 449 raise ValueError("message packed to %r, but bytes are required"%type(packed))
446 450
447 451 # check that unpack is pack's inverse
448 452 try:
449 453 unpacked = unpack(packed)
450 454 assert unpacked == msg
451 455 except Exception as e:
452 456 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
453 457 if self.packer == 'json':
454 458 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
455 459 else:
456 460 jsonmsg = ""
457 461 raise ValueError(
458 462 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
459 463 )
460 464
461 465 # check datetime support
462 466 msg = dict(t=datetime.now())
463 467 try:
464 468 unpacked = unpack(pack(msg))
465 469 if isinstance(unpacked['t'], datetime):
466 470 raise ValueError("Shouldn't deserialize to datetime")
467 471 except Exception:
468 472 self.pack = lambda o: pack(squash_dates(o))
469 473 self.unpack = lambda s: unpack(s)
470 474
471 475 def msg_header(self, msg_type):
472 476 return msg_header(self.msg_id, msg_type, self.username, self.session)
473 477
474 478 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
475 479 """Return the nested message dict.
476 480
477 481 This format is different from what is sent over the wire. The
478 482 serialize/unserialize methods converts this nested message dict to the wire
479 483 format, which is a list of message parts.
480 484 """
481 485 msg = {}
482 486 header = self.msg_header(msg_type) if header is None else header
483 487 msg['header'] = header
484 488 msg['msg_id'] = header['msg_id']
485 489 msg['msg_type'] = header['msg_type']
486 490 msg['parent_header'] = {} if parent is None else extract_header(parent)
487 491 msg['content'] = {} if content is None else content
488 492 msg['metadata'] = self.metadata.copy()
489 493 if metadata is not None:
490 494 msg['metadata'].update(metadata)
491 495 return msg
492 496
493 497 def sign(self, msg_list):
494 498 """Sign a message with HMAC digest. If no auth, return b''.
495 499
496 500 Parameters
497 501 ----------
498 502 msg_list : list
499 503 The [p_header,p_parent,p_content] part of the message list.
500 504 """
501 505 if self.auth is None:
502 506 return b''
503 507 h = self.auth.copy()
504 508 for m in msg_list:
505 509 h.update(m)
506 510 return str_to_bytes(h.hexdigest())
507 511
508 512 def serialize(self, msg, ident=None):
509 513 """Serialize the message components to bytes.
510 514
511 515 This is roughly the inverse of unserialize. The serialize/unserialize
512 516 methods work with full message lists, whereas pack/unpack work with
513 517 the individual message parts in the message list.
514 518
515 519 Parameters
516 520 ----------
517 521 msg : dict or Message
518 522 The nexted message dict as returned by the self.msg method.
519 523
520 524 Returns
521 525 -------
522 526 msg_list : list
523 527 The list of bytes objects to be sent with the format::
524 528
525 529 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
526 530 p_metadata, p_content, buffer1, buffer2, ...]
527 531
528 532 In this list, the ``p_*`` entities are the packed or serialized
529 533 versions, so if JSON is used, these are utf8 encoded JSON strings.
530 534 """
531 535 content = msg.get('content', {})
532 536 if content is None:
533 537 content = self.none
534 538 elif isinstance(content, dict):
535 539 content = self.pack(content)
536 540 elif isinstance(content, bytes):
537 541 # content is already packed, as in a relayed message
538 542 pass
539 543 elif isinstance(content, unicode_type):
540 544 # should be bytes, but JSON often spits out unicode
541 545 content = content.encode('utf8')
542 546 else:
543 547 raise TypeError("Content incorrect type: %s"%type(content))
544 548
545 549 real_message = [self.pack(msg['header']),
546 550 self.pack(msg['parent_header']),
547 551 self.pack(msg['metadata']),
548 552 content,
549 553 ]
550 554
551 555 to_send = []
552 556
553 557 if isinstance(ident, list):
554 558 # accept list of idents
555 559 to_send.extend(ident)
556 560 elif ident is not None:
557 561 to_send.append(ident)
558 562 to_send.append(DELIM)
559 563
560 564 signature = self.sign(real_message)
561 565 to_send.append(signature)
562 566
563 567 to_send.extend(real_message)
564 568
565 569 return to_send
566 570
567 571 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
568 572 buffers=None, track=False, header=None, metadata=None):
569 573 """Build and send a message via stream or socket.
570 574
571 575 The message format used by this function internally is as follows:
572 576
573 577 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
574 578 buffer1,buffer2,...]
575 579
576 580 The serialize/unserialize methods convert the nested message dict into this
577 581 format.
578 582
579 583 Parameters
580 584 ----------
581 585
582 586 stream : zmq.Socket or ZMQStream
583 587 The socket-like object used to send the data.
584 588 msg_or_type : str or Message/dict
585 589 Normally, msg_or_type will be a msg_type unless a message is being
586 590 sent more than once. If a header is supplied, this can be set to
587 591 None and the msg_type will be pulled from the header.
588 592
589 593 content : dict or None
590 594 The content of the message (ignored if msg_or_type is a message).
591 595 header : dict or None
592 596 The header dict for the message (ignored if msg_to_type is a message).
593 597 parent : Message or dict or None
594 598 The parent or parent header describing the parent of this message
595 599 (ignored if msg_or_type is a message).
596 600 ident : bytes or list of bytes
597 601 The zmq.IDENTITY routing path.
598 602 metadata : dict or None
599 603 The metadata describing the message
600 604 buffers : list or None
601 605 The already-serialized buffers to be appended to the message.
602 606 track : bool
603 607 Whether to track. Only for use with Sockets, because ZMQStream
604 608 objects cannot track messages.
605 609
606 610
607 611 Returns
608 612 -------
609 613 msg : dict
610 614 The constructed message.
611 615 """
612 616 if not isinstance(stream, zmq.Socket):
613 617 # ZMQStreams and dummy sockets do not support tracking.
614 618 track = False
615 619
616 620 if isinstance(msg_or_type, (Message, dict)):
617 621 # We got a Message or message dict, not a msg_type so don't
618 622 # build a new Message.
619 623 msg = msg_or_type
620 624 else:
621 625 msg = self.msg(msg_or_type, content=content, parent=parent,
622 626 header=header, metadata=metadata)
623 627 if not os.getpid() == self.pid:
624 628 io.rprint("WARNING: attempted to send message from fork")
625 629 io.rprint(msg)
626 630 return
627 631 buffers = [] if buffers is None else buffers
632 if self.adapt_version:
633 msg = adapt(msg, self.adapt_version)
628 634 to_send = self.serialize(msg, ident)
629 635 to_send.extend(buffers)
630 636 longest = max([ len(s) for s in to_send ])
631 637 copy = (longest < self.copy_threshold)
632 638
633 639 if buffers and track and not copy:
634 640 # only really track when we are doing zero-copy buffers
635 641 tracker = stream.send_multipart(to_send, copy=False, track=True)
636 642 else:
637 643 # use dummy tracker, which will be done immediately
638 644 tracker = DONE
639 645 stream.send_multipart(to_send, copy=copy)
640 646
641 647 if self.debug:
642 648 pprint.pprint(msg)
643 649 pprint.pprint(to_send)
644 650 pprint.pprint(buffers)
645 651
646 652 msg['tracker'] = tracker
647 653
648 654 return msg
649 655
650 656 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
651 657 """Send a raw message via ident path.
652 658
653 659 This method is used to send a already serialized message.
654 660
655 661 Parameters
656 662 ----------
657 663 stream : ZMQStream or Socket
658 664 The ZMQ stream or socket to use for sending the message.
659 665 msg_list : list
660 666 The serialized list of messages to send. This only includes the
661 667 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
662 668 the message.
663 669 ident : ident or list
664 670 A single ident or a list of idents to use in sending.
665 671 """
666 672 to_send = []
667 673 if isinstance(ident, bytes):
668 674 ident = [ident]
669 675 if ident is not None:
670 676 to_send.extend(ident)
671 677
672 678 to_send.append(DELIM)
673 679 to_send.append(self.sign(msg_list))
674 680 to_send.extend(msg_list)
675 681 stream.send_multipart(to_send, flags, copy=copy)
676 682
677 683 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
678 684 """Receive and unpack a message.
679 685
680 686 Parameters
681 687 ----------
682 688 socket : ZMQStream or Socket
683 689 The socket or stream to use in receiving.
684 690
685 691 Returns
686 692 -------
687 693 [idents], msg
688 694 [idents] is a list of idents and msg is a nested message dict of
689 695 same format as self.msg returns.
690 696 """
691 697 if isinstance(socket, ZMQStream):
692 698 socket = socket.socket
693 699 try:
694 700 msg_list = socket.recv_multipart(mode, copy=copy)
695 701 except zmq.ZMQError as e:
696 702 if e.errno == zmq.EAGAIN:
697 703 # We can convert EAGAIN to None as we know in this case
698 704 # recv_multipart won't return None.
699 705 return None,None
700 706 else:
701 707 raise
702 708 # split multipart message into identity list and message dict
703 709 # invalid large messages can cause very expensive string comparisons
704 710 idents, msg_list = self.feed_identities(msg_list, copy)
705 711 try:
706 712 return idents, self.unserialize(msg_list, content=content, copy=copy)
707 713 except Exception as e:
708 714 # TODO: handle it
709 715 raise e
710 716
711 717 def feed_identities(self, msg_list, copy=True):
712 718 """Split the identities from the rest of the message.
713 719
714 720 Feed until DELIM is reached, then return the prefix as idents and
715 721 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
716 722 but that would be silly.
717 723
718 724 Parameters
719 725 ----------
720 726 msg_list : a list of Message or bytes objects
721 727 The message to be split.
722 728 copy : bool
723 729 flag determining whether the arguments are bytes or Messages
724 730
725 731 Returns
726 732 -------
727 733 (idents, msg_list) : two lists
728 734 idents will always be a list of bytes, each of which is a ZMQ
729 735 identity. msg_list will be a list of bytes or zmq.Messages of the
730 736 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
731 737 should be unpackable/unserializable via self.unserialize at this
732 738 point.
733 739 """
734 740 if copy:
735 741 idx = msg_list.index(DELIM)
736 742 return msg_list[:idx], msg_list[idx+1:]
737 743 else:
738 744 failed = True
739 745 for idx,m in enumerate(msg_list):
740 746 if m.bytes == DELIM:
741 747 failed = False
742 748 break
743 749 if failed:
744 750 raise ValueError("DELIM not in msg_list")
745 751 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
746 752 return [m.bytes for m in idents], msg_list
747 753
748 754 def _add_digest(self, signature):
749 755 """add a digest to history to protect against replay attacks"""
750 756 if self.digest_history_size == 0:
751 757 # no history, never add digests
752 758 return
753 759
754 760 self.digest_history.add(signature)
755 761 if len(self.digest_history) > self.digest_history_size:
756 762 # threshold reached, cull 10%
757 763 self._cull_digest_history()
758 764
759 765 def _cull_digest_history(self):
760 766 """cull the digest history
761 767
762 768 Removes a randomly selected 10% of the digest history
763 769 """
764 770 current = len(self.digest_history)
765 771 n_to_cull = max(int(current // 10), current - self.digest_history_size)
766 772 if n_to_cull >= current:
767 773 self.digest_history = set()
768 774 return
769 775 to_cull = random.sample(self.digest_history, n_to_cull)
770 776 self.digest_history.difference_update(to_cull)
771 777
772 778 def unserialize(self, msg_list, content=True, copy=True):
773 779 """Unserialize a msg_list to a nested message dict.
774 780
775 781 This is roughly the inverse of serialize. The serialize/unserialize
776 782 methods work with full message lists, whereas pack/unpack work with
777 783 the individual message parts in the message list.
778 784
779 785 Parameters
780 786 ----------
781 787 msg_list : list of bytes or Message objects
782 788 The list of message parts of the form [HMAC,p_header,p_parent,
783 789 p_metadata,p_content,buffer1,buffer2,...].
784 790 content : bool (True)
785 791 Whether to unpack the content dict (True), or leave it packed
786 792 (False).
787 793 copy : bool (True)
788 794 Whether to return the bytes (True), or the non-copying Message
789 795 object in each place (False).
790 796
791 797 Returns
792 798 -------
793 799 msg : dict
794 800 The nested message dict with top-level keys [header, parent_header,
795 801 content, buffers].
796 802 """
797 803 minlen = 5
798 804 message = {}
799 805 if not copy:
800 806 for i in range(minlen):
801 807 msg_list[i] = msg_list[i].bytes
802 808 if self.auth is not None:
803 809 signature = msg_list[0]
804 810 if not signature:
805 811 raise ValueError("Unsigned Message")
806 812 if signature in self.digest_history:
807 813 raise ValueError("Duplicate Signature: %r" % signature)
808 814 self._add_digest(signature)
809 815 check = self.sign(msg_list[1:5])
810 816 if not signature == check:
811 817 raise ValueError("Invalid Signature: %r" % signature)
812 818 if not len(msg_list) >= minlen:
813 819 raise TypeError("malformed message, must have at least %i elements"%minlen)
814 820 header = self.unpack(msg_list[1])
815 821 message['header'] = extract_dates(header)
816 822 message['msg_id'] = header['msg_id']
817 823 message['msg_type'] = header['msg_type']
818 824 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
819 825 message['metadata'] = self.unpack(msg_list[3])
820 826 if content:
821 827 message['content'] = self.unpack(msg_list[4])
822 828 else:
823 829 message['content'] = msg_list[4]
824 830
825 831 message['buffers'] = msg_list[5:]
826 return message
832 # print("received: %s: %s\n %s" % (message['msg_type'], message['header'], message['content']))
833 # adapt to the current version
834 return adapt(message)
835 # print("adapted: %s: %s\n %s" % (adapted['msg_type'], adapted['header'], adapted['content']))
827 836
828 837 def test_msg2obj():
829 838 am = dict(x=1)
830 839 ao = Message(am)
831 840 assert ao.x == am['x']
832 841
833 842 am['y'] = dict(z=1)
834 843 ao = Message(am)
835 844 assert ao.y.z == am['y']['z']
836 845
837 846 k1, k2 = 'y', 'z'
838 847 assert ao[k1][k2] == am[k1][k2]
839 848
840 849 am2 = dict(ao)
841 850 assert am['x'] == am2['x']
842 851 assert am['y']['z'] == am2['y']['z']
843 852
General Comments 0
You need to be logged in to leave comments. Login now