##// END OF EJS Templates
Merge pull request #4840 from dsblank/master...
Thomas Kluyver -
r14648:8858f7ee merge
parent child Browse files
Show More
@@ -1,850 +1,850 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9
10 10 Authors:
11 11
12 12 * Min RK
13 13 * Brian Granger
14 14 * Fernando Perez
15 15 """
16 16 #-----------------------------------------------------------------------------
17 17 # Copyright (C) 2010-2011 The IPython Development Team
18 18 #
19 19 # Distributed under the terms of the BSD License. The full license is in
20 20 # the file COPYING, distributed as part of this software.
21 21 #-----------------------------------------------------------------------------
22 22
23 23 #-----------------------------------------------------------------------------
24 24 # Imports
25 25 #-----------------------------------------------------------------------------
26 26
27 27 import hashlib
28 28 import hmac
29 29 import logging
30 30 import os
31 31 import pprint
32 32 import random
33 33 import uuid
34 34 from datetime import datetime
35 35
36 36 try:
37 37 import cPickle
38 38 pickle = cPickle
39 39 except:
40 40 cPickle = None
41 41 import pickle
42 42
43 43 import zmq
44 44 from zmq.utils import jsonapi
45 45 from zmq.eventloop.ioloop import IOLoop
46 46 from zmq.eventloop.zmqstream import ZMQStream
47 47
48 48 from IPython.config.configurable import Configurable, LoggingConfigurable
49 49 from IPython.utils import io
50 50 from IPython.utils.importstring import import_item
51 51 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
52 52 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
53 53 iteritems)
54 54 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
55 55 DottedObjectName, CUnicode, Dict, Integer,
56 56 TraitError,
57 57 )
58 58 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
59 59
60 60 #-----------------------------------------------------------------------------
61 61 # utility functions
62 62 #-----------------------------------------------------------------------------
63 63
64 64 def squash_unicode(obj):
65 65 """coerce unicode back to bytestrings."""
66 66 if isinstance(obj,dict):
67 67 for key in obj.keys():
68 68 obj[key] = squash_unicode(obj[key])
69 69 if isinstance(key, unicode_type):
70 70 obj[squash_unicode(key)] = obj.pop(key)
71 71 elif isinstance(obj, list):
72 72 for i,v in enumerate(obj):
73 73 obj[i] = squash_unicode(v)
74 74 elif isinstance(obj, unicode_type):
75 75 obj = obj.encode('utf8')
76 76 return obj
77 77
78 78 #-----------------------------------------------------------------------------
79 79 # globals and defaults
80 80 #-----------------------------------------------------------------------------
81 81
82 82 # ISO8601-ify datetime objects
83 83 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
84 84 json_unpacker = lambda s: jsonapi.loads(s)
85 85
86 86 pickle_packer = lambda o: pickle.dumps(squash_dates(o),-1)
87 87 pickle_unpacker = pickle.loads
88 88
89 89 default_packer = json_packer
90 90 default_unpacker = json_unpacker
91 91
92 92 DELIM = b"<IDS|MSG>"
93 93 # singleton dummy tracker, which will always report as done
94 94 DONE = zmq.MessageTracker()
95 95
96 96 #-----------------------------------------------------------------------------
97 97 # Mixin tools for apps that use Sessions
98 98 #-----------------------------------------------------------------------------
99 99
100 100 session_aliases = dict(
101 101 ident = 'Session.session',
102 102 user = 'Session.username',
103 103 keyfile = 'Session.keyfile',
104 104 )
105 105
106 106 session_flags = {
107 107 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
108 108 'keyfile' : '' }},
109 109 """Use HMAC digests for authentication of messages.
110 110 Setting this flag will generate a new UUID to use as the HMAC key.
111 111 """),
112 112 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
113 113 """Don't authenticate messages."""),
114 114 }
115 115
116 116 def default_secure(cfg):
117 117 """Set the default behavior for a config environment to be secure.
118 118
119 119 If Session.key/keyfile have not been set, set Session.key to
120 120 a new random UUID.
121 121 """
122 122
123 123 if 'Session' in cfg:
124 124 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
125 125 return
126 126 # key/keyfile not specified, generate new UUID:
127 127 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
128 128
129 129
130 130 #-----------------------------------------------------------------------------
131 131 # Classes
132 132 #-----------------------------------------------------------------------------
133 133
134 134 class SessionFactory(LoggingConfigurable):
135 135 """The Base class for configurables that have a Session, Context, logger,
136 136 and IOLoop.
137 137 """
138 138
139 139 logname = Unicode('')
140 140 def _logname_changed(self, name, old, new):
141 141 self.log = logging.getLogger(new)
142 142
143 143 # not configurable:
144 144 context = Instance('zmq.Context')
145 145 def _context_default(self):
146 146 return zmq.Context.instance()
147 147
148 148 session = Instance('IPython.kernel.zmq.session.Session')
149 149
150 150 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
151 151 def _loop_default(self):
152 152 return IOLoop.instance()
153 153
154 154 def __init__(self, **kwargs):
155 155 super(SessionFactory, self).__init__(**kwargs)
156 156
157 157 if self.session is None:
158 158 # construct the session
159 159 self.session = Session(**kwargs)
160 160
161 161
162 162 class Message(object):
163 163 """A simple message object that maps dict keys to attributes.
164 164
165 165 A Message can be created from a dict and a dict from a Message instance
166 166 simply by calling dict(msg_obj)."""
167 167
168 168 def __init__(self, msg_dict):
169 169 dct = self.__dict__
170 170 for k, v in iteritems(dict(msg_dict)):
171 171 if isinstance(v, dict):
172 172 v = Message(v)
173 173 dct[k] = v
174 174
175 175 # Having this iterator lets dict(msg_obj) work out of the box.
176 176 def __iter__(self):
177 177 return iter(iteritems(self.__dict__))
178 178
179 179 def __repr__(self):
180 180 return repr(self.__dict__)
181 181
182 182 def __str__(self):
183 183 return pprint.pformat(self.__dict__)
184 184
185 185 def __contains__(self, k):
186 186 return k in self.__dict__
187 187
188 188 def __getitem__(self, k):
189 189 return self.__dict__[k]
190 190
191 191
192 192 def msg_header(msg_id, msg_type, username, session):
193 193 date = datetime.now()
194 194 return locals()
195 195
196 196 def extract_header(msg_or_header):
197 197 """Given a message or header, return the header."""
198 198 if not msg_or_header:
199 199 return {}
200 200 try:
201 201 # See if msg_or_header is the entire message.
202 202 h = msg_or_header['header']
203 203 except KeyError:
204 204 try:
205 205 # See if msg_or_header is just the header
206 206 h = msg_or_header['msg_id']
207 207 except KeyError:
208 208 raise
209 209 else:
210 210 h = msg_or_header
211 211 if not isinstance(h, dict):
212 212 h = dict(h)
213 213 return h
214 214
215 215 class Session(Configurable):
216 216 """Object for handling serialization and sending of messages.
217 217
218 218 The Session object handles building messages and sending them
219 219 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
220 220 other over the network via Session objects, and only need to work with the
221 221 dict-based IPython message spec. The Session will handle
222 222 serialization/deserialization, security, and metadata.
223 223
224 224 Sessions support configurable serialiization via packer/unpacker traits,
225 225 and signing with HMAC digests via the key/keyfile traits.
226 226
227 227 Parameters
228 228 ----------
229 229
230 230 debug : bool
231 231 whether to trigger extra debugging statements
232 232 packer/unpacker : str : 'json', 'pickle' or import_string
233 233 importstrings for methods to serialize message parts. If just
234 234 'json' or 'pickle', predefined JSON and pickle packers will be used.
235 235 Otherwise, the entire importstring must be used.
236 236
237 237 The functions must accept at least valid JSON input, and output *bytes*.
238 238
239 239 For example, to use msgpack:
240 240 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
241 241 pack/unpack : callables
242 242 You can also set the pack/unpack callables for serialization directly.
243 243 session : bytes
244 244 the ID of this Session object. The default is to generate a new UUID.
245 245 username : unicode
246 246 username added to message headers. The default is to ask the OS.
247 247 key : bytes
248 248 The key used to initialize an HMAC signature. If unset, messages
249 249 will not be signed or checked.
250 250 keyfile : filepath
251 251 The file containing a key. If this is set, `key` will be initialized
252 252 to the contents of the file.
253 253
254 254 """
255 255
256 256 debug=Bool(False, config=True, help="""Debug output in the Session""")
257 257
258 258 packer = DottedObjectName('json',config=True,
259 259 help="""The name of the packer for serializing messages.
260 260 Should be one of 'json', 'pickle', or an import name
261 261 for a custom callable serializer.""")
262 262 def _packer_changed(self, name, old, new):
263 263 if new.lower() == 'json':
264 264 self.pack = json_packer
265 265 self.unpack = json_unpacker
266 266 self.unpacker = new
267 267 elif new.lower() == 'pickle':
268 268 self.pack = pickle_packer
269 269 self.unpack = pickle_unpacker
270 270 self.unpacker = new
271 271 else:
272 272 self.pack = import_item(str(new))
273 273
274 274 unpacker = DottedObjectName('json', config=True,
275 275 help="""The name of the unpacker for unserializing messages.
276 276 Only used with custom functions for `packer`.""")
277 277 def _unpacker_changed(self, name, old, new):
278 278 if new.lower() == 'json':
279 279 self.pack = json_packer
280 280 self.unpack = json_unpacker
281 281 self.packer = new
282 282 elif new.lower() == 'pickle':
283 283 self.pack = pickle_packer
284 284 self.unpack = pickle_unpacker
285 285 self.packer = new
286 286 else:
287 287 self.unpack = import_item(str(new))
288 288
289 289 session = CUnicode(u'', config=True,
290 290 help="""The UUID identifying this session.""")
291 291 def _session_default(self):
292 292 u = unicode_type(uuid.uuid4())
293 293 self.bsession = u.encode('ascii')
294 294 return u
295 295
296 296 def _session_changed(self, name, old, new):
297 297 self.bsession = self.session.encode('ascii')
298 298
299 299 # bsession is the session as bytes
300 300 bsession = CBytes(b'')
301 301
302 302 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
303 303 help="""Username for the Session. Default is your system username.""",
304 304 config=True)
305 305
306 306 metadata = Dict({}, config=True,
307 307 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
308 308
309 309 # message signature related traits:
310 310
311 311 key = CBytes(b'', config=True,
312 312 help="""execution key, for extra authentication.""")
313 313 def _key_changed(self, name, old, new):
314 314 if new:
315 315 self.auth = hmac.HMAC(new, digestmod=self.digest_mod)
316 316 else:
317 317 self.auth = None
318 318
319 319 signature_scheme = Unicode('hmac-sha256', config=True,
320 320 help="""The digest scheme used to construct the message signatures.
321 321 Must have the form 'hmac-HASH'.""")
322 322 def _signature_scheme_changed(self, name, old, new):
323 323 if not new.startswith('hmac-'):
324 324 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
325 325 hash_name = new.split('-', 1)[1]
326 326 try:
327 327 self.digest_mod = getattr(hashlib, hash_name)
328 328 except AttributeError:
329 329 raise TraitError("hashlib has no such attribute: %s" % hash_name)
330 330
331 331 digest_mod = Any()
332 332 def _digest_mod_default(self):
333 333 return hashlib.sha256
334 334
335 335 auth = Instance(hmac.HMAC)
336 336
337 337 digest_history = Set()
338 338 digest_history_size = Integer(2**16, config=True,
339 339 help="""The maximum number of digests to remember.
340 340
341 341 The digest history will be culled when it exceeds this value.
342 342 """
343 343 )
344 344
345 345 keyfile = Unicode('', config=True,
346 346 help="""path to file containing execution key.""")
347 347 def _keyfile_changed(self, name, old, new):
348 348 with open(new, 'rb') as f:
349 349 self.key = f.read().strip()
350 350
351 351 # for protecting against sends from forks
352 352 pid = Integer()
353 353
354 354 # serialization traits:
355 355
356 356 pack = Any(default_packer) # the actual packer function
357 357 def _pack_changed(self, name, old, new):
358 358 if not callable(new):
359 359 raise TypeError("packer must be callable, not %s"%type(new))
360 360
361 361 unpack = Any(default_unpacker) # the actual packer function
362 362 def _unpack_changed(self, name, old, new):
363 363 # unpacker is not checked - it is assumed to be
364 364 if not callable(new):
365 365 raise TypeError("unpacker must be callable, not %s"%type(new))
366 366
367 367 # thresholds:
368 368 copy_threshold = Integer(2**16, config=True,
369 369 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
370 370 buffer_threshold = Integer(MAX_BYTES, config=True,
371 371 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
372 372 item_threshold = Integer(MAX_ITEMS, config=True,
373 373 help="""The maximum number of items for a container to be introspected for custom serialization.
374 374 Containers larger than this are pickled outright.
375 375 """
376 376 )
377 377
378 378
379 379 def __init__(self, **kwargs):
380 380 """create a Session object
381 381
382 382 Parameters
383 383 ----------
384 384
385 385 debug : bool
386 386 whether to trigger extra debugging statements
387 387 packer/unpacker : str : 'json', 'pickle' or import_string
388 388 importstrings for methods to serialize message parts. If just
389 389 'json' or 'pickle', predefined JSON and pickle packers will be used.
390 390 Otherwise, the entire importstring must be used.
391 391
392 392 The functions must accept at least valid JSON input, and output
393 393 *bytes*.
394 394
395 395 For example, to use msgpack:
396 396 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
397 397 pack/unpack : callables
398 398 You can also set the pack/unpack callables for serialization
399 399 directly.
400 400 session : unicode (must be ascii)
401 401 the ID of this Session object. The default is to generate a new
402 402 UUID.
403 403 bsession : bytes
404 404 The session as bytes
405 405 username : unicode
406 406 username added to message headers. The default is to ask the OS.
407 407 key : bytes
408 408 The key used to initialize an HMAC signature. If unset, messages
409 409 will not be signed or checked.
410 410 signature_scheme : str
411 411 The message digest scheme. Currently must be of the form 'hmac-HASH',
412 412 where 'HASH' is a hashing function available in Python's hashlib.
413 413 The default is 'hmac-sha256'.
414 414 This is ignored if 'key' is empty.
415 415 keyfile : filepath
416 416 The file containing a key. If this is set, `key` will be
417 417 initialized to the contents of the file.
418 418 """
419 419 super(Session, self).__init__(**kwargs)
420 420 self._check_packers()
421 421 self.none = self.pack({})
422 422 # ensure self._session_default() if necessary, so bsession is defined:
423 423 self.session
424 424 self.pid = os.getpid()
425 425
426 426 @property
427 427 def msg_id(self):
428 428 """always return new uuid"""
429 429 return str(uuid.uuid4())
430 430
431 431 def _check_packers(self):
432 432 """check packers for datetime support."""
433 433 pack = self.pack
434 434 unpack = self.unpack
435 435
436 436 # check simple serialization
437 437 msg = dict(a=[1,'hi'])
438 438 try:
439 439 packed = pack(msg)
440 440 except Exception as e:
441 441 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
442 442 if self.packer == 'json':
443 443 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
444 444 else:
445 445 jsonmsg = ""
446 446 raise ValueError(
447 447 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
448 448 )
449 449
450 450 # ensure packed message is bytes
451 451 if not isinstance(packed, bytes):
452 452 raise ValueError("message packed to %r, but bytes are required"%type(packed))
453 453
454 454 # check that unpack is pack's inverse
455 455 try:
456 456 unpacked = unpack(packed)
457 457 assert unpacked == msg
458 458 except Exception as e:
459 459 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
460 460 if self.packer == 'json':
461 461 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
462 462 else:
463 463 jsonmsg = ""
464 464 raise ValueError(
465 465 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
466 466 )
467 467
468 468 # check datetime support
469 469 msg = dict(t=datetime.now())
470 470 try:
471 471 unpacked = unpack(pack(msg))
472 472 if isinstance(unpacked['t'], datetime):
473 473 raise ValueError("Shouldn't deserialize to datetime")
474 474 except Exception:
475 475 self.pack = lambda o: pack(squash_dates(o))
476 476 self.unpack = lambda s: unpack(s)
477 477
478 478 def msg_header(self, msg_type):
479 479 return msg_header(self.msg_id, msg_type, self.username, self.session)
480 480
481 481 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
482 482 """Return the nested message dict.
483 483
484 484 This format is different from what is sent over the wire. The
485 485 serialize/unserialize methods converts this nested message dict to the wire
486 486 format, which is a list of message parts.
487 487 """
488 488 msg = {}
489 489 header = self.msg_header(msg_type) if header is None else header
490 490 msg['header'] = header
491 491 msg['msg_id'] = header['msg_id']
492 492 msg['msg_type'] = header['msg_type']
493 493 msg['parent_header'] = {} if parent is None else extract_header(parent)
494 494 msg['content'] = {} if content is None else content
495 495 msg['metadata'] = self.metadata.copy()
496 496 if metadata is not None:
497 497 msg['metadata'].update(metadata)
498 498 return msg
499 499
500 500 def sign(self, msg_list):
501 501 """Sign a message with HMAC digest. If no auth, return b''.
502 502
503 503 Parameters
504 504 ----------
505 505 msg_list : list
506 506 The [p_header,p_parent,p_content] part of the message list.
507 507 """
508 508 if self.auth is None:
509 509 return b''
510 510 h = self.auth.copy()
511 511 for m in msg_list:
512 512 h.update(m)
513 513 return str_to_bytes(h.hexdigest())
514 514
515 515 def serialize(self, msg, ident=None):
516 516 """Serialize the message components to bytes.
517 517
518 518 This is roughly the inverse of unserialize. The serialize/unserialize
519 519 methods work with full message lists, whereas pack/unpack work with
520 520 the individual message parts in the message list.
521 521
522 522 Parameters
523 523 ----------
524 524 msg : dict or Message
525 525 The nexted message dict as returned by the self.msg method.
526 526
527 527 Returns
528 528 -------
529 529 msg_list : list
530 530 The list of bytes objects to be sent with the format::
531 531
532 532 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
533 533 p_metadata, p_content, buffer1, buffer2, ...]
534 534
535 535 In this list, the ``p_*`` entities are the packed or serialized
536 536 versions, so if JSON is used, these are utf8 encoded JSON strings.
537 537 """
538 538 content = msg.get('content', {})
539 539 if content is None:
540 540 content = self.none
541 541 elif isinstance(content, dict):
542 542 content = self.pack(content)
543 543 elif isinstance(content, bytes):
544 544 # content is already packed, as in a relayed message
545 545 pass
546 546 elif isinstance(content, unicode_type):
547 547 # should be bytes, but JSON often spits out unicode
548 548 content = content.encode('utf8')
549 549 else:
550 550 raise TypeError("Content incorrect type: %s"%type(content))
551 551
552 552 real_message = [self.pack(msg['header']),
553 553 self.pack(msg['parent_header']),
554 554 self.pack(msg['metadata']),
555 555 content,
556 556 ]
557 557
558 558 to_send = []
559 559
560 560 if isinstance(ident, list):
561 561 # accept list of idents
562 562 to_send.extend(ident)
563 563 elif ident is not None:
564 564 to_send.append(ident)
565 565 to_send.append(DELIM)
566 566
567 567 signature = self.sign(real_message)
568 568 to_send.append(signature)
569 569
570 570 to_send.extend(real_message)
571 571
572 572 return to_send
573 573
574 574 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
575 575 buffers=None, track=False, header=None, metadata=None):
576 576 """Build and send a message via stream or socket.
577 577
578 578 The message format used by this function internally is as follows:
579 579
580 580 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
581 581 buffer1,buffer2,...]
582 582
583 583 The serialize/unserialize methods convert the nested message dict into this
584 584 format.
585 585
586 586 Parameters
587 587 ----------
588 588
589 589 stream : zmq.Socket or ZMQStream
590 590 The socket-like object used to send the data.
591 591 msg_or_type : str or Message/dict
592 592 Normally, msg_or_type will be a msg_type unless a message is being
593 593 sent more than once. If a header is supplied, this can be set to
594 594 None and the msg_type will be pulled from the header.
595 595
596 596 content : dict or None
597 597 The content of the message (ignored if msg_or_type is a message).
598 598 header : dict or None
599 599 The header dict for the message (ignored if msg_to_type is a message).
600 600 parent : Message or dict or None
601 601 The parent or parent header describing the parent of this message
602 602 (ignored if msg_or_type is a message).
603 603 ident : bytes or list of bytes
604 604 The zmq.IDENTITY routing path.
605 605 metadata : dict or None
606 606 The metadata describing the message
607 607 buffers : list or None
608 608 The already-serialized buffers to be appended to the message.
609 609 track : bool
610 610 Whether to track. Only for use with Sockets, because ZMQStream
611 611 objects cannot track messages.
612 612
613 613
614 614 Returns
615 615 -------
616 616 msg : dict
617 617 The constructed message.
618 618 """
619 619 if not isinstance(stream, zmq.Socket):
620 620 # ZMQStreams and dummy sockets do not support tracking.
621 621 track = False
622 622
623 623 if isinstance(msg_or_type, (Message, dict)):
624 624 # We got a Message or message dict, not a msg_type so don't
625 625 # build a new Message.
626 626 msg = msg_or_type
627 627 else:
628 628 msg = self.msg(msg_or_type, content=content, parent=parent,
629 629 header=header, metadata=metadata)
630 630 if not os.getpid() == self.pid:
631 631 io.rprint("WARNING: attempted to send message from fork")
632 632 io.rprint(msg)
633 633 return
634 634 buffers = [] if buffers is None else buffers
635 635 to_send = self.serialize(msg, ident)
636 636 to_send.extend(buffers)
637 637 longest = max([ len(s) for s in to_send ])
638 638 copy = (longest < self.copy_threshold)
639 639
640 640 if buffers and track and not copy:
641 641 # only really track when we are doing zero-copy buffers
642 642 tracker = stream.send_multipart(to_send, copy=False, track=True)
643 643 else:
644 644 # use dummy tracker, which will be done immediately
645 645 tracker = DONE
646 646 stream.send_multipart(to_send, copy=copy)
647 647
648 648 if self.debug:
649 649 pprint.pprint(msg)
650 650 pprint.pprint(to_send)
651 651 pprint.pprint(buffers)
652 652
653 653 msg['tracker'] = tracker
654 654
655 655 return msg
656 656
657 657 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
658 658 """Send a raw message via ident path.
659 659
660 660 This method is used to send a already serialized message.
661 661
662 662 Parameters
663 663 ----------
664 664 stream : ZMQStream or Socket
665 665 The ZMQ stream or socket to use for sending the message.
666 666 msg_list : list
667 667 The serialized list of messages to send. This only includes the
668 668 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
669 669 the message.
670 670 ident : ident or list
671 671 A single ident or a list of idents to use in sending.
672 672 """
673 673 to_send = []
674 674 if isinstance(ident, bytes):
675 675 ident = [ident]
676 676 if ident is not None:
677 677 to_send.extend(ident)
678 678
679 679 to_send.append(DELIM)
680 680 to_send.append(self.sign(msg_list))
681 681 to_send.extend(msg_list)
682 stream.send_multipart(msg_list, flags, copy=copy)
682 stream.send_multipart(to_send, flags, copy=copy)
683 683
684 684 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
685 685 """Receive and unpack a message.
686 686
687 687 Parameters
688 688 ----------
689 689 socket : ZMQStream or Socket
690 690 The socket or stream to use in receiving.
691 691
692 692 Returns
693 693 -------
694 694 [idents], msg
695 695 [idents] is a list of idents and msg is a nested message dict of
696 696 same format as self.msg returns.
697 697 """
698 698 if isinstance(socket, ZMQStream):
699 699 socket = socket.socket
700 700 try:
701 701 msg_list = socket.recv_multipart(mode, copy=copy)
702 702 except zmq.ZMQError as e:
703 703 if e.errno == zmq.EAGAIN:
704 704 # We can convert EAGAIN to None as we know in this case
705 705 # recv_multipart won't return None.
706 706 return None,None
707 707 else:
708 708 raise
709 709 # split multipart message into identity list and message dict
710 710 # invalid large messages can cause very expensive string comparisons
711 711 idents, msg_list = self.feed_identities(msg_list, copy)
712 712 try:
713 713 return idents, self.unserialize(msg_list, content=content, copy=copy)
714 714 except Exception as e:
715 715 # TODO: handle it
716 716 raise e
717 717
718 718 def feed_identities(self, msg_list, copy=True):
719 719 """Split the identities from the rest of the message.
720 720
721 721 Feed until DELIM is reached, then return the prefix as idents and
722 722 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
723 723 but that would be silly.
724 724
725 725 Parameters
726 726 ----------
727 727 msg_list : a list of Message or bytes objects
728 728 The message to be split.
729 729 copy : bool
730 730 flag determining whether the arguments are bytes or Messages
731 731
732 732 Returns
733 733 -------
734 734 (idents, msg_list) : two lists
735 735 idents will always be a list of bytes, each of which is a ZMQ
736 736 identity. msg_list will be a list of bytes or zmq.Messages of the
737 737 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
738 738 should be unpackable/unserializable via self.unserialize at this
739 739 point.
740 740 """
741 741 if copy:
742 742 idx = msg_list.index(DELIM)
743 743 return msg_list[:idx], msg_list[idx+1:]
744 744 else:
745 745 failed = True
746 746 for idx,m in enumerate(msg_list):
747 747 if m.bytes == DELIM:
748 748 failed = False
749 749 break
750 750 if failed:
751 751 raise ValueError("DELIM not in msg_list")
752 752 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
753 753 return [m.bytes for m in idents], msg_list
754 754
755 755 def _add_digest(self, signature):
756 756 """add a digest to history to protect against replay attacks"""
757 757 if self.digest_history_size == 0:
758 758 # no history, never add digests
759 759 return
760 760
761 761 self.digest_history.add(signature)
762 762 if len(self.digest_history) > self.digest_history_size:
763 763 # threshold reached, cull 10%
764 764 self._cull_digest_history()
765 765
766 766 def _cull_digest_history(self):
767 767 """cull the digest history
768 768
769 769 Removes a randomly selected 10% of the digest history
770 770 """
771 771 current = len(self.digest_history)
772 772 n_to_cull = max(int(current // 10), current - self.digest_history_size)
773 773 if n_to_cull >= current:
774 774 self.digest_history = set()
775 775 return
776 776 to_cull = random.sample(self.digest_history, n_to_cull)
777 777 self.digest_history.difference_update(to_cull)
778 778
779 779 def unserialize(self, msg_list, content=True, copy=True):
780 780 """Unserialize a msg_list to a nested message dict.
781 781
782 782 This is roughly the inverse of serialize. The serialize/unserialize
783 783 methods work with full message lists, whereas pack/unpack work with
784 784 the individual message parts in the message list.
785 785
786 786 Parameters
787 787 ----------
788 788 msg_list : list of bytes or Message objects
789 789 The list of message parts of the form [HMAC,p_header,p_parent,
790 790 p_metadata,p_content,buffer1,buffer2,...].
791 791 content : bool (True)
792 792 Whether to unpack the content dict (True), or leave it packed
793 793 (False).
794 794 copy : bool (True)
795 795 Whether to return the bytes (True), or the non-copying Message
796 796 object in each place (False).
797 797
798 798 Returns
799 799 -------
800 800 msg : dict
801 801 The nested message dict with top-level keys [header, parent_header,
802 802 content, buffers].
803 803 """
804 804 minlen = 5
805 805 message = {}
806 806 if not copy:
807 807 for i in range(minlen):
808 808 msg_list[i] = msg_list[i].bytes
809 809 if self.auth is not None:
810 810 signature = msg_list[0]
811 811 if not signature:
812 812 raise ValueError("Unsigned Message")
813 813 if signature in self.digest_history:
814 814 raise ValueError("Duplicate Signature: %r" % signature)
815 815 self._add_digest(signature)
816 816 check = self.sign(msg_list[1:5])
817 817 if not signature == check:
818 818 raise ValueError("Invalid Signature: %r" % signature)
819 819 if not len(msg_list) >= minlen:
820 820 raise TypeError("malformed message, must have at least %i elements"%minlen)
821 821 header = self.unpack(msg_list[1])
822 822 message['header'] = extract_dates(header)
823 823 message['msg_id'] = header['msg_id']
824 824 message['msg_type'] = header['msg_type']
825 825 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
826 826 message['metadata'] = self.unpack(msg_list[3])
827 827 if content:
828 828 message['content'] = self.unpack(msg_list[4])
829 829 else:
830 830 message['content'] = msg_list[4]
831 831
832 832 message['buffers'] = msg_list[5:]
833 833 return message
834 834
835 835 def test_msg2obj():
836 836 am = dict(x=1)
837 837 ao = Message(am)
838 838 assert ao.x == am['x']
839 839
840 840 am['y'] = dict(z=1)
841 841 ao = Message(am)
842 842 assert ao.y.z == am['y']['z']
843 843
844 844 k1, k2 = 'y', 'z'
845 845 assert ao[k1][k2] == am[k1][k2]
846 846
847 847 am2 = dict(ao)
848 848 assert am['x'] == am2['x']
849 849 assert am['y']['z'] == am2['y']['z']
850 850
@@ -1,289 +1,313 b''
1 1 """test building messages with streamsession"""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14 import os
15 15 import uuid
16 16 from datetime import datetime
17 17
18 18 import zmq
19 19
20 20 from zmq.tests import BaseZMQTestCase
21 21 from zmq.eventloop.zmqstream import ZMQStream
22 22
23 23 from IPython.kernel.zmq import session as ss
24 24
25 25 from IPython.testing.decorators import skipif, module_not_available
26 26 from IPython.utils.py3compat import string_types
27 27 from IPython.utils import jsonutil
28 28
29 29 def _bad_packer(obj):
30 30 raise TypeError("I don't work")
31 31
32 32 def _bad_unpacker(bytes):
33 33 raise TypeError("I don't work either")
34 34
35 35 class SessionTestCase(BaseZMQTestCase):
36 36
37 37 def setUp(self):
38 38 BaseZMQTestCase.setUp(self)
39 39 self.session = ss.Session()
40 40
41 41
42 42 class TestSession(SessionTestCase):
43 43
44 44 def test_msg(self):
45 45 """message format"""
46 46 msg = self.session.msg('execute')
47 47 thekeys = set('header parent_header metadata content msg_type msg_id'.split())
48 48 s = set(msg.keys())
49 49 self.assertEqual(s, thekeys)
50 50 self.assertTrue(isinstance(msg['content'],dict))
51 51 self.assertTrue(isinstance(msg['metadata'],dict))
52 52 self.assertTrue(isinstance(msg['header'],dict))
53 53 self.assertTrue(isinstance(msg['parent_header'],dict))
54 54 self.assertTrue(isinstance(msg['msg_id'],str))
55 55 self.assertTrue(isinstance(msg['msg_type'],str))
56 56 self.assertEqual(msg['header']['msg_type'], 'execute')
57 57 self.assertEqual(msg['msg_type'], 'execute')
58 58
59 59 def test_serialize(self):
60 60 msg = self.session.msg('execute', content=dict(a=10, b=1.1))
61 61 msg_list = self.session.serialize(msg, ident=b'foo')
62 62 ident, msg_list = self.session.feed_identities(msg_list)
63 63 new_msg = self.session.unserialize(msg_list)
64 64 self.assertEqual(ident[0], b'foo')
65 65 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
66 66 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
67 67 self.assertEqual(new_msg['header'],msg['header'])
68 68 self.assertEqual(new_msg['content'],msg['content'])
69 69 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
70 70 self.assertEqual(new_msg['metadata'],msg['metadata'])
71 71 # ensure floats don't come out as Decimal:
72 72 self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
73 73
74 74 def test_send(self):
75 75 ctx = zmq.Context.instance()
76 76 A = ctx.socket(zmq.PAIR)
77 77 B = ctx.socket(zmq.PAIR)
78 78 A.bind("inproc://test")
79 79 B.connect("inproc://test")
80 80
81 81 msg = self.session.msg('execute', content=dict(a=10))
82 82 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
83 83
84 84 ident, msg_list = self.session.feed_identities(B.recv_multipart())
85 85 new_msg = self.session.unserialize(msg_list)
86 86 self.assertEqual(ident[0], b'foo')
87 87 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
88 88 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
89 89 self.assertEqual(new_msg['header'],msg['header'])
90 90 self.assertEqual(new_msg['content'],msg['content'])
91 91 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
92 92 self.assertEqual(new_msg['metadata'],msg['metadata'])
93 93 self.assertEqual(new_msg['buffers'],[b'bar'])
94 94
95 95 content = msg['content']
96 96 header = msg['header']
97 97 parent = msg['parent_header']
98 98 metadata = msg['metadata']
99 99 msg_type = header['msg_type']
100 100 self.session.send(A, None, content=content, parent=parent,
101 101 header=header, metadata=metadata, ident=b'foo', buffers=[b'bar'])
102 102 ident, msg_list = self.session.feed_identities(B.recv_multipart())
103 103 new_msg = self.session.unserialize(msg_list)
104 104 self.assertEqual(ident[0], b'foo')
105 105 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
106 106 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
107 107 self.assertEqual(new_msg['header'],msg['header'])
108 108 self.assertEqual(new_msg['content'],msg['content'])
109 109 self.assertEqual(new_msg['metadata'],msg['metadata'])
110 110 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
111 111 self.assertEqual(new_msg['buffers'],[b'bar'])
112 112
113 113 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
114 114 ident, new_msg = self.session.recv(B)
115 115 self.assertEqual(ident[0], b'foo')
116 116 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
117 117 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
118 118 self.assertEqual(new_msg['header'],msg['header'])
119 119 self.assertEqual(new_msg['content'],msg['content'])
120 120 self.assertEqual(new_msg['metadata'],msg['metadata'])
121 121 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
122 122 self.assertEqual(new_msg['buffers'],[b'bar'])
123 123
124 124 A.close()
125 125 B.close()
126 126 ctx.term()
127 127
128 128 def test_args(self):
129 129 """initialization arguments for Session"""
130 130 s = self.session
131 131 self.assertTrue(s.pack is ss.default_packer)
132 132 self.assertTrue(s.unpack is ss.default_unpacker)
133 133 self.assertEqual(s.username, os.environ.get('USER', u'username'))
134 134
135 135 s = ss.Session()
136 136 self.assertEqual(s.username, os.environ.get('USER', u'username'))
137 137
138 138 self.assertRaises(TypeError, ss.Session, pack='hi')
139 139 self.assertRaises(TypeError, ss.Session, unpack='hi')
140 140 u = str(uuid.uuid4())
141 141 s = ss.Session(username=u'carrot', session=u)
142 142 self.assertEqual(s.session, u)
143 143 self.assertEqual(s.username, u'carrot')
144 144
145 145 def test_tracking(self):
146 146 """test tracking messages"""
147 147 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
148 148 s = self.session
149 149 s.copy_threshold = 1
150 150 stream = ZMQStream(a)
151 151 msg = s.send(a, 'hello', track=False)
152 152 self.assertTrue(msg['tracker'] is ss.DONE)
153 153 msg = s.send(a, 'hello', track=True)
154 154 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
155 155 M = zmq.Message(b'hi there', track=True)
156 156 msg = s.send(a, 'hello', buffers=[M], track=True)
157 157 t = msg['tracker']
158 158 self.assertTrue(isinstance(t, zmq.MessageTracker))
159 159 self.assertRaises(zmq.NotDone, t.wait, .1)
160 160 del M
161 161 t.wait(1) # this will raise
162 162
163 163
164 164 def test_unique_msg_ids(self):
165 165 """test that messages receive unique ids"""
166 166 ids = set()
167 167 for i in range(2**12):
168 168 h = self.session.msg_header('test')
169 169 msg_id = h['msg_id']
170 170 self.assertTrue(msg_id not in ids)
171 171 ids.add(msg_id)
172 172
173 173 def test_feed_identities(self):
174 174 """scrub the front for zmq IDENTITIES"""
175 175 theids = "engine client other".split()
176 176 content = dict(code='whoda',stuff=object())
177 177 themsg = self.session.msg('execute',content=content)
178 178 pmsg = theids
179 179
180 180 def test_session_id(self):
181 181 session = ss.Session()
182 182 # get bs before us
183 183 bs = session.bsession
184 184 us = session.session
185 185 self.assertEqual(us.encode('ascii'), bs)
186 186 session = ss.Session()
187 187 # get us before bs
188 188 us = session.session
189 189 bs = session.bsession
190 190 self.assertEqual(us.encode('ascii'), bs)
191 191 # change propagates:
192 192 session.session = 'something else'
193 193 bs = session.bsession
194 194 us = session.session
195 195 self.assertEqual(us.encode('ascii'), bs)
196 196 session = ss.Session(session='stuff')
197 197 # get us before bs
198 198 self.assertEqual(session.bsession, session.session.encode('ascii'))
199 199 self.assertEqual(b'stuff', session.bsession)
200 200
201 201 def test_zero_digest_history(self):
202 202 session = ss.Session(digest_history_size=0)
203 203 for i in range(11):
204 204 session._add_digest(uuid.uuid4().bytes)
205 205 self.assertEqual(len(session.digest_history), 0)
206 206
207 207 def test_cull_digest_history(self):
208 208 session = ss.Session(digest_history_size=100)
209 209 for i in range(100):
210 210 session._add_digest(uuid.uuid4().bytes)
211 211 self.assertTrue(len(session.digest_history) == 100)
212 212 session._add_digest(uuid.uuid4().bytes)
213 213 self.assertTrue(len(session.digest_history) == 91)
214 214 for i in range(9):
215 215 session._add_digest(uuid.uuid4().bytes)
216 216 self.assertTrue(len(session.digest_history) == 100)
217 217 session._add_digest(uuid.uuid4().bytes)
218 218 self.assertTrue(len(session.digest_history) == 91)
219 219
220 220 def test_bad_pack(self):
221 221 try:
222 222 session = ss.Session(pack=_bad_packer)
223 223 except ValueError as e:
224 224 self.assertIn("could not serialize", str(e))
225 225 self.assertIn("don't work", str(e))
226 226 else:
227 227 self.fail("Should have raised ValueError")
228 228
229 229 def test_bad_unpack(self):
230 230 try:
231 231 session = ss.Session(unpack=_bad_unpacker)
232 232 except ValueError as e:
233 233 self.assertIn("could not handle output", str(e))
234 234 self.assertIn("don't work either", str(e))
235 235 else:
236 236 self.fail("Should have raised ValueError")
237 237
238 238 def test_bad_packer(self):
239 239 try:
240 240 session = ss.Session(packer=__name__ + '._bad_packer')
241 241 except ValueError as e:
242 242 self.assertIn("could not serialize", str(e))
243 243 self.assertIn("don't work", str(e))
244 244 else:
245 245 self.fail("Should have raised ValueError")
246 246
247 247 def test_bad_unpacker(self):
248 248 try:
249 249 session = ss.Session(unpacker=__name__ + '._bad_unpacker')
250 250 except ValueError as e:
251 251 self.assertIn("could not handle output", str(e))
252 252 self.assertIn("don't work either", str(e))
253 253 else:
254 254 self.fail("Should have raised ValueError")
255 255
256 256 def test_bad_roundtrip(self):
257 257 with self.assertRaises(ValueError):
258 258 session = ss.Session(unpack=lambda b: 5)
259 259
260 260 def _datetime_test(self, session):
261 261 content = dict(t=datetime.now())
262 262 metadata = dict(t=datetime.now())
263 263 p = session.msg('msg')
264 264 msg = session.msg('msg', content=content, metadata=metadata, parent=p['header'])
265 265 smsg = session.serialize(msg)
266 266 msg2 = session.unserialize(session.feed_identities(smsg)[1])
267 267 assert isinstance(msg2['header']['date'], datetime)
268 268 self.assertEqual(msg['header'], msg2['header'])
269 269 self.assertEqual(msg['parent_header'], msg2['parent_header'])
270 270 self.assertEqual(msg['parent_header'], msg2['parent_header'])
271 271 assert isinstance(msg['content']['t'], datetime)
272 272 assert isinstance(msg['metadata']['t'], datetime)
273 273 assert isinstance(msg2['content']['t'], string_types)
274 274 assert isinstance(msg2['metadata']['t'], string_types)
275 275 self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
276 276 self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
277 277
278 278 def test_datetimes(self):
279 279 self._datetime_test(self.session)
280 280
281 281 def test_datetimes_pickle(self):
282 282 session = ss.Session(packer='pickle')
283 283 self._datetime_test(session)
284 284
285 285 @skipif(module_not_available('msgpack'))
286 286 def test_datetimes_msgpack(self):
287 287 session = ss.Session(packer='msgpack.packb', unpacker='msgpack.unpackb')
288 288 self._datetime_test(session)
289 289
290 def test_send_raw(self):
291 ctx = zmq.Context.instance()
292 A = ctx.socket(zmq.PAIR)
293 B = ctx.socket(zmq.PAIR)
294 A.bind("inproc://test")
295 B.connect("inproc://test")
296
297 msg = self.session.msg('execute', content=dict(a=10))
298 msg_list = [self.session.pack(msg[part]) for part in
299 ['header', 'parent_header', 'metadata', 'content']]
300 self.session.send_raw(A, msg_list, ident=b'foo')
301
302 ident, new_msg_list = self.session.feed_identities(B.recv_multipart())
303 new_msg = self.session.unserialize(new_msg_list)
304 self.assertEqual(ident[0], b'foo')
305 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
306 self.assertEqual(new_msg['header'],msg['header'])
307 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
308 self.assertEqual(new_msg['content'],msg['content'])
309 self.assertEqual(new_msg['metadata'],msg['metadata'])
310
311 A.close()
312 B.close()
313 ctx.term()
General Comments 0
You need to be logged in to leave comments. Login now