##// END OF EJS Templates
build new Session auth when either key or scheme changes
MinRK -
Show More
@@ -1,852 +1,856 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9 """
10 10
11 11 # Copyright (c) IPython Development Team.
12 12 # Distributed under the terms of the Modified BSD License.
13 13
14 14 import hashlib
15 15 import hmac
16 16 import logging
17 17 import os
18 18 import pprint
19 19 import random
20 20 import uuid
21 21 from datetime import datetime
22 22
23 23 try:
24 24 import cPickle
25 25 pickle = cPickle
26 26 except:
27 27 cPickle = None
28 28 import pickle
29 29
30 30 import zmq
31 31 from zmq.utils import jsonapi
32 32 from zmq.eventloop.ioloop import IOLoop
33 33 from zmq.eventloop.zmqstream import ZMQStream
34 34
35 35 from IPython.core.release import kernel_protocol_version, kernel_protocol_version_info
36 36 from IPython.config.configurable import Configurable, LoggingConfigurable
37 37 from IPython.utils import io
38 38 from IPython.utils.importstring import import_item
39 39 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
40 40 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
41 41 iteritems)
42 42 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
43 43 DottedObjectName, CUnicode, Dict, Integer,
44 44 TraitError,
45 45 )
46 46 from IPython.kernel.adapter import adapt
47 47 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
48 48
49 49 #-----------------------------------------------------------------------------
50 50 # utility functions
51 51 #-----------------------------------------------------------------------------
52 52
53 53 def squash_unicode(obj):
54 54 """coerce unicode back to bytestrings."""
55 55 if isinstance(obj,dict):
56 56 for key in obj.keys():
57 57 obj[key] = squash_unicode(obj[key])
58 58 if isinstance(key, unicode_type):
59 59 obj[squash_unicode(key)] = obj.pop(key)
60 60 elif isinstance(obj, list):
61 61 for i,v in enumerate(obj):
62 62 obj[i] = squash_unicode(v)
63 63 elif isinstance(obj, unicode_type):
64 64 obj = obj.encode('utf8')
65 65 return obj
66 66
67 67 #-----------------------------------------------------------------------------
68 68 # globals and defaults
69 69 #-----------------------------------------------------------------------------
70 70
71 71 # ISO8601-ify datetime objects
72 72 # allow unicode
73 73 # disallow nan, because it's not actually valid JSON
74 74 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
75 75 ensure_ascii=False, allow_nan=False,
76 76 )
77 77 json_unpacker = lambda s: jsonapi.loads(s)
78 78
79 79 pickle_packer = lambda o: pickle.dumps(squash_dates(o),-1)
80 80 pickle_unpacker = pickle.loads
81 81
82 82 default_packer = json_packer
83 83 default_unpacker = json_unpacker
84 84
85 85 DELIM = b"<IDS|MSG>"
86 86 # singleton dummy tracker, which will always report as done
87 87 DONE = zmq.MessageTracker()
88 88
89 89 #-----------------------------------------------------------------------------
90 90 # Mixin tools for apps that use Sessions
91 91 #-----------------------------------------------------------------------------
92 92
93 93 session_aliases = dict(
94 94 ident = 'Session.session',
95 95 user = 'Session.username',
96 96 keyfile = 'Session.keyfile',
97 97 )
98 98
99 99 session_flags = {
100 100 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
101 101 'keyfile' : '' }},
102 102 """Use HMAC digests for authentication of messages.
103 103 Setting this flag will generate a new UUID to use as the HMAC key.
104 104 """),
105 105 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
106 106 """Don't authenticate messages."""),
107 107 }
108 108
109 109 def default_secure(cfg):
110 110 """Set the default behavior for a config environment to be secure.
111 111
112 112 If Session.key/keyfile have not been set, set Session.key to
113 113 a new random UUID.
114 114 """
115 115
116 116 if 'Session' in cfg:
117 117 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
118 118 return
119 119 # key/keyfile not specified, generate new UUID:
120 120 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
121 121
122 122
123 123 #-----------------------------------------------------------------------------
124 124 # Classes
125 125 #-----------------------------------------------------------------------------
126 126
127 127 class SessionFactory(LoggingConfigurable):
128 128 """The Base class for configurables that have a Session, Context, logger,
129 129 and IOLoop.
130 130 """
131 131
132 132 logname = Unicode('')
133 133 def _logname_changed(self, name, old, new):
134 134 self.log = logging.getLogger(new)
135 135
136 136 # not configurable:
137 137 context = Instance('zmq.Context')
138 138 def _context_default(self):
139 139 return zmq.Context.instance()
140 140
141 141 session = Instance('IPython.kernel.zmq.session.Session')
142 142
143 143 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
144 144 def _loop_default(self):
145 145 return IOLoop.instance()
146 146
147 147 def __init__(self, **kwargs):
148 148 super(SessionFactory, self).__init__(**kwargs)
149 149
150 150 if self.session is None:
151 151 # construct the session
152 152 self.session = Session(**kwargs)
153 153
154 154
155 155 class Message(object):
156 156 """A simple message object that maps dict keys to attributes.
157 157
158 158 A Message can be created from a dict and a dict from a Message instance
159 159 simply by calling dict(msg_obj)."""
160 160
161 161 def __init__(self, msg_dict):
162 162 dct = self.__dict__
163 163 for k, v in iteritems(dict(msg_dict)):
164 164 if isinstance(v, dict):
165 165 v = Message(v)
166 166 dct[k] = v
167 167
168 168 # Having this iterator lets dict(msg_obj) work out of the box.
169 169 def __iter__(self):
170 170 return iter(iteritems(self.__dict__))
171 171
172 172 def __repr__(self):
173 173 return repr(self.__dict__)
174 174
175 175 def __str__(self):
176 176 return pprint.pformat(self.__dict__)
177 177
178 178 def __contains__(self, k):
179 179 return k in self.__dict__
180 180
181 181 def __getitem__(self, k):
182 182 return self.__dict__[k]
183 183
184 184
185 185 def msg_header(msg_id, msg_type, username, session):
186 186 date = datetime.now()
187 187 version = kernel_protocol_version
188 188 return locals()
189 189
190 190 def extract_header(msg_or_header):
191 191 """Given a message or header, return the header."""
192 192 if not msg_or_header:
193 193 return {}
194 194 try:
195 195 # See if msg_or_header is the entire message.
196 196 h = msg_or_header['header']
197 197 except KeyError:
198 198 try:
199 199 # See if msg_or_header is just the header
200 200 h = msg_or_header['msg_id']
201 201 except KeyError:
202 202 raise
203 203 else:
204 204 h = msg_or_header
205 205 if not isinstance(h, dict):
206 206 h = dict(h)
207 207 return h
208 208
209 209 class Session(Configurable):
210 210 """Object for handling serialization and sending of messages.
211 211
212 212 The Session object handles building messages and sending them
213 213 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
214 214 other over the network via Session objects, and only need to work with the
215 215 dict-based IPython message spec. The Session will handle
216 216 serialization/deserialization, security, and metadata.
217 217
218 218 Sessions support configurable serialiization via packer/unpacker traits,
219 219 and signing with HMAC digests via the key/keyfile traits.
220 220
221 221 Parameters
222 222 ----------
223 223
224 224 debug : bool
225 225 whether to trigger extra debugging statements
226 226 packer/unpacker : str : 'json', 'pickle' or import_string
227 227 importstrings for methods to serialize message parts. If just
228 228 'json' or 'pickle', predefined JSON and pickle packers will be used.
229 229 Otherwise, the entire importstring must be used.
230 230
231 231 The functions must accept at least valid JSON input, and output *bytes*.
232 232
233 233 For example, to use msgpack:
234 234 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
235 235 pack/unpack : callables
236 236 You can also set the pack/unpack callables for serialization directly.
237 237 session : bytes
238 238 the ID of this Session object. The default is to generate a new UUID.
239 239 username : unicode
240 240 username added to message headers. The default is to ask the OS.
241 241 key : bytes
242 242 The key used to initialize an HMAC signature. If unset, messages
243 243 will not be signed or checked.
244 244 keyfile : filepath
245 245 The file containing a key. If this is set, `key` will be initialized
246 246 to the contents of the file.
247 247
248 248 """
249 249
250 250 debug=Bool(False, config=True, help="""Debug output in the Session""")
251 251
252 252 packer = DottedObjectName('json',config=True,
253 253 help="""The name of the packer for serializing messages.
254 254 Should be one of 'json', 'pickle', or an import name
255 255 for a custom callable serializer.""")
256 256 def _packer_changed(self, name, old, new):
257 257 if new.lower() == 'json':
258 258 self.pack = json_packer
259 259 self.unpack = json_unpacker
260 260 self.unpacker = new
261 261 elif new.lower() == 'pickle':
262 262 self.pack = pickle_packer
263 263 self.unpack = pickle_unpacker
264 264 self.unpacker = new
265 265 else:
266 266 self.pack = import_item(str(new))
267 267
268 268 unpacker = DottedObjectName('json', config=True,
269 269 help="""The name of the unpacker for unserializing messages.
270 270 Only used with custom functions for `packer`.""")
271 271 def _unpacker_changed(self, name, old, new):
272 272 if new.lower() == 'json':
273 273 self.pack = json_packer
274 274 self.unpack = json_unpacker
275 275 self.packer = new
276 276 elif new.lower() == 'pickle':
277 277 self.pack = pickle_packer
278 278 self.unpack = pickle_unpacker
279 279 self.packer = new
280 280 else:
281 281 self.unpack = import_item(str(new))
282 282
283 283 session = CUnicode(u'', config=True,
284 284 help="""The UUID identifying this session.""")
285 285 def _session_default(self):
286 286 u = unicode_type(uuid.uuid4())
287 287 self.bsession = u.encode('ascii')
288 288 return u
289 289
290 290 def _session_changed(self, name, old, new):
291 291 self.bsession = self.session.encode('ascii')
292 292
293 293 # bsession is the session as bytes
294 294 bsession = CBytes(b'')
295 295
296 296 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
297 297 help="""Username for the Session. Default is your system username.""",
298 298 config=True)
299 299
300 300 metadata = Dict({}, config=True,
301 301 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
302 302
303 303 # if 0, no adapting to do.
304 304 adapt_version = Integer(0)
305 305
306 306 # message signature related traits:
307 307
308 308 key = CBytes(b'', config=True,
309 309 help="""execution key, for extra authentication.""")
310 def _key_changed(self, name, old, new):
311 if new:
312 self.auth = hmac.HMAC(new, digestmod=self.digest_mod)
313 else:
314 self.auth = None
310 def _key_changed(self):
311 self._new_auth()
315 312
316 313 signature_scheme = Unicode('hmac-sha256', config=True,
317 314 help="""The digest scheme used to construct the message signatures.
318 315 Must have the form 'hmac-HASH'.""")
319 316 def _signature_scheme_changed(self, name, old, new):
320 317 if not new.startswith('hmac-'):
321 318 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
322 319 hash_name = new.split('-', 1)[1]
323 320 try:
324 321 self.digest_mod = getattr(hashlib, hash_name)
325 322 except AttributeError:
326 323 raise TraitError("hashlib has no such attribute: %s" % hash_name)
324 self._new_auth()
327 325
328 326 digest_mod = Any()
329 327 def _digest_mod_default(self):
330 328 return hashlib.sha256
331 329
332 330 auth = Instance(hmac.HMAC)
333 331
332 def _new_auth(self):
333 if self.key:
334 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
335 else:
336 self.auth = None
337
334 338 digest_history = Set()
335 339 digest_history_size = Integer(2**16, config=True,
336 340 help="""The maximum number of digests to remember.
337 341
338 342 The digest history will be culled when it exceeds this value.
339 343 """
340 344 )
341 345
342 346 keyfile = Unicode('', config=True,
343 347 help="""path to file containing execution key.""")
344 348 def _keyfile_changed(self, name, old, new):
345 349 with open(new, 'rb') as f:
346 350 self.key = f.read().strip()
347 351
348 352 # for protecting against sends from forks
349 353 pid = Integer()
350 354
351 355 # serialization traits:
352 356
353 357 pack = Any(default_packer) # the actual packer function
354 358 def _pack_changed(self, name, old, new):
355 359 if not callable(new):
356 360 raise TypeError("packer must be callable, not %s"%type(new))
357 361
358 362 unpack = Any(default_unpacker) # the actual packer function
359 363 def _unpack_changed(self, name, old, new):
360 364 # unpacker is not checked - it is assumed to be
361 365 if not callable(new):
362 366 raise TypeError("unpacker must be callable, not %s"%type(new))
363 367
364 368 # thresholds:
365 369 copy_threshold = Integer(2**16, config=True,
366 370 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
367 371 buffer_threshold = Integer(MAX_BYTES, config=True,
368 372 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
369 373 item_threshold = Integer(MAX_ITEMS, config=True,
370 374 help="""The maximum number of items for a container to be introspected for custom serialization.
371 375 Containers larger than this are pickled outright.
372 376 """
373 377 )
374 378
375 379
376 380 def __init__(self, **kwargs):
377 381 """create a Session object
378 382
379 383 Parameters
380 384 ----------
381 385
382 386 debug : bool
383 387 whether to trigger extra debugging statements
384 388 packer/unpacker : str : 'json', 'pickle' or import_string
385 389 importstrings for methods to serialize message parts. If just
386 390 'json' or 'pickle', predefined JSON and pickle packers will be used.
387 391 Otherwise, the entire importstring must be used.
388 392
389 393 The functions must accept at least valid JSON input, and output
390 394 *bytes*.
391 395
392 396 For example, to use msgpack:
393 397 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
394 398 pack/unpack : callables
395 399 You can also set the pack/unpack callables for serialization
396 400 directly.
397 401 session : unicode (must be ascii)
398 402 the ID of this Session object. The default is to generate a new
399 403 UUID.
400 404 bsession : bytes
401 405 The session as bytes
402 406 username : unicode
403 407 username added to message headers. The default is to ask the OS.
404 408 key : bytes
405 409 The key used to initialize an HMAC signature. If unset, messages
406 410 will not be signed or checked.
407 411 signature_scheme : str
408 412 The message digest scheme. Currently must be of the form 'hmac-HASH',
409 413 where 'HASH' is a hashing function available in Python's hashlib.
410 414 The default is 'hmac-sha256'.
411 415 This is ignored if 'key' is empty.
412 416 keyfile : filepath
413 417 The file containing a key. If this is set, `key` will be
414 418 initialized to the contents of the file.
415 419 """
416 420 super(Session, self).__init__(**kwargs)
417 421 self._check_packers()
418 422 self.none = self.pack({})
419 423 # ensure self._session_default() if necessary, so bsession is defined:
420 424 self.session
421 425 self.pid = os.getpid()
422 426
423 427 @property
424 428 def msg_id(self):
425 429 """always return new uuid"""
426 430 return str(uuid.uuid4())
427 431
428 432 def _check_packers(self):
429 433 """check packers for datetime support."""
430 434 pack = self.pack
431 435 unpack = self.unpack
432 436
433 437 # check simple serialization
434 438 msg = dict(a=[1,'hi'])
435 439 try:
436 440 packed = pack(msg)
437 441 except Exception as e:
438 442 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
439 443 if self.packer == 'json':
440 444 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
441 445 else:
442 446 jsonmsg = ""
443 447 raise ValueError(
444 448 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
445 449 )
446 450
447 451 # ensure packed message is bytes
448 452 if not isinstance(packed, bytes):
449 453 raise ValueError("message packed to %r, but bytes are required"%type(packed))
450 454
451 455 # check that unpack is pack's inverse
452 456 try:
453 457 unpacked = unpack(packed)
454 458 assert unpacked == msg
455 459 except Exception as e:
456 460 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
457 461 if self.packer == 'json':
458 462 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
459 463 else:
460 464 jsonmsg = ""
461 465 raise ValueError(
462 466 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
463 467 )
464 468
465 469 # check datetime support
466 470 msg = dict(t=datetime.now())
467 471 try:
468 472 unpacked = unpack(pack(msg))
469 473 if isinstance(unpacked['t'], datetime):
470 474 raise ValueError("Shouldn't deserialize to datetime")
471 475 except Exception:
472 476 self.pack = lambda o: pack(squash_dates(o))
473 477 self.unpack = lambda s: unpack(s)
474 478
475 479 def msg_header(self, msg_type):
476 480 return msg_header(self.msg_id, msg_type, self.username, self.session)
477 481
478 482 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
479 483 """Return the nested message dict.
480 484
481 485 This format is different from what is sent over the wire. The
482 486 serialize/unserialize methods converts this nested message dict to the wire
483 487 format, which is a list of message parts.
484 488 """
485 489 msg = {}
486 490 header = self.msg_header(msg_type) if header is None else header
487 491 msg['header'] = header
488 492 msg['msg_id'] = header['msg_id']
489 493 msg['msg_type'] = header['msg_type']
490 494 msg['parent_header'] = {} if parent is None else extract_header(parent)
491 495 msg['content'] = {} if content is None else content
492 496 msg['metadata'] = self.metadata.copy()
493 497 if metadata is not None:
494 498 msg['metadata'].update(metadata)
495 499 return msg
496 500
497 501 def sign(self, msg_list):
498 502 """Sign a message with HMAC digest. If no auth, return b''.
499 503
500 504 Parameters
501 505 ----------
502 506 msg_list : list
503 507 The [p_header,p_parent,p_content] part of the message list.
504 508 """
505 509 if self.auth is None:
506 510 return b''
507 511 h = self.auth.copy()
508 512 for m in msg_list:
509 513 h.update(m)
510 514 return str_to_bytes(h.hexdigest())
511 515
512 516 def serialize(self, msg, ident=None):
513 517 """Serialize the message components to bytes.
514 518
515 519 This is roughly the inverse of unserialize. The serialize/unserialize
516 520 methods work with full message lists, whereas pack/unpack work with
517 521 the individual message parts in the message list.
518 522
519 523 Parameters
520 524 ----------
521 525 msg : dict or Message
522 526 The nexted message dict as returned by the self.msg method.
523 527
524 528 Returns
525 529 -------
526 530 msg_list : list
527 531 The list of bytes objects to be sent with the format::
528 532
529 533 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
530 534 p_metadata, p_content, buffer1, buffer2, ...]
531 535
532 536 In this list, the ``p_*`` entities are the packed or serialized
533 537 versions, so if JSON is used, these are utf8 encoded JSON strings.
534 538 """
535 539 content = msg.get('content', {})
536 540 if content is None:
537 541 content = self.none
538 542 elif isinstance(content, dict):
539 543 content = self.pack(content)
540 544 elif isinstance(content, bytes):
541 545 # content is already packed, as in a relayed message
542 546 pass
543 547 elif isinstance(content, unicode_type):
544 548 # should be bytes, but JSON often spits out unicode
545 549 content = content.encode('utf8')
546 550 else:
547 551 raise TypeError("Content incorrect type: %s"%type(content))
548 552
549 553 real_message = [self.pack(msg['header']),
550 554 self.pack(msg['parent_header']),
551 555 self.pack(msg['metadata']),
552 556 content,
553 557 ]
554 558
555 559 to_send = []
556 560
557 561 if isinstance(ident, list):
558 562 # accept list of idents
559 563 to_send.extend(ident)
560 564 elif ident is not None:
561 565 to_send.append(ident)
562 566 to_send.append(DELIM)
563 567
564 568 signature = self.sign(real_message)
565 569 to_send.append(signature)
566 570
567 571 to_send.extend(real_message)
568 572
569 573 return to_send
570 574
571 575 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
572 576 buffers=None, track=False, header=None, metadata=None):
573 577 """Build and send a message via stream or socket.
574 578
575 579 The message format used by this function internally is as follows:
576 580
577 581 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
578 582 buffer1,buffer2,...]
579 583
580 584 The serialize/unserialize methods convert the nested message dict into this
581 585 format.
582 586
583 587 Parameters
584 588 ----------
585 589
586 590 stream : zmq.Socket or ZMQStream
587 591 The socket-like object used to send the data.
588 592 msg_or_type : str or Message/dict
589 593 Normally, msg_or_type will be a msg_type unless a message is being
590 594 sent more than once. If a header is supplied, this can be set to
591 595 None and the msg_type will be pulled from the header.
592 596
593 597 content : dict or None
594 598 The content of the message (ignored if msg_or_type is a message).
595 599 header : dict or None
596 600 The header dict for the message (ignored if msg_to_type is a message).
597 601 parent : Message or dict or None
598 602 The parent or parent header describing the parent of this message
599 603 (ignored if msg_or_type is a message).
600 604 ident : bytes or list of bytes
601 605 The zmq.IDENTITY routing path.
602 606 metadata : dict or None
603 607 The metadata describing the message
604 608 buffers : list or None
605 609 The already-serialized buffers to be appended to the message.
606 610 track : bool
607 611 Whether to track. Only for use with Sockets, because ZMQStream
608 612 objects cannot track messages.
609 613
610 614
611 615 Returns
612 616 -------
613 617 msg : dict
614 618 The constructed message.
615 619 """
616 620 if not isinstance(stream, zmq.Socket):
617 621 # ZMQStreams and dummy sockets do not support tracking.
618 622 track = False
619 623
620 624 if isinstance(msg_or_type, (Message, dict)):
621 625 # We got a Message or message dict, not a msg_type so don't
622 626 # build a new Message.
623 627 msg = msg_or_type
624 628 else:
625 629 msg = self.msg(msg_or_type, content=content, parent=parent,
626 630 header=header, metadata=metadata)
627 631 if not os.getpid() == self.pid:
628 632 io.rprint("WARNING: attempted to send message from fork")
629 633 io.rprint(msg)
630 634 return
631 635 buffers = [] if buffers is None else buffers
632 636 if self.adapt_version:
633 637 msg = adapt(msg, self.adapt_version)
634 638 to_send = self.serialize(msg, ident)
635 639 to_send.extend(buffers)
636 640 longest = max([ len(s) for s in to_send ])
637 641 copy = (longest < self.copy_threshold)
638 642
639 643 if buffers and track and not copy:
640 644 # only really track when we are doing zero-copy buffers
641 645 tracker = stream.send_multipart(to_send, copy=False, track=True)
642 646 else:
643 647 # use dummy tracker, which will be done immediately
644 648 tracker = DONE
645 649 stream.send_multipart(to_send, copy=copy)
646 650
647 651 if self.debug:
648 652 pprint.pprint(msg)
649 653 pprint.pprint(to_send)
650 654 pprint.pprint(buffers)
651 655
652 656 msg['tracker'] = tracker
653 657
654 658 return msg
655 659
656 660 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
657 661 """Send a raw message via ident path.
658 662
659 663 This method is used to send a already serialized message.
660 664
661 665 Parameters
662 666 ----------
663 667 stream : ZMQStream or Socket
664 668 The ZMQ stream or socket to use for sending the message.
665 669 msg_list : list
666 670 The serialized list of messages to send. This only includes the
667 671 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
668 672 the message.
669 673 ident : ident or list
670 674 A single ident or a list of idents to use in sending.
671 675 """
672 676 to_send = []
673 677 if isinstance(ident, bytes):
674 678 ident = [ident]
675 679 if ident is not None:
676 680 to_send.extend(ident)
677 681
678 682 to_send.append(DELIM)
679 683 to_send.append(self.sign(msg_list))
680 684 to_send.extend(msg_list)
681 685 stream.send_multipart(to_send, flags, copy=copy)
682 686
683 687 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
684 688 """Receive and unpack a message.
685 689
686 690 Parameters
687 691 ----------
688 692 socket : ZMQStream or Socket
689 693 The socket or stream to use in receiving.
690 694
691 695 Returns
692 696 -------
693 697 [idents], msg
694 698 [idents] is a list of idents and msg is a nested message dict of
695 699 same format as self.msg returns.
696 700 """
697 701 if isinstance(socket, ZMQStream):
698 702 socket = socket.socket
699 703 try:
700 704 msg_list = socket.recv_multipart(mode, copy=copy)
701 705 except zmq.ZMQError as e:
702 706 if e.errno == zmq.EAGAIN:
703 707 # We can convert EAGAIN to None as we know in this case
704 708 # recv_multipart won't return None.
705 709 return None,None
706 710 else:
707 711 raise
708 712 # split multipart message into identity list and message dict
709 713 # invalid large messages can cause very expensive string comparisons
710 714 idents, msg_list = self.feed_identities(msg_list, copy)
711 715 try:
712 716 return idents, self.unserialize(msg_list, content=content, copy=copy)
713 717 except Exception as e:
714 718 # TODO: handle it
715 719 raise e
716 720
717 721 def feed_identities(self, msg_list, copy=True):
718 722 """Split the identities from the rest of the message.
719 723
720 724 Feed until DELIM is reached, then return the prefix as idents and
721 725 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
722 726 but that would be silly.
723 727
724 728 Parameters
725 729 ----------
726 730 msg_list : a list of Message or bytes objects
727 731 The message to be split.
728 732 copy : bool
729 733 flag determining whether the arguments are bytes or Messages
730 734
731 735 Returns
732 736 -------
733 737 (idents, msg_list) : two lists
734 738 idents will always be a list of bytes, each of which is a ZMQ
735 739 identity. msg_list will be a list of bytes or zmq.Messages of the
736 740 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
737 741 should be unpackable/unserializable via self.unserialize at this
738 742 point.
739 743 """
740 744 if copy:
741 745 idx = msg_list.index(DELIM)
742 746 return msg_list[:idx], msg_list[idx+1:]
743 747 else:
744 748 failed = True
745 749 for idx,m in enumerate(msg_list):
746 750 if m.bytes == DELIM:
747 751 failed = False
748 752 break
749 753 if failed:
750 754 raise ValueError("DELIM not in msg_list")
751 755 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
752 756 return [m.bytes for m in idents], msg_list
753 757
754 758 def _add_digest(self, signature):
755 759 """add a digest to history to protect against replay attacks"""
756 760 if self.digest_history_size == 0:
757 761 # no history, never add digests
758 762 return
759 763
760 764 self.digest_history.add(signature)
761 765 if len(self.digest_history) > self.digest_history_size:
762 766 # threshold reached, cull 10%
763 767 self._cull_digest_history()
764 768
765 769 def _cull_digest_history(self):
766 770 """cull the digest history
767 771
768 772 Removes a randomly selected 10% of the digest history
769 773 """
770 774 current = len(self.digest_history)
771 775 n_to_cull = max(int(current // 10), current - self.digest_history_size)
772 776 if n_to_cull >= current:
773 777 self.digest_history = set()
774 778 return
775 779 to_cull = random.sample(self.digest_history, n_to_cull)
776 780 self.digest_history.difference_update(to_cull)
777 781
778 782 def unserialize(self, msg_list, content=True, copy=True):
779 783 """Unserialize a msg_list to a nested message dict.
780 784
781 785 This is roughly the inverse of serialize. The serialize/unserialize
782 786 methods work with full message lists, whereas pack/unpack work with
783 787 the individual message parts in the message list.
784 788
785 789 Parameters
786 790 ----------
787 791 msg_list : list of bytes or Message objects
788 792 The list of message parts of the form [HMAC,p_header,p_parent,
789 793 p_metadata,p_content,buffer1,buffer2,...].
790 794 content : bool (True)
791 795 Whether to unpack the content dict (True), or leave it packed
792 796 (False).
793 797 copy : bool (True)
794 798 Whether to return the bytes (True), or the non-copying Message
795 799 object in each place (False).
796 800
797 801 Returns
798 802 -------
799 803 msg : dict
800 804 The nested message dict with top-level keys [header, parent_header,
801 805 content, buffers].
802 806 """
803 807 minlen = 5
804 808 message = {}
805 809 if not copy:
806 810 for i in range(minlen):
807 811 msg_list[i] = msg_list[i].bytes
808 812 if self.auth is not None:
809 813 signature = msg_list[0]
810 814 if not signature:
811 815 raise ValueError("Unsigned Message")
812 816 if signature in self.digest_history:
813 817 raise ValueError("Duplicate Signature: %r" % signature)
814 818 self._add_digest(signature)
815 819 check = self.sign(msg_list[1:5])
816 820 if not signature == check:
817 821 raise ValueError("Invalid Signature: %r" % signature)
818 822 if not len(msg_list) >= minlen:
819 823 raise TypeError("malformed message, must have at least %i elements"%minlen)
820 824 header = self.unpack(msg_list[1])
821 825 message['header'] = extract_dates(header)
822 826 message['msg_id'] = header['msg_id']
823 827 message['msg_type'] = header['msg_type']
824 828 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
825 829 message['metadata'] = self.unpack(msg_list[3])
826 830 if content:
827 831 message['content'] = self.unpack(msg_list[4])
828 832 else:
829 833 message['content'] = msg_list[4]
830 834
831 835 message['buffers'] = msg_list[5:]
832 836 # print("received: %s: %s\n %s" % (message['msg_type'], message['header'], message['content']))
833 837 # adapt to the current version
834 838 return adapt(message)
835 839 # print("adapted: %s: %s\n %s" % (adapted['msg_type'], adapted['header'], adapted['content']))
836 840
837 841 def test_msg2obj():
838 842 am = dict(x=1)
839 843 ao = Message(am)
840 844 assert ao.x == am['x']
841 845
842 846 am['y'] = dict(z=1)
843 847 ao = Message(am)
844 848 assert ao.y.z == am['y']['z']
845 849
846 850 k1, k2 = 'y', 'z'
847 851 assert ao[k1][k2] == am[k1][k2]
848 852
849 853 am2 = dict(ao)
850 854 assert am['x'] == am2['x']
851 855 assert am['y']['z'] == am2['y']['z']
852 856
General Comments 0
You need to be logged in to leave comments. Login now