##// END OF EJS Templates
cull Session digest history...
MinRK -
Show More
@@ -1,773 +1,806 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9
10 10 Authors:
11 11
12 12 * Min RK
13 13 * Brian Granger
14 14 * Fernando Perez
15 15 """
16 16 #-----------------------------------------------------------------------------
17 17 # Copyright (C) 2010-2011 The IPython Development Team
18 18 #
19 19 # Distributed under the terms of the BSD License. The full license is in
20 20 # the file COPYING, distributed as part of this software.
21 21 #-----------------------------------------------------------------------------
22 22
23 23 #-----------------------------------------------------------------------------
24 24 # Imports
25 25 #-----------------------------------------------------------------------------
26 26
27 27 import hmac
28 28 import logging
29 29 import os
30 30 import pprint
31 import random
31 32 import uuid
32 33 from datetime import datetime
33 34
34 35 try:
35 36 import cPickle
36 37 pickle = cPickle
37 38 except:
38 39 cPickle = None
39 40 import pickle
40 41
41 42 import zmq
42 43 from zmq.utils import jsonapi
43 44 from zmq.eventloop.ioloop import IOLoop
44 45 from zmq.eventloop.zmqstream import ZMQStream
45 46
46 47 from IPython.config.application import Application, boolean_flag
47 48 from IPython.config.configurable import Configurable, LoggingConfigurable
48 49 from IPython.utils import io
49 50 from IPython.utils.importstring import import_item
50 51 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
51 52 from IPython.utils.py3compat import str_to_bytes
52 53 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
53 54 DottedObjectName, CUnicode, Dict, Integer)
54 55 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
55 56
56 57 #-----------------------------------------------------------------------------
57 58 # utility functions
58 59 #-----------------------------------------------------------------------------
59 60
60 61 def squash_unicode(obj):
61 62 """coerce unicode back to bytestrings."""
62 63 if isinstance(obj,dict):
63 64 for key in obj.keys():
64 65 obj[key] = squash_unicode(obj[key])
65 66 if isinstance(key, unicode):
66 67 obj[squash_unicode(key)] = obj.pop(key)
67 68 elif isinstance(obj, list):
68 69 for i,v in enumerate(obj):
69 70 obj[i] = squash_unicode(v)
70 71 elif isinstance(obj, unicode):
71 72 obj = obj.encode('utf8')
72 73 return obj
73 74
74 75 #-----------------------------------------------------------------------------
75 76 # globals and defaults
76 77 #-----------------------------------------------------------------------------
77 78
78 79 # ISO8601-ify datetime objects
79 80 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
80 81 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
81 82
82 83 pickle_packer = lambda o: pickle.dumps(o,-1)
83 84 pickle_unpacker = pickle.loads
84 85
85 86 default_packer = json_packer
86 87 default_unpacker = json_unpacker
87 88
88 89 DELIM = b"<IDS|MSG>"
89 90 # singleton dummy tracker, which will always report as done
90 91 DONE = zmq.MessageTracker()
91 92
92 93 #-----------------------------------------------------------------------------
93 94 # Mixin tools for apps that use Sessions
94 95 #-----------------------------------------------------------------------------
95 96
96 97 session_aliases = dict(
97 98 ident = 'Session.session',
98 99 user = 'Session.username',
99 100 keyfile = 'Session.keyfile',
100 101 )
101 102
102 103 session_flags = {
103 104 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
104 105 'keyfile' : '' }},
105 106 """Use HMAC digests for authentication of messages.
106 107 Setting this flag will generate a new UUID to use as the HMAC key.
107 108 """),
108 109 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
109 110 """Don't authenticate messages."""),
110 111 }
111 112
112 113 def default_secure(cfg):
113 114 """Set the default behavior for a config environment to be secure.
114 115
115 116 If Session.key/keyfile have not been set, set Session.key to
116 117 a new random UUID.
117 118 """
118 119
119 120 if 'Session' in cfg:
120 121 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
121 122 return
122 123 # key/keyfile not specified, generate new UUID:
123 124 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
124 125
125 126
126 127 #-----------------------------------------------------------------------------
127 128 # Classes
128 129 #-----------------------------------------------------------------------------
129 130
130 131 class SessionFactory(LoggingConfigurable):
131 132 """The Base class for configurables that have a Session, Context, logger,
132 133 and IOLoop.
133 134 """
134 135
135 136 logname = Unicode('')
136 137 def _logname_changed(self, name, old, new):
137 138 self.log = logging.getLogger(new)
138 139
139 140 # not configurable:
140 141 context = Instance('zmq.Context')
141 142 def _context_default(self):
142 143 return zmq.Context.instance()
143 144
144 145 session = Instance('IPython.kernel.zmq.session.Session')
145 146
146 147 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
147 148 def _loop_default(self):
148 149 return IOLoop.instance()
149 150
150 151 def __init__(self, **kwargs):
151 152 super(SessionFactory, self).__init__(**kwargs)
152 153
153 154 if self.session is None:
154 155 # construct the session
155 156 self.session = Session(**kwargs)
156 157
157 158
158 159 class Message(object):
159 160 """A simple message object that maps dict keys to attributes.
160 161
161 162 A Message can be created from a dict and a dict from a Message instance
162 163 simply by calling dict(msg_obj)."""
163 164
164 165 def __init__(self, msg_dict):
165 166 dct = self.__dict__
166 167 for k, v in dict(msg_dict).iteritems():
167 168 if isinstance(v, dict):
168 169 v = Message(v)
169 170 dct[k] = v
170 171
171 172 # Having this iterator lets dict(msg_obj) work out of the box.
172 173 def __iter__(self):
173 174 return iter(self.__dict__.iteritems())
174 175
175 176 def __repr__(self):
176 177 return repr(self.__dict__)
177 178
178 179 def __str__(self):
179 180 return pprint.pformat(self.__dict__)
180 181
181 182 def __contains__(self, k):
182 183 return k in self.__dict__
183 184
184 185 def __getitem__(self, k):
185 186 return self.__dict__[k]
186 187
187 188
188 189 def msg_header(msg_id, msg_type, username, session):
189 190 date = datetime.now()
190 191 return locals()
191 192
192 193 def extract_header(msg_or_header):
193 194 """Given a message or header, return the header."""
194 195 if not msg_or_header:
195 196 return {}
196 197 try:
197 198 # See if msg_or_header is the entire message.
198 199 h = msg_or_header['header']
199 200 except KeyError:
200 201 try:
201 202 # See if msg_or_header is just the header
202 203 h = msg_or_header['msg_id']
203 204 except KeyError:
204 205 raise
205 206 else:
206 207 h = msg_or_header
207 208 if not isinstance(h, dict):
208 209 h = dict(h)
209 210 return h
210 211
211 212 class Session(Configurable):
212 213 """Object for handling serialization and sending of messages.
213 214
214 215 The Session object handles building messages and sending them
215 216 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
216 217 other over the network via Session objects, and only need to work with the
217 218 dict-based IPython message spec. The Session will handle
218 219 serialization/deserialization, security, and metadata.
219 220
220 221 Sessions support configurable serialiization via packer/unpacker traits,
221 222 and signing with HMAC digests via the key/keyfile traits.
222 223
223 224 Parameters
224 225 ----------
225 226
226 227 debug : bool
227 228 whether to trigger extra debugging statements
228 229 packer/unpacker : str : 'json', 'pickle' or import_string
229 230 importstrings for methods to serialize message parts. If just
230 231 'json' or 'pickle', predefined JSON and pickle packers will be used.
231 232 Otherwise, the entire importstring must be used.
232 233
233 234 The functions must accept at least valid JSON input, and output *bytes*.
234 235
235 236 For example, to use msgpack:
236 237 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
237 238 pack/unpack : callables
238 239 You can also set the pack/unpack callables for serialization directly.
239 240 session : bytes
240 241 the ID of this Session object. The default is to generate a new UUID.
241 242 username : unicode
242 243 username added to message headers. The default is to ask the OS.
243 244 key : bytes
244 245 The key used to initialize an HMAC signature. If unset, messages
245 246 will not be signed or checked.
246 247 keyfile : filepath
247 248 The file containing a key. If this is set, `key` will be initialized
248 249 to the contents of the file.
249 250
250 251 """
251 252
252 253 debug=Bool(False, config=True, help="""Debug output in the Session""")
253 254
254 255 packer = DottedObjectName('json',config=True,
255 256 help="""The name of the packer for serializing messages.
256 257 Should be one of 'json', 'pickle', or an import name
257 258 for a custom callable serializer.""")
258 259 def _packer_changed(self, name, old, new):
259 260 if new.lower() == 'json':
260 261 self.pack = json_packer
261 262 self.unpack = json_unpacker
262 263 self.unpacker = new
263 264 elif new.lower() == 'pickle':
264 265 self.pack = pickle_packer
265 266 self.unpack = pickle_unpacker
266 267 self.unpacker = new
267 268 else:
268 269 self.pack = import_item(str(new))
269 270
270 271 unpacker = DottedObjectName('json', config=True,
271 272 help="""The name of the unpacker for unserializing messages.
272 273 Only used with custom functions for `packer`.""")
273 274 def _unpacker_changed(self, name, old, new):
274 275 if new.lower() == 'json':
275 276 self.pack = json_packer
276 277 self.unpack = json_unpacker
277 278 self.packer = new
278 279 elif new.lower() == 'pickle':
279 280 self.pack = pickle_packer
280 281 self.unpack = pickle_unpacker
281 282 self.packer = new
282 283 else:
283 284 self.unpack = import_item(str(new))
284 285
285 286 session = CUnicode(u'', config=True,
286 287 help="""The UUID identifying this session.""")
287 288 def _session_default(self):
288 289 u = unicode(uuid.uuid4())
289 290 self.bsession = u.encode('ascii')
290 291 return u
291 292
292 293 def _session_changed(self, name, old, new):
293 294 self.bsession = self.session.encode('ascii')
294 295
295 296 # bsession is the session as bytes
296 297 bsession = CBytes(b'')
297 298
298 299 username = Unicode(os.environ.get('USER',u'username'), config=True,
299 300 help="""Username for the Session. Default is your system username.""")
300 301
301 302 metadata = Dict({}, config=True,
302 303 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
303 304
304 305 # message signature related traits:
305 306
306 307 key = CBytes(b'', config=True,
307 308 help="""execution key, for extra authentication.""")
308 309 def _key_changed(self, name, old, new):
309 310 if new:
310 311 self.auth = hmac.HMAC(new)
311 312 else:
312 313 self.auth = None
314
313 315 auth = Instance(hmac.HMAC)
316
314 317 digest_history = Set()
318 digest_history_size = Integer(2**16, config=True,
319 help="""The maximum number of digests to remember.
320
321 The digest history will be culled when it exceeds this value.
322 """
323 )
315 324
316 325 keyfile = Unicode('', config=True,
317 326 help="""path to file containing execution key.""")
318 327 def _keyfile_changed(self, name, old, new):
319 328 with open(new, 'rb') as f:
320 329 self.key = f.read().strip()
321 330
322 331 # for protecting against sends from forks
323 332 pid = Integer()
324 333
325 334 # serialization traits:
326 335
327 336 pack = Any(default_packer) # the actual packer function
328 337 def _pack_changed(self, name, old, new):
329 338 if not callable(new):
330 339 raise TypeError("packer must be callable, not %s"%type(new))
331 340
332 341 unpack = Any(default_unpacker) # the actual packer function
333 342 def _unpack_changed(self, name, old, new):
334 343 # unpacker is not checked - it is assumed to be
335 344 if not callable(new):
336 345 raise TypeError("unpacker must be callable, not %s"%type(new))
337 346
338 347 # thresholds:
339 348 copy_threshold = Integer(2**16, config=True,
340 349 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
341 350 buffer_threshold = Integer(MAX_BYTES, config=True,
342 351 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
343 352 item_threshold = Integer(MAX_ITEMS, config=True,
344 353 help="""The maximum number of items for a container to be introspected for custom serialization.
345 354 Containers larger than this are pickled outright.
346 355 """
347 356 )
348 357
349 358
350 359 def __init__(self, **kwargs):
351 360 """create a Session object
352 361
353 362 Parameters
354 363 ----------
355 364
356 365 debug : bool
357 366 whether to trigger extra debugging statements
358 367 packer/unpacker : str : 'json', 'pickle' or import_string
359 368 importstrings for methods to serialize message parts. If just
360 369 'json' or 'pickle', predefined JSON and pickle packers will be used.
361 370 Otherwise, the entire importstring must be used.
362 371
363 372 The functions must accept at least valid JSON input, and output
364 373 *bytes*.
365 374
366 375 For example, to use msgpack:
367 376 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
368 377 pack/unpack : callables
369 378 You can also set the pack/unpack callables for serialization
370 379 directly.
371 380 session : unicode (must be ascii)
372 381 the ID of this Session object. The default is to generate a new
373 382 UUID.
374 383 bsession : bytes
375 384 The session as bytes
376 385 username : unicode
377 386 username added to message headers. The default is to ask the OS.
378 387 key : bytes
379 388 The key used to initialize an HMAC signature. If unset, messages
380 389 will not be signed or checked.
381 390 keyfile : filepath
382 391 The file containing a key. If this is set, `key` will be
383 392 initialized to the contents of the file.
384 393 """
385 394 super(Session, self).__init__(**kwargs)
386 395 self._check_packers()
387 396 self.none = self.pack({})
388 397 # ensure self._session_default() if necessary, so bsession is defined:
389 398 self.session
390 399 self.pid = os.getpid()
391 400
392 401 @property
393 402 def msg_id(self):
394 403 """always return new uuid"""
395 404 return str(uuid.uuid4())
396 405
397 406 def _check_packers(self):
398 407 """check packers for binary data and datetime support."""
399 408 pack = self.pack
400 409 unpack = self.unpack
401 410
402 411 # check simple serialization
403 412 msg = dict(a=[1,'hi'])
404 413 try:
405 414 packed = pack(msg)
406 415 except Exception:
407 416 raise ValueError("packer could not serialize a simple message")
408 417
409 418 # ensure packed message is bytes
410 419 if not isinstance(packed, bytes):
411 420 raise ValueError("message packed to %r, but bytes are required"%type(packed))
412 421
413 422 # check that unpack is pack's inverse
414 423 try:
415 424 unpacked = unpack(packed)
416 425 except Exception:
417 426 raise ValueError("unpacker could not handle the packer's output")
418 427
419 428 # check datetime support
420 429 msg = dict(t=datetime.now())
421 430 try:
422 431 unpacked = unpack(pack(msg))
423 432 except Exception:
424 433 self.pack = lambda o: pack(squash_dates(o))
425 434 self.unpack = lambda s: extract_dates(unpack(s))
426 435
427 436 def msg_header(self, msg_type):
428 437 return msg_header(self.msg_id, msg_type, self.username, self.session)
429 438
430 439 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
431 440 """Return the nested message dict.
432 441
433 442 This format is different from what is sent over the wire. The
434 443 serialize/unserialize methods converts this nested message dict to the wire
435 444 format, which is a list of message parts.
436 445 """
437 446 msg = {}
438 447 header = self.msg_header(msg_type) if header is None else header
439 448 msg['header'] = header
440 449 msg['msg_id'] = header['msg_id']
441 450 msg['msg_type'] = header['msg_type']
442 451 msg['parent_header'] = {} if parent is None else extract_header(parent)
443 452 msg['content'] = {} if content is None else content
444 453 msg['metadata'] = self.metadata.copy()
445 454 if metadata is not None:
446 455 msg['metadata'].update(metadata)
447 456 return msg
448 457
449 458 def sign(self, msg_list):
450 459 """Sign a message with HMAC digest. If no auth, return b''.
451 460
452 461 Parameters
453 462 ----------
454 463 msg_list : list
455 464 The [p_header,p_parent,p_content] part of the message list.
456 465 """
457 466 if self.auth is None:
458 467 return b''
459 468 h = self.auth.copy()
460 469 for m in msg_list:
461 470 h.update(m)
462 471 return str_to_bytes(h.hexdigest())
463 472
464 473 def serialize(self, msg, ident=None):
465 474 """Serialize the message components to bytes.
466 475
467 476 This is roughly the inverse of unserialize. The serialize/unserialize
468 477 methods work with full message lists, whereas pack/unpack work with
469 478 the individual message parts in the message list.
470 479
471 480 Parameters
472 481 ----------
473 482 msg : dict or Message
474 483 The nexted message dict as returned by the self.msg method.
475 484
476 485 Returns
477 486 -------
478 487 msg_list : list
479 488 The list of bytes objects to be sent with the format:
480 489 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_metadata,p_content,
481 490 buffer1,buffer2,...]. In this list, the p_* entities are
482 491 the packed or serialized versions, so if JSON is used, these
483 492 are utf8 encoded JSON strings.
484 493 """
485 494 content = msg.get('content', {})
486 495 if content is None:
487 496 content = self.none
488 497 elif isinstance(content, dict):
489 498 content = self.pack(content)
490 499 elif isinstance(content, bytes):
491 500 # content is already packed, as in a relayed message
492 501 pass
493 502 elif isinstance(content, unicode):
494 503 # should be bytes, but JSON often spits out unicode
495 504 content = content.encode('utf8')
496 505 else:
497 506 raise TypeError("Content incorrect type: %s"%type(content))
498 507
499 508 real_message = [self.pack(msg['header']),
500 509 self.pack(msg['parent_header']),
501 510 self.pack(msg['metadata']),
502 511 content,
503 512 ]
504 513
505 514 to_send = []
506 515
507 516 if isinstance(ident, list):
508 517 # accept list of idents
509 518 to_send.extend(ident)
510 519 elif ident is not None:
511 520 to_send.append(ident)
512 521 to_send.append(DELIM)
513 522
514 523 signature = self.sign(real_message)
515 524 to_send.append(signature)
516 525
517 526 to_send.extend(real_message)
518 527
519 528 return to_send
520 529
521 530 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
522 531 buffers=None, track=False, header=None, metadata=None):
523 532 """Build and send a message via stream or socket.
524 533
525 534 The message format used by this function internally is as follows:
526 535
527 536 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
528 537 buffer1,buffer2,...]
529 538
530 539 The serialize/unserialize methods convert the nested message dict into this
531 540 format.
532 541
533 542 Parameters
534 543 ----------
535 544
536 545 stream : zmq.Socket or ZMQStream
537 546 The socket-like object used to send the data.
538 547 msg_or_type : str or Message/dict
539 548 Normally, msg_or_type will be a msg_type unless a message is being
540 549 sent more than once. If a header is supplied, this can be set to
541 550 None and the msg_type will be pulled from the header.
542 551
543 552 content : dict or None
544 553 The content of the message (ignored if msg_or_type is a message).
545 554 header : dict or None
546 555 The header dict for the message (ignored if msg_to_type is a message).
547 556 parent : Message or dict or None
548 557 The parent or parent header describing the parent of this message
549 558 (ignored if msg_or_type is a message).
550 559 ident : bytes or list of bytes
551 560 The zmq.IDENTITY routing path.
552 561 metadata : dict or None
553 562 The metadata describing the message
554 563 buffers : list or None
555 564 The already-serialized buffers to be appended to the message.
556 565 track : bool
557 566 Whether to track. Only for use with Sockets, because ZMQStream
558 567 objects cannot track messages.
559 568
560 569
561 570 Returns
562 571 -------
563 572 msg : dict
564 573 The constructed message.
565 574 """
566 575 if not isinstance(stream, zmq.Socket):
567 576 # ZMQStreams and dummy sockets do not support tracking.
568 577 track = False
569 578
570 579 if isinstance(msg_or_type, (Message, dict)):
571 580 # We got a Message or message dict, not a msg_type so don't
572 581 # build a new Message.
573 582 msg = msg_or_type
574 583 else:
575 584 msg = self.msg(msg_or_type, content=content, parent=parent,
576 585 header=header, metadata=metadata)
577 586 if not os.getpid() == self.pid:
578 587 io.rprint("WARNING: attempted to send message from fork")
579 588 io.rprint(msg)
580 589 return
581 590 buffers = [] if buffers is None else buffers
582 591 to_send = self.serialize(msg, ident)
583 592 to_send.extend(buffers)
584 593 longest = max([ len(s) for s in to_send ])
585 594 copy = (longest < self.copy_threshold)
586 595
587 596 if buffers and track and not copy:
588 597 # only really track when we are doing zero-copy buffers
589 598 tracker = stream.send_multipart(to_send, copy=False, track=True)
590 599 else:
591 600 # use dummy tracker, which will be done immediately
592 601 tracker = DONE
593 602 stream.send_multipart(to_send, copy=copy)
594 603
595 604 if self.debug:
596 605 pprint.pprint(msg)
597 606 pprint.pprint(to_send)
598 607 pprint.pprint(buffers)
599 608
600 609 msg['tracker'] = tracker
601 610
602 611 return msg
603 612
604 613 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
605 614 """Send a raw message via ident path.
606 615
607 616 This method is used to send a already serialized message.
608 617
609 618 Parameters
610 619 ----------
611 620 stream : ZMQStream or Socket
612 621 The ZMQ stream or socket to use for sending the message.
613 622 msg_list : list
614 623 The serialized list of messages to send. This only includes the
615 624 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
616 625 the message.
617 626 ident : ident or list
618 627 A single ident or a list of idents to use in sending.
619 628 """
620 629 to_send = []
621 630 if isinstance(ident, bytes):
622 631 ident = [ident]
623 632 if ident is not None:
624 633 to_send.extend(ident)
625 634
626 635 to_send.append(DELIM)
627 636 to_send.append(self.sign(msg_list))
628 637 to_send.extend(msg_list)
629 638 stream.send_multipart(msg_list, flags, copy=copy)
630 639
631 640 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
632 641 """Receive and unpack a message.
633 642
634 643 Parameters
635 644 ----------
636 645 socket : ZMQStream or Socket
637 646 The socket or stream to use in receiving.
638 647
639 648 Returns
640 649 -------
641 650 [idents], msg
642 651 [idents] is a list of idents and msg is a nested message dict of
643 652 same format as self.msg returns.
644 653 """
645 654 if isinstance(socket, ZMQStream):
646 655 socket = socket.socket
647 656 try:
648 657 msg_list = socket.recv_multipart(mode, copy=copy)
649 658 except zmq.ZMQError as e:
650 659 if e.errno == zmq.EAGAIN:
651 660 # We can convert EAGAIN to None as we know in this case
652 661 # recv_multipart won't return None.
653 662 return None,None
654 663 else:
655 664 raise
656 665 # split multipart message into identity list and message dict
657 666 # invalid large messages can cause very expensive string comparisons
658 667 idents, msg_list = self.feed_identities(msg_list, copy)
659 668 try:
660 669 return idents, self.unserialize(msg_list, content=content, copy=copy)
661 670 except Exception as e:
662 671 # TODO: handle it
663 672 raise e
664 673
665 674 def feed_identities(self, msg_list, copy=True):
666 675 """Split the identities from the rest of the message.
667 676
668 677 Feed until DELIM is reached, then return the prefix as idents and
669 678 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
670 679 but that would be silly.
671 680
672 681 Parameters
673 682 ----------
674 683 msg_list : a list of Message or bytes objects
675 684 The message to be split.
676 685 copy : bool
677 686 flag determining whether the arguments are bytes or Messages
678 687
679 688 Returns
680 689 -------
681 690 (idents, msg_list) : two lists
682 691 idents will always be a list of bytes, each of which is a ZMQ
683 692 identity. msg_list will be a list of bytes or zmq.Messages of the
684 693 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
685 694 should be unpackable/unserializable via self.unserialize at this
686 695 point.
687 696 """
688 697 if copy:
689 698 idx = msg_list.index(DELIM)
690 699 return msg_list[:idx], msg_list[idx+1:]
691 700 else:
692 701 failed = True
693 702 for idx,m in enumerate(msg_list):
694 703 if m.bytes == DELIM:
695 704 failed = False
696 705 break
697 706 if failed:
698 707 raise ValueError("DELIM not in msg_list")
699 708 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
700 709 return [m.bytes for m in idents], msg_list
701 710
711 def _add_digest(self, signature):
712 """add a digest to history to protect against replay attacks"""
713 if self.digest_history_size == 0:
714 # no history, never add digests
715 return
716
717 self.digest_history.add(signature)
718 if len(self.digest_history) > self.digest_history_size:
719 # threshold reached, cull 10%
720 self._cull_digest_history()
721
722 def _cull_digest_history(self):
723 """cull the digest history
724
725 Removes a randomly selected 10% of the digest history
726 """
727 current = len(self.digest_history)
728 n_to_cull = max(int(current // 10), current - self.digest_history_size)
729 if n_to_cull >= current:
730 self.digest_history = set()
731 return
732 to_cull = random.sample(self.digest_history, n_to_cull)
733 self.digest_history.difference_update(to_cull)
734
702 735 def unserialize(self, msg_list, content=True, copy=True):
703 736 """Unserialize a msg_list to a nested message dict.
704 737
705 738 This is roughly the inverse of serialize. The serialize/unserialize
706 739 methods work with full message lists, whereas pack/unpack work with
707 740 the individual message parts in the message list.
708 741
709 742 Parameters:
710 743 -----------
711 744 msg_list : list of bytes or Message objects
712 745 The list of message parts of the form [HMAC,p_header,p_parent,
713 746 p_metadata,p_content,buffer1,buffer2,...].
714 747 content : bool (True)
715 748 Whether to unpack the content dict (True), or leave it packed
716 749 (False).
717 750 copy : bool (True)
718 751 Whether to return the bytes (True), or the non-copying Message
719 752 object in each place (False).
720 753
721 754 Returns
722 755 -------
723 756 msg : dict
724 757 The nested message dict with top-level keys [header, parent_header,
725 758 content, buffers].
726 759 """
727 760 minlen = 5
728 761 message = {}
729 762 if not copy:
730 763 for i in range(minlen):
731 764 msg_list[i] = msg_list[i].bytes
732 765 if self.auth is not None:
733 766 signature = msg_list[0]
734 767 if not signature:
735 768 raise ValueError("Unsigned Message")
736 769 if signature in self.digest_history:
737 raise ValueError("Duplicate Signature: %r"%signature)
738 self.digest_history.add(signature)
770 raise ValueError("Duplicate Signature: %r" % signature)
771 self._add_digest(signature)
739 772 check = self.sign(msg_list[1:5])
740 773 if not signature == check:
741 774 raise ValueError("Invalid Signature: %r" % signature)
742 775 if not len(msg_list) >= minlen:
743 776 raise TypeError("malformed message, must have at least %i elements"%minlen)
744 777 header = self.unpack(msg_list[1])
745 778 message['header'] = header
746 779 message['msg_id'] = header['msg_id']
747 780 message['msg_type'] = header['msg_type']
748 781 message['parent_header'] = self.unpack(msg_list[2])
749 782 message['metadata'] = self.unpack(msg_list[3])
750 783 if content:
751 784 message['content'] = self.unpack(msg_list[4])
752 785 else:
753 786 message['content'] = msg_list[4]
754 787
755 788 message['buffers'] = msg_list[5:]
756 789 return message
757 790
758 791 def test_msg2obj():
759 792 am = dict(x=1)
760 793 ao = Message(am)
761 794 assert ao.x == am['x']
762 795
763 796 am['y'] = dict(z=1)
764 797 ao = Message(am)
765 798 assert ao.y.z == am['y']['z']
766 799
767 800 k1, k2 = 'y', 'z'
768 801 assert ao[k1][k2] == am[k1][k2]
769 802
770 803 am2 = dict(ao)
771 804 assert am['x'] == am2['x']
772 805 assert am['y']['z'] == am2['y']['z']
773 806
@@ -1,207 +1,225 b''
1 1 """test building messages with streamsession"""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14 import os
15 15 import uuid
16 16 import zmq
17 17
18 18 from zmq.tests import BaseZMQTestCase
19 19 from zmq.eventloop.zmqstream import ZMQStream
20 20
21 21 from IPython.kernel.zmq import session as ss
22 22
23 23 class SessionTestCase(BaseZMQTestCase):
24 24
25 25 def setUp(self):
26 26 BaseZMQTestCase.setUp(self)
27 27 self.session = ss.Session()
28 28
29 29
30 30 class TestSession(SessionTestCase):
31 31
32 32 def test_msg(self):
33 33 """message format"""
34 34 msg = self.session.msg('execute')
35 35 thekeys = set('header parent_header metadata content msg_type msg_id'.split())
36 36 s = set(msg.keys())
37 37 self.assertEqual(s, thekeys)
38 38 self.assertTrue(isinstance(msg['content'],dict))
39 39 self.assertTrue(isinstance(msg['metadata'],dict))
40 40 self.assertTrue(isinstance(msg['header'],dict))
41 41 self.assertTrue(isinstance(msg['parent_header'],dict))
42 42 self.assertTrue(isinstance(msg['msg_id'],str))
43 43 self.assertTrue(isinstance(msg['msg_type'],str))
44 44 self.assertEqual(msg['header']['msg_type'], 'execute')
45 45 self.assertEqual(msg['msg_type'], 'execute')
46 46
47 47 def test_serialize(self):
48 48 msg = self.session.msg('execute', content=dict(a=10, b=1.1))
49 49 msg_list = self.session.serialize(msg, ident=b'foo')
50 50 ident, msg_list = self.session.feed_identities(msg_list)
51 51 new_msg = self.session.unserialize(msg_list)
52 52 self.assertEqual(ident[0], b'foo')
53 53 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
54 54 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
55 55 self.assertEqual(new_msg['header'],msg['header'])
56 56 self.assertEqual(new_msg['content'],msg['content'])
57 57 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
58 58 self.assertEqual(new_msg['metadata'],msg['metadata'])
59 59 # ensure floats don't come out as Decimal:
60 60 self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
61 61
62 62 def test_send(self):
63 63 ctx = zmq.Context.instance()
64 64 A = ctx.socket(zmq.PAIR)
65 65 B = ctx.socket(zmq.PAIR)
66 66 A.bind("inproc://test")
67 67 B.connect("inproc://test")
68 68
69 69 msg = self.session.msg('execute', content=dict(a=10))
70 70 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
71 71
72 72 ident, msg_list = self.session.feed_identities(B.recv_multipart())
73 73 new_msg = self.session.unserialize(msg_list)
74 74 self.assertEqual(ident[0], b'foo')
75 75 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
76 76 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
77 77 self.assertEqual(new_msg['header'],msg['header'])
78 78 self.assertEqual(new_msg['content'],msg['content'])
79 79 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
80 80 self.assertEqual(new_msg['metadata'],msg['metadata'])
81 81 self.assertEqual(new_msg['buffers'],[b'bar'])
82 82
83 83 content = msg['content']
84 84 header = msg['header']
85 85 parent = msg['parent_header']
86 86 metadata = msg['metadata']
87 87 msg_type = header['msg_type']
88 88 self.session.send(A, None, content=content, parent=parent,
89 89 header=header, metadata=metadata, ident=b'foo', buffers=[b'bar'])
90 90 ident, msg_list = self.session.feed_identities(B.recv_multipart())
91 91 new_msg = self.session.unserialize(msg_list)
92 92 self.assertEqual(ident[0], b'foo')
93 93 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
94 94 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
95 95 self.assertEqual(new_msg['header'],msg['header'])
96 96 self.assertEqual(new_msg['content'],msg['content'])
97 97 self.assertEqual(new_msg['metadata'],msg['metadata'])
98 98 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
99 99 self.assertEqual(new_msg['buffers'],[b'bar'])
100 100
101 101 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
102 102 ident, new_msg = self.session.recv(B)
103 103 self.assertEqual(ident[0], b'foo')
104 104 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
105 105 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
106 106 self.assertEqual(new_msg['header'],msg['header'])
107 107 self.assertEqual(new_msg['content'],msg['content'])
108 108 self.assertEqual(new_msg['metadata'],msg['metadata'])
109 109 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
110 110 self.assertEqual(new_msg['buffers'],[b'bar'])
111 111
112 112 A.close()
113 113 B.close()
114 114 ctx.term()
115 115
116 116 def test_args(self):
117 117 """initialization arguments for Session"""
118 118 s = self.session
119 119 self.assertTrue(s.pack is ss.default_packer)
120 120 self.assertTrue(s.unpack is ss.default_unpacker)
121 121 self.assertEqual(s.username, os.environ.get('USER', u'username'))
122 122
123 123 s = ss.Session()
124 124 self.assertEqual(s.username, os.environ.get('USER', u'username'))
125 125
126 126 self.assertRaises(TypeError, ss.Session, pack='hi')
127 127 self.assertRaises(TypeError, ss.Session, unpack='hi')
128 128 u = str(uuid.uuid4())
129 129 s = ss.Session(username=u'carrot', session=u)
130 130 self.assertEqual(s.session, u)
131 131 self.assertEqual(s.username, u'carrot')
132 132
133 133 def test_tracking(self):
134 134 """test tracking messages"""
135 135 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
136 136 s = self.session
137 137 s.copy_threshold = 1
138 138 stream = ZMQStream(a)
139 139 msg = s.send(a, 'hello', track=False)
140 140 self.assertTrue(msg['tracker'] is ss.DONE)
141 141 msg = s.send(a, 'hello', track=True)
142 142 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
143 143 M = zmq.Message(b'hi there', track=True)
144 144 msg = s.send(a, 'hello', buffers=[M], track=True)
145 145 t = msg['tracker']
146 146 self.assertTrue(isinstance(t, zmq.MessageTracker))
147 147 self.assertRaises(zmq.NotDone, t.wait, .1)
148 148 del M
149 149 t.wait(1) # this will raise
150 150
151 151
152 152 # def test_rekey(self):
153 153 # """rekeying dict around json str keys"""
154 154 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
155 155 # self.assertRaises(KeyError, ss.rekey, d)
156 156 #
157 157 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
158 158 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
159 159 # rd = ss.rekey(d)
160 160 # self.assertEqual(d2,rd)
161 161 #
162 162 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
163 163 # d2 = {1.5:d['1.5'],1:d['1']}
164 164 # rd = ss.rekey(d)
165 165 # self.assertEqual(d2,rd)
166 166 #
167 167 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
168 168 # self.assertRaises(KeyError, ss.rekey, d)
169 169 #
170 170 def test_unique_msg_ids(self):
171 171 """test that messages receive unique ids"""
172 172 ids = set()
173 173 for i in range(2**12):
174 174 h = self.session.msg_header('test')
175 175 msg_id = h['msg_id']
176 176 self.assertTrue(msg_id not in ids)
177 177 ids.add(msg_id)
178 178
179 179 def test_feed_identities(self):
180 180 """scrub the front for zmq IDENTITIES"""
181 181 theids = "engine client other".split()
182 182 content = dict(code='whoda',stuff=object())
183 183 themsg = self.session.msg('execute',content=content)
184 184 pmsg = theids
185 185
186 186 def test_session_id(self):
187 187 session = ss.Session()
188 188 # get bs before us
189 189 bs = session.bsession
190 190 us = session.session
191 191 self.assertEqual(us.encode('ascii'), bs)
192 192 session = ss.Session()
193 193 # get us before bs
194 194 us = session.session
195 195 bs = session.bsession
196 196 self.assertEqual(us.encode('ascii'), bs)
197 197 # change propagates:
198 198 session.session = 'something else'
199 199 bs = session.bsession
200 200 us = session.session
201 201 self.assertEqual(us.encode('ascii'), bs)
202 202 session = ss.Session(session='stuff')
203 203 # get us before bs
204 204 self.assertEqual(session.bsession, session.session.encode('ascii'))
205 205 self.assertEqual(b'stuff', session.bsession)
206 206
207 def test_zero_digest_history(self):
208 session = ss.Session(digest_history_size=0)
209 for i in range(11):
210 session._add_digest(uuid.uuid4().bytes)
211 self.assertEqual(len(session.digest_history), 0)
212
213 def test_cull_digest_history(self):
214 session = ss.Session(digest_history_size=100)
215 for i in range(100):
216 session._add_digest(uuid.uuid4().bytes)
217 self.assertTrue(len(session.digest_history) == 100)
218 session._add_digest(uuid.uuid4().bytes)
219 self.assertTrue(len(session.digest_history) == 91)
220 for i in range(9):
221 session._add_digest(uuid.uuid4().bytes)
222 self.assertTrue(len(session.digest_history) == 100)
223 session._add_digest(uuid.uuid4().bytes)
224 self.assertTrue(len(session.digest_history) == 91)
207 225
General Comments 0
You need to be logged in to leave comments. Login now