##// END OF EJS Templates
don't ensure_ascii in JSON messages in Session
MinRK -
Show More
@@ -1,850 +1,837 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9
10 Authors:
11
12 * Min RK
13 * Brian Granger
14 * Fernando Perez
15 9 """
16 #-----------------------------------------------------------------------------
17 # Copyright (C) 2010-2011 The IPython Development Team
18 #
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
21 #-----------------------------------------------------------------------------
22 10
23 #-----------------------------------------------------------------------------
24 # Imports
25 #-----------------------------------------------------------------------------
11 # Copyright (c) IPython Development Team.
12 # Distributed under the terms of the Modified BSD License.
26 13
27 14 import hashlib
28 15 import hmac
29 16 import logging
30 17 import os
31 18 import pprint
32 19 import random
33 20 import uuid
34 21 from datetime import datetime
35 22
36 23 try:
37 24 import cPickle
38 25 pickle = cPickle
39 26 except:
40 27 cPickle = None
41 28 import pickle
42 29
43 30 import zmq
44 31 from zmq.utils import jsonapi
45 32 from zmq.eventloop.ioloop import IOLoop
46 33 from zmq.eventloop.zmqstream import ZMQStream
47 34
48 35 from IPython.config.configurable import Configurable, LoggingConfigurable
49 36 from IPython.utils import io
50 37 from IPython.utils.importstring import import_item
51 38 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
52 39 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
53 40 iteritems)
54 41 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
55 42 DottedObjectName, CUnicode, Dict, Integer,
56 43 TraitError,
57 44 )
58 45 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
59 46
60 47 #-----------------------------------------------------------------------------
61 48 # utility functions
62 49 #-----------------------------------------------------------------------------
63 50
64 51 def squash_unicode(obj):
65 52 """coerce unicode back to bytestrings."""
66 53 if isinstance(obj,dict):
67 54 for key in obj.keys():
68 55 obj[key] = squash_unicode(obj[key])
69 56 if isinstance(key, unicode_type):
70 57 obj[squash_unicode(key)] = obj.pop(key)
71 58 elif isinstance(obj, list):
72 59 for i,v in enumerate(obj):
73 60 obj[i] = squash_unicode(v)
74 61 elif isinstance(obj, unicode_type):
75 62 obj = obj.encode('utf8')
76 63 return obj
77 64
78 65 #-----------------------------------------------------------------------------
79 66 # globals and defaults
80 67 #-----------------------------------------------------------------------------
81 68
82 69 # ISO8601-ify datetime objects
83 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
70 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default, ensure_ascii=False)
84 71 json_unpacker = lambda s: jsonapi.loads(s)
85 72
86 73 pickle_packer = lambda o: pickle.dumps(squash_dates(o),-1)
87 74 pickle_unpacker = pickle.loads
88 75
89 76 default_packer = json_packer
90 77 default_unpacker = json_unpacker
91 78
92 79 DELIM = b"<IDS|MSG>"
93 80 # singleton dummy tracker, which will always report as done
94 81 DONE = zmq.MessageTracker()
95 82
96 83 #-----------------------------------------------------------------------------
97 84 # Mixin tools for apps that use Sessions
98 85 #-----------------------------------------------------------------------------
99 86
100 87 session_aliases = dict(
101 88 ident = 'Session.session',
102 89 user = 'Session.username',
103 90 keyfile = 'Session.keyfile',
104 91 )
105 92
106 93 session_flags = {
107 94 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
108 95 'keyfile' : '' }},
109 96 """Use HMAC digests for authentication of messages.
110 97 Setting this flag will generate a new UUID to use as the HMAC key.
111 98 """),
112 99 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
113 100 """Don't authenticate messages."""),
114 101 }
115 102
116 103 def default_secure(cfg):
117 104 """Set the default behavior for a config environment to be secure.
118 105
119 106 If Session.key/keyfile have not been set, set Session.key to
120 107 a new random UUID.
121 108 """
122 109
123 110 if 'Session' in cfg:
124 111 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
125 112 return
126 113 # key/keyfile not specified, generate new UUID:
127 114 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
128 115
129 116
130 117 #-----------------------------------------------------------------------------
131 118 # Classes
132 119 #-----------------------------------------------------------------------------
133 120
134 121 class SessionFactory(LoggingConfigurable):
135 122 """The Base class for configurables that have a Session, Context, logger,
136 123 and IOLoop.
137 124 """
138 125
139 126 logname = Unicode('')
140 127 def _logname_changed(self, name, old, new):
141 128 self.log = logging.getLogger(new)
142 129
143 130 # not configurable:
144 131 context = Instance('zmq.Context')
145 132 def _context_default(self):
146 133 return zmq.Context.instance()
147 134
148 135 session = Instance('IPython.kernel.zmq.session.Session')
149 136
150 137 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
151 138 def _loop_default(self):
152 139 return IOLoop.instance()
153 140
154 141 def __init__(self, **kwargs):
155 142 super(SessionFactory, self).__init__(**kwargs)
156 143
157 144 if self.session is None:
158 145 # construct the session
159 146 self.session = Session(**kwargs)
160 147
161 148
162 149 class Message(object):
163 150 """A simple message object that maps dict keys to attributes.
164 151
165 152 A Message can be created from a dict and a dict from a Message instance
166 153 simply by calling dict(msg_obj)."""
167 154
168 155 def __init__(self, msg_dict):
169 156 dct = self.__dict__
170 157 for k, v in iteritems(dict(msg_dict)):
171 158 if isinstance(v, dict):
172 159 v = Message(v)
173 160 dct[k] = v
174 161
175 162 # Having this iterator lets dict(msg_obj) work out of the box.
176 163 def __iter__(self):
177 164 return iter(iteritems(self.__dict__))
178 165
179 166 def __repr__(self):
180 167 return repr(self.__dict__)
181 168
182 169 def __str__(self):
183 170 return pprint.pformat(self.__dict__)
184 171
185 172 def __contains__(self, k):
186 173 return k in self.__dict__
187 174
188 175 def __getitem__(self, k):
189 176 return self.__dict__[k]
190 177
191 178
192 179 def msg_header(msg_id, msg_type, username, session):
193 180 date = datetime.now()
194 181 return locals()
195 182
196 183 def extract_header(msg_or_header):
197 184 """Given a message or header, return the header."""
198 185 if not msg_or_header:
199 186 return {}
200 187 try:
201 188 # See if msg_or_header is the entire message.
202 189 h = msg_or_header['header']
203 190 except KeyError:
204 191 try:
205 192 # See if msg_or_header is just the header
206 193 h = msg_or_header['msg_id']
207 194 except KeyError:
208 195 raise
209 196 else:
210 197 h = msg_or_header
211 198 if not isinstance(h, dict):
212 199 h = dict(h)
213 200 return h
214 201
215 202 class Session(Configurable):
216 203 """Object for handling serialization and sending of messages.
217 204
218 205 The Session object handles building messages and sending them
219 206 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
220 207 other over the network via Session objects, and only need to work with the
221 208 dict-based IPython message spec. The Session will handle
222 209 serialization/deserialization, security, and metadata.
223 210
224 211 Sessions support configurable serialiization via packer/unpacker traits,
225 212 and signing with HMAC digests via the key/keyfile traits.
226 213
227 214 Parameters
228 215 ----------
229 216
230 217 debug : bool
231 218 whether to trigger extra debugging statements
232 219 packer/unpacker : str : 'json', 'pickle' or import_string
233 220 importstrings for methods to serialize message parts. If just
234 221 'json' or 'pickle', predefined JSON and pickle packers will be used.
235 222 Otherwise, the entire importstring must be used.
236 223
237 224 The functions must accept at least valid JSON input, and output *bytes*.
238 225
239 226 For example, to use msgpack:
240 227 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
241 228 pack/unpack : callables
242 229 You can also set the pack/unpack callables for serialization directly.
243 230 session : bytes
244 231 the ID of this Session object. The default is to generate a new UUID.
245 232 username : unicode
246 233 username added to message headers. The default is to ask the OS.
247 234 key : bytes
248 235 The key used to initialize an HMAC signature. If unset, messages
249 236 will not be signed or checked.
250 237 keyfile : filepath
251 238 The file containing a key. If this is set, `key` will be initialized
252 239 to the contents of the file.
253 240
254 241 """
255 242
256 243 debug=Bool(False, config=True, help="""Debug output in the Session""")
257 244
258 245 packer = DottedObjectName('json',config=True,
259 246 help="""The name of the packer for serializing messages.
260 247 Should be one of 'json', 'pickle', or an import name
261 248 for a custom callable serializer.""")
262 249 def _packer_changed(self, name, old, new):
263 250 if new.lower() == 'json':
264 251 self.pack = json_packer
265 252 self.unpack = json_unpacker
266 253 self.unpacker = new
267 254 elif new.lower() == 'pickle':
268 255 self.pack = pickle_packer
269 256 self.unpack = pickle_unpacker
270 257 self.unpacker = new
271 258 else:
272 259 self.pack = import_item(str(new))
273 260
274 261 unpacker = DottedObjectName('json', config=True,
275 262 help="""The name of the unpacker for unserializing messages.
276 263 Only used with custom functions for `packer`.""")
277 264 def _unpacker_changed(self, name, old, new):
278 265 if new.lower() == 'json':
279 266 self.pack = json_packer
280 267 self.unpack = json_unpacker
281 268 self.packer = new
282 269 elif new.lower() == 'pickle':
283 270 self.pack = pickle_packer
284 271 self.unpack = pickle_unpacker
285 272 self.packer = new
286 273 else:
287 274 self.unpack = import_item(str(new))
288 275
289 276 session = CUnicode(u'', config=True,
290 277 help="""The UUID identifying this session.""")
291 278 def _session_default(self):
292 279 u = unicode_type(uuid.uuid4())
293 280 self.bsession = u.encode('ascii')
294 281 return u
295 282
296 283 def _session_changed(self, name, old, new):
297 284 self.bsession = self.session.encode('ascii')
298 285
299 286 # bsession is the session as bytes
300 287 bsession = CBytes(b'')
301 288
302 289 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
303 290 help="""Username for the Session. Default is your system username.""",
304 291 config=True)
305 292
306 293 metadata = Dict({}, config=True,
307 294 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
308 295
309 296 # message signature related traits:
310 297
311 298 key = CBytes(b'', config=True,
312 299 help="""execution key, for extra authentication.""")
313 300 def _key_changed(self, name, old, new):
314 301 if new:
315 302 self.auth = hmac.HMAC(new, digestmod=self.digest_mod)
316 303 else:
317 304 self.auth = None
318 305
319 306 signature_scheme = Unicode('hmac-sha256', config=True,
320 307 help="""The digest scheme used to construct the message signatures.
321 308 Must have the form 'hmac-HASH'.""")
322 309 def _signature_scheme_changed(self, name, old, new):
323 310 if not new.startswith('hmac-'):
324 311 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
325 312 hash_name = new.split('-', 1)[1]
326 313 try:
327 314 self.digest_mod = getattr(hashlib, hash_name)
328 315 except AttributeError:
329 316 raise TraitError("hashlib has no such attribute: %s" % hash_name)
330 317
331 318 digest_mod = Any()
332 319 def _digest_mod_default(self):
333 320 return hashlib.sha256
334 321
335 322 auth = Instance(hmac.HMAC)
336 323
337 324 digest_history = Set()
338 325 digest_history_size = Integer(2**16, config=True,
339 326 help="""The maximum number of digests to remember.
340 327
341 328 The digest history will be culled when it exceeds this value.
342 329 """
343 330 )
344 331
345 332 keyfile = Unicode('', config=True,
346 333 help="""path to file containing execution key.""")
347 334 def _keyfile_changed(self, name, old, new):
348 335 with open(new, 'rb') as f:
349 336 self.key = f.read().strip()
350 337
351 338 # for protecting against sends from forks
352 339 pid = Integer()
353 340
354 341 # serialization traits:
355 342
356 343 pack = Any(default_packer) # the actual packer function
357 344 def _pack_changed(self, name, old, new):
358 345 if not callable(new):
359 346 raise TypeError("packer must be callable, not %s"%type(new))
360 347
361 348 unpack = Any(default_unpacker) # the actual packer function
362 349 def _unpack_changed(self, name, old, new):
363 350 # unpacker is not checked - it is assumed to be
364 351 if not callable(new):
365 352 raise TypeError("unpacker must be callable, not %s"%type(new))
366 353
367 354 # thresholds:
368 355 copy_threshold = Integer(2**16, config=True,
369 356 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
370 357 buffer_threshold = Integer(MAX_BYTES, config=True,
371 358 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
372 359 item_threshold = Integer(MAX_ITEMS, config=True,
373 360 help="""The maximum number of items for a container to be introspected for custom serialization.
374 361 Containers larger than this are pickled outright.
375 362 """
376 363 )
377 364
378 365
379 366 def __init__(self, **kwargs):
380 367 """create a Session object
381 368
382 369 Parameters
383 370 ----------
384 371
385 372 debug : bool
386 373 whether to trigger extra debugging statements
387 374 packer/unpacker : str : 'json', 'pickle' or import_string
388 375 importstrings for methods to serialize message parts. If just
389 376 'json' or 'pickle', predefined JSON and pickle packers will be used.
390 377 Otherwise, the entire importstring must be used.
391 378
392 379 The functions must accept at least valid JSON input, and output
393 380 *bytes*.
394 381
395 382 For example, to use msgpack:
396 383 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
397 384 pack/unpack : callables
398 385 You can also set the pack/unpack callables for serialization
399 386 directly.
400 387 session : unicode (must be ascii)
401 388 the ID of this Session object. The default is to generate a new
402 389 UUID.
403 390 bsession : bytes
404 391 The session as bytes
405 392 username : unicode
406 393 username added to message headers. The default is to ask the OS.
407 394 key : bytes
408 395 The key used to initialize an HMAC signature. If unset, messages
409 396 will not be signed or checked.
410 397 signature_scheme : str
411 398 The message digest scheme. Currently must be of the form 'hmac-HASH',
412 399 where 'HASH' is a hashing function available in Python's hashlib.
413 400 The default is 'hmac-sha256'.
414 401 This is ignored if 'key' is empty.
415 402 keyfile : filepath
416 403 The file containing a key. If this is set, `key` will be
417 404 initialized to the contents of the file.
418 405 """
419 406 super(Session, self).__init__(**kwargs)
420 407 self._check_packers()
421 408 self.none = self.pack({})
422 409 # ensure self._session_default() if necessary, so bsession is defined:
423 410 self.session
424 411 self.pid = os.getpid()
425 412
426 413 @property
427 414 def msg_id(self):
428 415 """always return new uuid"""
429 416 return str(uuid.uuid4())
430 417
431 418 def _check_packers(self):
432 419 """check packers for datetime support."""
433 420 pack = self.pack
434 421 unpack = self.unpack
435 422
436 423 # check simple serialization
437 424 msg = dict(a=[1,'hi'])
438 425 try:
439 426 packed = pack(msg)
440 427 except Exception as e:
441 428 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
442 429 if self.packer == 'json':
443 430 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
444 431 else:
445 432 jsonmsg = ""
446 433 raise ValueError(
447 434 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
448 435 )
449 436
450 437 # ensure packed message is bytes
451 438 if not isinstance(packed, bytes):
452 439 raise ValueError("message packed to %r, but bytes are required"%type(packed))
453 440
454 441 # check that unpack is pack's inverse
455 442 try:
456 443 unpacked = unpack(packed)
457 444 assert unpacked == msg
458 445 except Exception as e:
459 446 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
460 447 if self.packer == 'json':
461 448 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
462 449 else:
463 450 jsonmsg = ""
464 451 raise ValueError(
465 452 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
466 453 )
467 454
468 455 # check datetime support
469 456 msg = dict(t=datetime.now())
470 457 try:
471 458 unpacked = unpack(pack(msg))
472 459 if isinstance(unpacked['t'], datetime):
473 460 raise ValueError("Shouldn't deserialize to datetime")
474 461 except Exception:
475 462 self.pack = lambda o: pack(squash_dates(o))
476 463 self.unpack = lambda s: unpack(s)
477 464
478 465 def msg_header(self, msg_type):
479 466 return msg_header(self.msg_id, msg_type, self.username, self.session)
480 467
481 468 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
482 469 """Return the nested message dict.
483 470
484 471 This format is different from what is sent over the wire. The
485 472 serialize/unserialize methods converts this nested message dict to the wire
486 473 format, which is a list of message parts.
487 474 """
488 475 msg = {}
489 476 header = self.msg_header(msg_type) if header is None else header
490 477 msg['header'] = header
491 478 msg['msg_id'] = header['msg_id']
492 479 msg['msg_type'] = header['msg_type']
493 480 msg['parent_header'] = {} if parent is None else extract_header(parent)
494 481 msg['content'] = {} if content is None else content
495 482 msg['metadata'] = self.metadata.copy()
496 483 if metadata is not None:
497 484 msg['metadata'].update(metadata)
498 485 return msg
499 486
500 487 def sign(self, msg_list):
501 488 """Sign a message with HMAC digest. If no auth, return b''.
502 489
503 490 Parameters
504 491 ----------
505 492 msg_list : list
506 493 The [p_header,p_parent,p_content] part of the message list.
507 494 """
508 495 if self.auth is None:
509 496 return b''
510 497 h = self.auth.copy()
511 498 for m in msg_list:
512 499 h.update(m)
513 500 return str_to_bytes(h.hexdigest())
514 501
515 502 def serialize(self, msg, ident=None):
516 503 """Serialize the message components to bytes.
517 504
518 505 This is roughly the inverse of unserialize. The serialize/unserialize
519 506 methods work with full message lists, whereas pack/unpack work with
520 507 the individual message parts in the message list.
521 508
522 509 Parameters
523 510 ----------
524 511 msg : dict or Message
525 512 The nexted message dict as returned by the self.msg method.
526 513
527 514 Returns
528 515 -------
529 516 msg_list : list
530 517 The list of bytes objects to be sent with the format::
531 518
532 519 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
533 520 p_metadata, p_content, buffer1, buffer2, ...]
534 521
535 522 In this list, the ``p_*`` entities are the packed or serialized
536 523 versions, so if JSON is used, these are utf8 encoded JSON strings.
537 524 """
538 525 content = msg.get('content', {})
539 526 if content is None:
540 527 content = self.none
541 528 elif isinstance(content, dict):
542 529 content = self.pack(content)
543 530 elif isinstance(content, bytes):
544 531 # content is already packed, as in a relayed message
545 532 pass
546 533 elif isinstance(content, unicode_type):
547 534 # should be bytes, but JSON often spits out unicode
548 535 content = content.encode('utf8')
549 536 else:
550 537 raise TypeError("Content incorrect type: %s"%type(content))
551 538
552 539 real_message = [self.pack(msg['header']),
553 540 self.pack(msg['parent_header']),
554 541 self.pack(msg['metadata']),
555 542 content,
556 543 ]
557 544
558 545 to_send = []
559 546
560 547 if isinstance(ident, list):
561 548 # accept list of idents
562 549 to_send.extend(ident)
563 550 elif ident is not None:
564 551 to_send.append(ident)
565 552 to_send.append(DELIM)
566 553
567 554 signature = self.sign(real_message)
568 555 to_send.append(signature)
569 556
570 557 to_send.extend(real_message)
571 558
572 559 return to_send
573 560
574 561 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
575 562 buffers=None, track=False, header=None, metadata=None):
576 563 """Build and send a message via stream or socket.
577 564
578 565 The message format used by this function internally is as follows:
579 566
580 567 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
581 568 buffer1,buffer2,...]
582 569
583 570 The serialize/unserialize methods convert the nested message dict into this
584 571 format.
585 572
586 573 Parameters
587 574 ----------
588 575
589 576 stream : zmq.Socket or ZMQStream
590 577 The socket-like object used to send the data.
591 578 msg_or_type : str or Message/dict
592 579 Normally, msg_or_type will be a msg_type unless a message is being
593 580 sent more than once. If a header is supplied, this can be set to
594 581 None and the msg_type will be pulled from the header.
595 582
596 583 content : dict or None
597 584 The content of the message (ignored if msg_or_type is a message).
598 585 header : dict or None
599 586 The header dict for the message (ignored if msg_to_type is a message).
600 587 parent : Message or dict or None
601 588 The parent or parent header describing the parent of this message
602 589 (ignored if msg_or_type is a message).
603 590 ident : bytes or list of bytes
604 591 The zmq.IDENTITY routing path.
605 592 metadata : dict or None
606 593 The metadata describing the message
607 594 buffers : list or None
608 595 The already-serialized buffers to be appended to the message.
609 596 track : bool
610 597 Whether to track. Only for use with Sockets, because ZMQStream
611 598 objects cannot track messages.
612 599
613 600
614 601 Returns
615 602 -------
616 603 msg : dict
617 604 The constructed message.
618 605 """
619 606 if not isinstance(stream, zmq.Socket):
620 607 # ZMQStreams and dummy sockets do not support tracking.
621 608 track = False
622 609
623 610 if isinstance(msg_or_type, (Message, dict)):
624 611 # We got a Message or message dict, not a msg_type so don't
625 612 # build a new Message.
626 613 msg = msg_or_type
627 614 else:
628 615 msg = self.msg(msg_or_type, content=content, parent=parent,
629 616 header=header, metadata=metadata)
630 617 if not os.getpid() == self.pid:
631 618 io.rprint("WARNING: attempted to send message from fork")
632 619 io.rprint(msg)
633 620 return
634 621 buffers = [] if buffers is None else buffers
635 622 to_send = self.serialize(msg, ident)
636 623 to_send.extend(buffers)
637 624 longest = max([ len(s) for s in to_send ])
638 625 copy = (longest < self.copy_threshold)
639 626
640 627 if buffers and track and not copy:
641 628 # only really track when we are doing zero-copy buffers
642 629 tracker = stream.send_multipart(to_send, copy=False, track=True)
643 630 else:
644 631 # use dummy tracker, which will be done immediately
645 632 tracker = DONE
646 633 stream.send_multipart(to_send, copy=copy)
647 634
648 635 if self.debug:
649 636 pprint.pprint(msg)
650 637 pprint.pprint(to_send)
651 638 pprint.pprint(buffers)
652 639
653 640 msg['tracker'] = tracker
654 641
655 642 return msg
656 643
657 644 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
658 645 """Send a raw message via ident path.
659 646
660 647 This method is used to send a already serialized message.
661 648
662 649 Parameters
663 650 ----------
664 651 stream : ZMQStream or Socket
665 652 The ZMQ stream or socket to use for sending the message.
666 653 msg_list : list
667 654 The serialized list of messages to send. This only includes the
668 655 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
669 656 the message.
670 657 ident : ident or list
671 658 A single ident or a list of idents to use in sending.
672 659 """
673 660 to_send = []
674 661 if isinstance(ident, bytes):
675 662 ident = [ident]
676 663 if ident is not None:
677 664 to_send.extend(ident)
678 665
679 666 to_send.append(DELIM)
680 667 to_send.append(self.sign(msg_list))
681 668 to_send.extend(msg_list)
682 669 stream.send_multipart(to_send, flags, copy=copy)
683 670
684 671 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
685 672 """Receive and unpack a message.
686 673
687 674 Parameters
688 675 ----------
689 676 socket : ZMQStream or Socket
690 677 The socket or stream to use in receiving.
691 678
692 679 Returns
693 680 -------
694 681 [idents], msg
695 682 [idents] is a list of idents and msg is a nested message dict of
696 683 same format as self.msg returns.
697 684 """
698 685 if isinstance(socket, ZMQStream):
699 686 socket = socket.socket
700 687 try:
701 688 msg_list = socket.recv_multipart(mode, copy=copy)
702 689 except zmq.ZMQError as e:
703 690 if e.errno == zmq.EAGAIN:
704 691 # We can convert EAGAIN to None as we know in this case
705 692 # recv_multipart won't return None.
706 693 return None,None
707 694 else:
708 695 raise
709 696 # split multipart message into identity list and message dict
710 697 # invalid large messages can cause very expensive string comparisons
711 698 idents, msg_list = self.feed_identities(msg_list, copy)
712 699 try:
713 700 return idents, self.unserialize(msg_list, content=content, copy=copy)
714 701 except Exception as e:
715 702 # TODO: handle it
716 703 raise e
717 704
718 705 def feed_identities(self, msg_list, copy=True):
719 706 """Split the identities from the rest of the message.
720 707
721 708 Feed until DELIM is reached, then return the prefix as idents and
722 709 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
723 710 but that would be silly.
724 711
725 712 Parameters
726 713 ----------
727 714 msg_list : a list of Message or bytes objects
728 715 The message to be split.
729 716 copy : bool
730 717 flag determining whether the arguments are bytes or Messages
731 718
732 719 Returns
733 720 -------
734 721 (idents, msg_list) : two lists
735 722 idents will always be a list of bytes, each of which is a ZMQ
736 723 identity. msg_list will be a list of bytes or zmq.Messages of the
737 724 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
738 725 should be unpackable/unserializable via self.unserialize at this
739 726 point.
740 727 """
741 728 if copy:
742 729 idx = msg_list.index(DELIM)
743 730 return msg_list[:idx], msg_list[idx+1:]
744 731 else:
745 732 failed = True
746 733 for idx,m in enumerate(msg_list):
747 734 if m.bytes == DELIM:
748 735 failed = False
749 736 break
750 737 if failed:
751 738 raise ValueError("DELIM not in msg_list")
752 739 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
753 740 return [m.bytes for m in idents], msg_list
754 741
755 742 def _add_digest(self, signature):
756 743 """add a digest to history to protect against replay attacks"""
757 744 if self.digest_history_size == 0:
758 745 # no history, never add digests
759 746 return
760 747
761 748 self.digest_history.add(signature)
762 749 if len(self.digest_history) > self.digest_history_size:
763 750 # threshold reached, cull 10%
764 751 self._cull_digest_history()
765 752
766 753 def _cull_digest_history(self):
767 754 """cull the digest history
768 755
769 756 Removes a randomly selected 10% of the digest history
770 757 """
771 758 current = len(self.digest_history)
772 759 n_to_cull = max(int(current // 10), current - self.digest_history_size)
773 760 if n_to_cull >= current:
774 761 self.digest_history = set()
775 762 return
776 763 to_cull = random.sample(self.digest_history, n_to_cull)
777 764 self.digest_history.difference_update(to_cull)
778 765
779 766 def unserialize(self, msg_list, content=True, copy=True):
780 767 """Unserialize a msg_list to a nested message dict.
781 768
782 769 This is roughly the inverse of serialize. The serialize/unserialize
783 770 methods work with full message lists, whereas pack/unpack work with
784 771 the individual message parts in the message list.
785 772
786 773 Parameters
787 774 ----------
788 775 msg_list : list of bytes or Message objects
789 776 The list of message parts of the form [HMAC,p_header,p_parent,
790 777 p_metadata,p_content,buffer1,buffer2,...].
791 778 content : bool (True)
792 779 Whether to unpack the content dict (True), or leave it packed
793 780 (False).
794 781 copy : bool (True)
795 782 Whether to return the bytes (True), or the non-copying Message
796 783 object in each place (False).
797 784
798 785 Returns
799 786 -------
800 787 msg : dict
801 788 The nested message dict with top-level keys [header, parent_header,
802 789 content, buffers].
803 790 """
804 791 minlen = 5
805 792 message = {}
806 793 if not copy:
807 794 for i in range(minlen):
808 795 msg_list[i] = msg_list[i].bytes
809 796 if self.auth is not None:
810 797 signature = msg_list[0]
811 798 if not signature:
812 799 raise ValueError("Unsigned Message")
813 800 if signature in self.digest_history:
814 801 raise ValueError("Duplicate Signature: %r" % signature)
815 802 self._add_digest(signature)
816 803 check = self.sign(msg_list[1:5])
817 804 if not signature == check:
818 805 raise ValueError("Invalid Signature: %r" % signature)
819 806 if not len(msg_list) >= minlen:
820 807 raise TypeError("malformed message, must have at least %i elements"%minlen)
821 808 header = self.unpack(msg_list[1])
822 809 message['header'] = extract_dates(header)
823 810 message['msg_id'] = header['msg_id']
824 811 message['msg_type'] = header['msg_type']
825 812 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
826 813 message['metadata'] = self.unpack(msg_list[3])
827 814 if content:
828 815 message['content'] = self.unpack(msg_list[4])
829 816 else:
830 817 message['content'] = msg_list[4]
831 818
832 819 message['buffers'] = msg_list[5:]
833 820 return message
834 821
835 822 def test_msg2obj():
836 823 am = dict(x=1)
837 824 ao = Message(am)
838 825 assert ao.x == am['x']
839 826
840 827 am['y'] = dict(z=1)
841 828 ao = Message(am)
842 829 assert ao.y.z == am['y']['z']
843 830
844 831 k1, k2 = 'y', 'z'
845 832 assert ao[k1][k2] == am[k1][k2]
846 833
847 834 am2 = dict(ao)
848 835 assert am['x'] == am2['x']
849 836 assert am['y']['z'] == am2['y']['z']
850 837
General Comments 0
You need to be logged in to leave comments. Login now