##// END OF EJS Templates
Adding tests for zmq.session.
Brian E. Granger -
Show More
@@ -1,696 +1,696 b''
1 1 #!/usr/bin/env python
2 2 """Session object for building, serializing, sending, and receiving messages in
3 3 IPython. The Session object supports serialization, HMAC signatures, and
4 4 metadata on messages.
5 5
6 6 Also defined here are utilities for working with Sessions:
7 7 * A SessionFactory to be used as a base class for configurables that work with
8 8 Sessions.
9 9 * A Message object for convenience that allows attribute-access to the msg dict.
10 10
11 11 Authors:
12 12
13 13 * Min RK
14 14 * Brian Granger
15 15 * Fernando Perez
16 16 """
17 17 #-----------------------------------------------------------------------------
18 18 # Copyright (C) 2010-2011 The IPython Development Team
19 19 #
20 20 # Distributed under the terms of the BSD License. The full license is in
21 21 # the file COPYING, distributed as part of this software.
22 22 #-----------------------------------------------------------------------------
23 23
24 24 #-----------------------------------------------------------------------------
25 25 # Imports
26 26 #-----------------------------------------------------------------------------
27 27
28 28 import hmac
29 29 import logging
30 30 import os
31 31 import pprint
32 32 import uuid
33 33 from datetime import datetime
34 34
35 35 try:
36 36 import cPickle
37 37 pickle = cPickle
38 38 except:
39 39 cPickle = None
40 40 import pickle
41 41
42 42 import zmq
43 43 from zmq.utils import jsonapi
44 44 from zmq.eventloop.ioloop import IOLoop
45 45 from zmq.eventloop.zmqstream import ZMQStream
46 46
47 47 from IPython.config.configurable import Configurable, LoggingConfigurable
48 48 from IPython.utils.importstring import import_item
49 49 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
50 50 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
51 51 DottedObjectName)
52 52
53 53 #-----------------------------------------------------------------------------
54 54 # utility functions
55 55 #-----------------------------------------------------------------------------
56 56
57 57 def squash_unicode(obj):
58 58 """coerce unicode back to bytestrings."""
59 59 if isinstance(obj,dict):
60 60 for key in obj.keys():
61 61 obj[key] = squash_unicode(obj[key])
62 62 if isinstance(key, unicode):
63 63 obj[squash_unicode(key)] = obj.pop(key)
64 64 elif isinstance(obj, list):
65 65 for i,v in enumerate(obj):
66 66 obj[i] = squash_unicode(v)
67 67 elif isinstance(obj, unicode):
68 68 obj = obj.encode('utf8')
69 69 return obj
70 70
71 71 #-----------------------------------------------------------------------------
72 72 # globals and defaults
73 73 #-----------------------------------------------------------------------------
74 74 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
75 75 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
76 76 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
77 77
78 78 pickle_packer = lambda o: pickle.dumps(o,-1)
79 79 pickle_unpacker = pickle.loads
80 80
81 81 default_packer = json_packer
82 82 default_unpacker = json_unpacker
83 83
84 84
85 85 DELIM=b"<IDS|MSG>"
86 86
87 87 #-----------------------------------------------------------------------------
88 88 # Classes
89 89 #-----------------------------------------------------------------------------
90 90
91 91 class SessionFactory(LoggingConfigurable):
92 92 """The Base class for configurables that have a Session, Context, logger,
93 93 and IOLoop.
94 94 """
95 95
96 96 logname = Unicode('')
97 97 def _logname_changed(self, name, old, new):
98 98 self.log = logging.getLogger(new)
99 99
100 100 # not configurable:
101 101 context = Instance('zmq.Context')
102 102 def _context_default(self):
103 103 return zmq.Context.instance()
104 104
105 105 session = Instance('IPython.zmq.session.Session')
106 106
107 107 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
108 108 def _loop_default(self):
109 109 return IOLoop.instance()
110 110
111 111 def __init__(self, **kwargs):
112 112 super(SessionFactory, self).__init__(**kwargs)
113 113
114 114 if self.session is None:
115 115 # construct the session
116 116 self.session = Session(**kwargs)
117 117
118 118
119 119 class Message(object):
120 120 """A simple message object that maps dict keys to attributes.
121 121
122 122 A Message can be created from a dict and a dict from a Message instance
123 123 simply by calling dict(msg_obj)."""
124 124
125 125 def __init__(self, msg_dict):
126 126 dct = self.__dict__
127 127 for k, v in dict(msg_dict).iteritems():
128 128 if isinstance(v, dict):
129 129 v = Message(v)
130 130 dct[k] = v
131 131
132 132 # Having this iterator lets dict(msg_obj) work out of the box.
133 133 def __iter__(self):
134 134 return iter(self.__dict__.iteritems())
135 135
136 136 def __repr__(self):
137 137 return repr(self.__dict__)
138 138
139 139 def __str__(self):
140 140 return pprint.pformat(self.__dict__)
141 141
142 142 def __contains__(self, k):
143 143 return k in self.__dict__
144 144
145 145 def __getitem__(self, k):
146 146 return self.__dict__[k]
147 147
148 148
149 149 def msg_header(msg_id, msg_type, username, session):
150 150 date = datetime.now()
151 151 return locals()
152 152
153 153 def extract_header(msg_or_header):
154 154 """Given a message or header, return the header."""
155 155 if not msg_or_header:
156 156 return {}
157 157 try:
158 158 # See if msg_or_header is the entire message.
159 159 h = msg_or_header['header']
160 160 except KeyError:
161 161 try:
162 162 # See if msg_or_header is just the header
163 163 h = msg_or_header['msg_id']
164 164 except KeyError:
165 165 raise
166 166 else:
167 167 h = msg_or_header
168 168 if not isinstance(h, dict):
169 169 h = dict(h)
170 170 return h
171 171
172 172 class Session(Configurable):
173 173 """Object for handling serialization and sending of messages.
174 174
175 175 The Session object handles building messages and sending them
176 176 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
177 177 other over the network via Session objects, and only need to work with the
178 178 dict-based IPython message spec. The Session will handle
179 179 serialization/deserialization, security, and metadata.
180 180
181 181 Sessions support configurable serialiization via packer/unpacker traits,
182 182 and signing with HMAC digests via the key/keyfile traits.
183 183
184 184 Parameters
185 185 ----------
186 186
187 187 debug : bool
188 188 whether to trigger extra debugging statements
189 189 packer/unpacker : str : 'json', 'pickle' or import_string
190 190 importstrings for methods to serialize message parts. If just
191 191 'json' or 'pickle', predefined JSON and pickle packers will be used.
192 192 Otherwise, the entire importstring must be used.
193 193
194 194 The functions must accept at least valid JSON input, and output *bytes*.
195 195
196 196 For example, to use msgpack:
197 197 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
198 198 pack/unpack : callables
199 199 You can also set the pack/unpack callables for serialization directly.
200 200 session : bytes
201 201 the ID of this Session object. The default is to generate a new UUID.
202 202 username : unicode
203 203 username added to message headers. The default is to ask the OS.
204 204 key : bytes
205 205 The key used to initialize an HMAC signature. If unset, messages
206 206 will not be signed or checked.
207 207 keyfile : filepath
208 208 The file containing a key. If this is set, `key` will be initialized
209 209 to the contents of the file.
210 210
211 211 """
212 212
213 213 debug=Bool(False, config=True, help="""Debug output in the Session""")
214 214
215 215 packer = DottedObjectName('json',config=True,
216 216 help="""The name of the packer for serializing messages.
217 217 Should be one of 'json', 'pickle', or an import name
218 218 for a custom callable serializer.""")
219 219 def _packer_changed(self, name, old, new):
220 220 if new.lower() == 'json':
221 221 self.pack = json_packer
222 222 self.unpack = json_unpacker
223 223 elif new.lower() == 'pickle':
224 224 self.pack = pickle_packer
225 225 self.unpack = pickle_unpacker
226 226 else:
227 227 self.pack = import_item(str(new))
228 228
229 229 unpacker = DottedObjectName('json', config=True,
230 230 help="""The name of the unpacker for unserializing messages.
231 231 Only used with custom functions for `packer`.""")
232 232 def _unpacker_changed(self, name, old, new):
233 233 if new.lower() == 'json':
234 234 self.pack = json_packer
235 235 self.unpack = json_unpacker
236 236 elif new.lower() == 'pickle':
237 237 self.pack = pickle_packer
238 238 self.unpack = pickle_unpacker
239 239 else:
240 240 self.unpack = import_item(str(new))
241 241
242 242 session = CBytes(b'', config=True,
243 243 help="""The UUID identifying this session.""")
244 244 def _session_default(self):
245 245 return bytes(uuid.uuid4())
246 246
247 247 username = Unicode(os.environ.get('USER','username'), config=True,
248 248 help="""Username for the Session. Default is your system username.""")
249 249
250 250 # message signature related traits:
251 251 key = CBytes(b'', config=True,
252 252 help="""execution key, for extra authentication.""")
253 253 def _key_changed(self, name, old, new):
254 254 if new:
255 255 self.auth = hmac.HMAC(new)
256 256 else:
257 257 self.auth = None
258 258 auth = Instance(hmac.HMAC)
259 259 digest_history = Set()
260 260
261 261 keyfile = Unicode('', config=True,
262 262 help="""path to file containing execution key.""")
263 263 def _keyfile_changed(self, name, old, new):
264 264 with open(new, 'rb') as f:
265 265 self.key = f.read().strip()
266 266
267 267 pack = Any(default_packer) # the actual packer function
268 268 def _pack_changed(self, name, old, new):
269 269 if not callable(new):
270 270 raise TypeError("packer must be callable, not %s"%type(new))
271 271
272 272 unpack = Any(default_unpacker) # the actual packer function
273 273 def _unpack_changed(self, name, old, new):
274 274 # unpacker is not checked - it is assumed to be
275 275 if not callable(new):
276 276 raise TypeError("unpacker must be callable, not %s"%type(new))
277 277
278 278 def __init__(self, **kwargs):
279 279 """create a Session object
280 280
281 281 Parameters
282 282 ----------
283 283
284 284 debug : bool
285 285 whether to trigger extra debugging statements
286 286 packer/unpacker : str : 'json', 'pickle' or import_string
287 287 importstrings for methods to serialize message parts. If just
288 288 'json' or 'pickle', predefined JSON and pickle packers will be used.
289 289 Otherwise, the entire importstring must be used.
290 290
291 291 The functions must accept at least valid JSON input, and output
292 292 *bytes*.
293 293
294 294 For example, to use msgpack:
295 295 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
296 296 pack/unpack : callables
297 297 You can also set the pack/unpack callables for serialization
298 298 directly.
299 299 session : bytes
300 300 the ID of this Session object. The default is to generate a new
301 301 UUID.
302 302 username : unicode
303 303 username added to message headers. The default is to ask the OS.
304 304 key : bytes
305 305 The key used to initialize an HMAC signature. If unset, messages
306 306 will not be signed or checked.
307 307 keyfile : filepath
308 308 The file containing a key. If this is set, `key` will be
309 309 initialized to the contents of the file.
310 310 """
311 311 super(Session, self).__init__(**kwargs)
312 312 self._check_packers()
313 313 self.none = self.pack({})
314 314
315 315 @property
316 316 def msg_id(self):
317 317 """always return new uuid"""
318 318 return str(uuid.uuid4())
319 319
320 320 def _check_packers(self):
321 321 """check packers for binary data and datetime support."""
322 322 pack = self.pack
323 323 unpack = self.unpack
324 324
325 325 # check simple serialization
326 326 msg = dict(a=[1,'hi'])
327 327 try:
328 328 packed = pack(msg)
329 329 except Exception:
330 330 raise ValueError("packer could not serialize a simple message")
331 331
332 332 # ensure packed message is bytes
333 333 if not isinstance(packed, bytes):
334 334 raise ValueError("message packed to %r, but bytes are required"%type(packed))
335 335
336 336 # check that unpack is pack's inverse
337 337 try:
338 338 unpacked = unpack(packed)
339 339 except Exception:
340 340 raise ValueError("unpacker could not handle the packer's output")
341 341
342 342 # check datetime support
343 343 msg = dict(t=datetime.now())
344 344 try:
345 345 unpacked = unpack(pack(msg))
346 346 except Exception:
347 347 self.pack = lambda o: pack(squash_dates(o))
348 348 self.unpack = lambda s: extract_dates(unpack(s))
349 349
350 350 def msg_header(self, msg_type):
351 351 return msg_header(self.msg_id, msg_type, self.username, self.session)
352 352
353 353 def msg(self, msg_type, content=None, parent=None, subheader=None, header=None):
354 354 """Return the nested message dict.
355 355
356 356 This format is different from what is sent over the wire. The
357 357 serialize/unserialize methods converts this nested message dict to the wire
358 358 format, which is a list of message parts.
359 359 """
360 360 msg = {}
361 361 msg['header'] = self.msg_header(msg_type) if header is None else header
362 362 msg['parent_header'] = {} if parent is None else extract_header(parent)
363 363 msg['content'] = {} if content is None else content
364 364 sub = {} if subheader is None else subheader
365 365 msg['header'].update(sub)
366 366 return msg
367 367
368 368 def sign(self, msg_list):
369 369 """Sign a message with HMAC digest. If no auth, return b''.
370 370
371 371 Parameters
372 372 ----------
373 373 msg_list : list
374 374 The [p_header,p_parent,p_content] part of the message list.
375 375 """
376 376 if self.auth is None:
377 377 return b''
378 378 h = self.auth.copy()
379 379 for m in msg_list:
380 380 h.update(m)
381 381 return h.hexdigest()
382 382
383 383 def serialize(self, msg, ident=None):
384 384 """Serialize the message components to bytes.
385 385
386 386 This is roughly the inverse of unserialize. The serialize/unserialize
387 387 methods work with full message lists, whereas pack/unpack work with
388 388 the individual message parts in the message list.
389 389
390 390 Parameters
391 391 ----------
392 392 msg : dict or Message
393 393 The nexted message dict as returned by the self.msg method.
394 394
395 395 Returns
396 396 -------
397 397 msg_list : list
398 398 The list of bytes objects to be sent with the format:
399 399 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
400 400 buffer1,buffer2,...]. In this list, the p_* entities are
401 401 the packed or serialized versions, so if JSON is used, these
402 402 are uft8 encoded JSON strings.
403 403 """
404 404 content = msg.get('content', {})
405 405 if content is None:
406 406 content = self.none
407 407 elif isinstance(content, dict):
408 408 content = self.pack(content)
409 409 elif isinstance(content, bytes):
410 410 # content is already packed, as in a relayed message
411 411 pass
412 412 elif isinstance(content, unicode):
413 413 # should be bytes, but JSON often spits out unicode
414 414 content = content.encode('utf8')
415 415 else:
416 416 raise TypeError("Content incorrect type: %s"%type(content))
417 417
418 418 real_message = [self.pack(msg['header']),
419 419 self.pack(msg['parent_header']),
420 420 content
421 421 ]
422 422
423 423 to_send = []
424 424
425 425 if isinstance(ident, list):
426 426 # accept list of idents
427 427 to_send.extend(ident)
428 428 elif ident is not None:
429 429 to_send.append(ident)
430 430 to_send.append(DELIM)
431 431
432 432 signature = self.sign(real_message)
433 433 to_send.append(signature)
434 434
435 435 to_send.extend(real_message)
436 436
437 437 return to_send
438 438
439 def send(self, stream, msg_or_type, content=None, parent=None, ident=None
439 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
440 440 buffers=None, subheader=None, track=False, header=None):
441 441 """Build and send a message via stream or socket.
442 442
443 443 The message format used by this function internally is as follows:
444 444
445 445 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
446 446 buffer1,buffer2,...]
447 447
448 448 The serialize/unserialize methods convert the nested message dict into this
449 449 format.
450 450
451 451 Parameters
452 452 ----------
453 453
454 454 stream : zmq.Socket or ZMQStream
455 455 The socket-like object used to send the data.
456 456 msg_or_type : str or Message/dict
457 457 Normally, msg_or_type will be a msg_type unless a message is being
458 458 sent more than once.
459 459
460 460 content : dict or None
461 461 The content of the message (ignored if msg_or_type is a message).
462 462 header : dict or None
463 463 The header dict for the message (ignores if msg_to_type is a message).
464 464 parent : Message or dict or None
465 465 The parent or parent header describing the parent of this message
466 466 (ignored if msg_or_type is a message).
467 467 ident : bytes or list of bytes
468 468 The zmq.IDENTITY routing path.
469 469 subheader : dict or None
470 470 Extra header keys for this message's header (ignored if msg_or_type
471 471 is a message).
472 472 buffers : list or None
473 473 The already-serialized buffers to be appended to the message.
474 474 track : bool
475 475 Whether to track. Only for use with Sockets, because ZMQStream
476 476 objects cannot track messages.
477 477
478 478 Returns
479 479 -------
480 480 msg : dict
481 481 The constructed message.
482 482 (msg,tracker) : (dict, MessageTracker)
483 483 if track=True, then a 2-tuple will be returned,
484 484 the first element being the constructed
485 485 message, and the second being the MessageTracker
486 486
487 487 """
488 488
489 489 if not isinstance(stream, (zmq.Socket, ZMQStream)):
490 490 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
491 491 elif track and isinstance(stream, ZMQStream):
492 492 raise TypeError("ZMQStream cannot track messages")
493 493
494 494 if isinstance(msg_or_type, (Message, dict)):
495 495 # We got a Message or message dict, not a msg_type so don't
496 496 # build a new Message.
497 497 msg = msg_or_type
498 498 else:
499 499 msg = self.msg(msg_or_type, content=content, parent=parent,
500 500 subheader=subheader, header=header)
501 501
502 502 buffers = [] if buffers is None else buffers
503 503 to_send = self.serialize(msg, ident)
504 504 flag = 0
505 505 if buffers:
506 506 flag = zmq.SNDMORE
507 507 _track = False
508 508 else:
509 509 _track=track
510 510 if track:
511 511 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
512 512 else:
513 513 tracker = stream.send_multipart(to_send, flag, copy=False)
514 514 for b in buffers[:-1]:
515 515 stream.send(b, flag, copy=False)
516 516 if buffers:
517 517 if track:
518 518 tracker = stream.send(buffers[-1], copy=False, track=track)
519 519 else:
520 520 tracker = stream.send(buffers[-1], copy=False)
521 521
522 522 # omsg = Message(msg)
523 523 if self.debug:
524 524 pprint.pprint(msg)
525 525 pprint.pprint(to_send)
526 526 pprint.pprint(buffers)
527 527
528 528 msg['tracker'] = tracker
529 529
530 530 return msg
531 531
532 532 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
533 533 """Send a raw message via ident path.
534 534
535 535 This method is used to send a already serialized message.
536 536
537 537 Parameters
538 538 ----------
539 539 stream : ZMQStream or Socket
540 540 The ZMQ stream or socket to use for sending the message.
541 541 msg_list : list
542 542 The serialized list of messages to send. This only includes the
543 543 [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
544 544 the message.
545 545 ident : ident or list
546 546 A single ident or a list of idents to use in sending.
547 547 """
548 548 to_send = []
549 549 if isinstance(ident, bytes):
550 550 ident = [ident]
551 551 if ident is not None:
552 552 to_send.extend(ident)
553 553
554 554 to_send.append(DELIM)
555 555 to_send.append(self.sign(msg_list))
556 556 to_send.extend(msg_list)
557 557 stream.send_multipart(msg_list, flags, copy=copy)
558 558
559 559 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
560 560 """Receive and unpack a message.
561 561
562 562 Parameters
563 563 ----------
564 564 socket : ZMQStream or Socket
565 565 The socket or stream to use in receiving.
566 566
567 567 Returns
568 568 -------
569 569 [idents], msg
570 570 [idents] is a list of idents and msg is a nested message dict of
571 571 same format as self.msg returns.
572 572 """
573 573 if isinstance(socket, ZMQStream):
574 574 socket = socket.socket
575 575 try:
576 576 msg_list = socket.recv_multipart(mode)
577 577 except zmq.ZMQError as e:
578 578 if e.errno == zmq.EAGAIN:
579 579 # We can convert EAGAIN to None as we know in this case
580 580 # recv_multipart won't return None.
581 581 return None,None
582 582 else:
583 583 raise
584 584 # split multipart message into identity list and message dict
585 585 # invalid large messages can cause very expensive string comparisons
586 586 idents, msg_list = self.feed_identities(msg_list, copy)
587 587 try:
588 588 return idents, self.unserialize(msg_list, content=content, copy=copy)
589 589 except Exception as e:
590 590 print (idents, msg_list)
591 591 # TODO: handle it
592 592 raise e
593 593
594 594 def feed_identities(self, msg_list, copy=True):
595 595 """Split the identities from the rest of the message.
596 596
597 597 Feed until DELIM is reached, then return the prefix as idents and
598 598 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
599 599 but that would be silly.
600 600
601 601 Parameters
602 602 ----------
603 603 msg_list : a list of Message or bytes objects
604 604 The message to be split.
605 605 copy : bool
606 606 flag determining whether the arguments are bytes or Messages
607 607
608 608 Returns
609 609 -------
610 610 (idents, msg_list) : two lists
611 611 idents will always be a list of bytes, each of which is a ZMQ
612 612 identity. msg_list will be a list of bytes or zmq.Messages of the
613 613 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
614 614 should be unpackable/unserializable via self.unserialize at this
615 615 point.
616 616 """
617 617 if copy:
618 618 idx = msg_list.index(DELIM)
619 619 return msg_list[:idx], msg_list[idx+1:]
620 620 else:
621 621 failed = True
622 622 for idx,m in enumerate(msg_list):
623 623 if m.bytes == DELIM:
624 624 failed = False
625 625 break
626 626 if failed:
627 627 raise ValueError("DELIM not in msg_list")
628 628 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
629 629 return [m.bytes for m in idents], msg_list
630 630
631 631 def unserialize(self, msg_list, content=True, copy=True):
632 632 """Unserialize a msg_list to a nested message dict.
633 633
634 634 This is roughly the inverse of serialize. The serialize/unserialize
635 635 methods work with full message lists, whereas pack/unpack work with
636 636 the individual message parts in the message list.
637 637
638 638 Parameters:
639 639 -----------
640 640 msg_list : list of bytes or Message objects
641 641 The list of message parts of the form [HMAC,p_header,p_parent,
642 642 p_content,buffer1,buffer2,...].
643 643 content : bool (True)
644 644 Whether to unpack the content dict (True), or leave it packed
645 645 (False).
646 646 copy : bool (True)
647 647 Whether to return the bytes (True), or the non-copying Message
648 648 object in each place (False).
649 649
650 650 Returns
651 651 -------
652 652 msg : dict
653 653 The nested message dict with top-level keys [header, parent_header,
654 654 content, buffers].
655 655 """
656 656 minlen = 4
657 657 message = {}
658 658 if not copy:
659 659 for i in range(minlen):
660 660 msg_list[i] = msg_list[i].bytes
661 661 if self.auth is not None:
662 662 signature = msg_list[0]
663 663 if signature in self.digest_history:
664 664 raise ValueError("Duplicate Signature: %r"%signature)
665 665 self.digest_history.add(signature)
666 666 check = self.sign(msg_list[1:4])
667 667 if not signature == check:
668 668 raise ValueError("Invalid Signature: %r"%signature)
669 669 if not len(msg_list) >= minlen:
670 670 raise TypeError("malformed message, must have at least %i elements"%minlen)
671 671 message['header'] = self.unpack(msg_list[1])
672 672 message['parent_header'] = self.unpack(msg_list[2])
673 673 if content:
674 674 message['content'] = self.unpack(msg_list[3])
675 675 else:
676 676 message['content'] = msg_list[3]
677 677
678 678 message['buffers'] = msg_list[4:]
679 679 return message
680 680
681 681 def test_msg2obj():
682 682 am = dict(x=1)
683 683 ao = Message(am)
684 684 assert ao.x == am['x']
685 685
686 686 am['y'] = dict(z=1)
687 687 ao = Message(am)
688 688 assert ao.y.z == am['y']['z']
689 689
690 690 k1, k2 = 'y', 'z'
691 691 assert ao[k1][k2] == am[k1][k2]
692 692
693 693 am2 = dict(ao)
694 694 assert am['x'] == am2['x']
695 695 assert am['y']['z'] == am2['y']['z']
696 696
@@ -1,109 +1,120 b''
1 1 """test building messages with streamsession"""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14 import os
15 15 import uuid
16 16 import zmq
17 17
18 18 from zmq.tests import BaseZMQTestCase
19 19 from zmq.eventloop.zmqstream import ZMQStream
20 20
21 21 from IPython.zmq import session as ss
22 22
23 23 class SessionTestCase(BaseZMQTestCase):
24 24
25 25 def setUp(self):
26 26 BaseZMQTestCase.setUp(self)
27 27 self.session = ss.Session()
28 28
29 29 class TestSession(SessionTestCase):
30 30
31 31 def test_msg(self):
32 32 """message format"""
33 33 msg = self.session.msg('execute')
34 thekeys = set('header msg_id parent_header msg_type content'.split())
34 thekeys = set('header parent_header content'.split())
35 35 s = set(msg.keys())
36 36 self.assertEquals(s, thekeys)
37 37 self.assertTrue(isinstance(msg['content'],dict))
38 38 self.assertTrue(isinstance(msg['header'],dict))
39 39 self.assertTrue(isinstance(msg['parent_header'],dict))
40 40 self.assertEquals(msg['header']['msg_type'], 'execute')
41
41
42 def test_serialize(self):
43 msg = self.session.msg('execute')
44 msg_list = self.session.serialize(msg, ident=b'foo')
45 ident, msg_list = self.session.feed_identities(msg_list)
46 new_msg = self.session.unserialize(msg_list)
47 self.assertEquals(ident[0], b'foo')
48 self.assertEquals(new_msg['header'],msg['header'])
49 self.assertEquals(new_msg['content'],msg['content'])
50 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
51
42 52 def test_args(self):
43 53 """initialization arguments for Session"""
44 54 s = self.session
45 55 self.assertTrue(s.pack is ss.default_packer)
46 56 self.assertTrue(s.unpack is ss.default_unpacker)
47 57 self.assertEquals(s.username, os.environ.get('USER', 'username'))
48 58
49 59 s = ss.Session()
50 60 self.assertEquals(s.username, os.environ.get('USER', 'username'))
51 61
52 62 self.assertRaises(TypeError, ss.Session, pack='hi')
53 63 self.assertRaises(TypeError, ss.Session, unpack='hi')
54 64 u = str(uuid.uuid4())
55 65 s = ss.Session(username='carrot', session=u)
56 66 self.assertEquals(s.session, u)
57 67 self.assertEquals(s.username, 'carrot')
58 68
59 69 def test_tracking(self):
60 70 """test tracking messages"""
61 71 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
62 72 s = self.session
63 73 stream = ZMQStream(a)
64 74 msg = s.send(a, 'hello', track=False)
65 75 self.assertTrue(msg['tracker'] is None)
66 76 msg = s.send(a, 'hello', track=True)
67 77 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
68 78 M = zmq.Message(b'hi there', track=True)
69 79 msg = s.send(a, 'hello', buffers=[M], track=True)
70 80 t = msg['tracker']
71 81 self.assertTrue(isinstance(t, zmq.MessageTracker))
72 82 self.assertRaises(zmq.NotDone, t.wait, .1)
73 83 del M
74 84 t.wait(1) # this will raise
75 85
76 86
77 87 # def test_rekey(self):
78 88 # """rekeying dict around json str keys"""
79 89 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
80 90 # self.assertRaises(KeyError, ss.rekey, d)
81 91 #
82 92 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
83 93 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
84 94 # rd = ss.rekey(d)
85 95 # self.assertEquals(d2,rd)
86 96 #
87 97 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
88 98 # d2 = {1.5:d['1.5'],1:d['1']}
89 99 # rd = ss.rekey(d)
90 100 # self.assertEquals(d2,rd)
91 101 #
92 102 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
93 103 # self.assertRaises(KeyError, ss.rekey, d)
94 104 #
95 105 def test_unique_msg_ids(self):
96 106 """test that messages receive unique ids"""
97 107 ids = set()
98 108 for i in range(2**12):
99 109 h = self.session.msg_header('test')
100 110 msg_id = h['msg_id']
101 111 self.assertTrue(msg_id not in ids)
102 112 ids.add(msg_id)
103 113
104 114 def test_feed_identities(self):
105 115 """scrub the front for zmq IDENTITIES"""
106 116 theids = "engine client other".split()
107 117 content = dict(code='whoda',stuff=object())
108 118 themsg = self.session.msg('execute',content=content)
109 119 pmsg = theids
120
General Comments 0
You need to be logged in to leave comments. Login now