##// END OF EJS Templates
Merge branch 'sessionwork'
Brian E. Granger -
r4517:8de5e115 merge
parent child Browse files
Show More
@@ -96,7 +96,7 b' class BaseFrontendMixin(object):'
96 96 """ Calls the frontend handler associated with the message type of the
97 97 given message.
98 98 """
99 msg_type = msg['msg_type']
99 msg_type = msg['header']['msg_type']
100 100 handler = getattr(self, '_handle_' + msg_type, None)
101 101 if handler:
102 102 handler(msg)
@@ -66,7 +66,7 b' class QtShellSocketChannel(SocketChannelQObject, ShellSocketChannel):'
66 66 self.message_received.emit(msg)
67 67
68 68 # Emit signals for specialized message types.
69 msg_type = msg['msg_type']
69 msg_type = msg['header']['msg_type']
70 70 signal = getattr(self, msg_type, None)
71 71 if signal:
72 72 signal.emit(msg)
@@ -122,7 +122,7 b' class QtSubSocketChannel(SocketChannelQObject, SubSocketChannel):'
122 122 # Emit the generic signal.
123 123 self.message_received.emit(msg)
124 124 # Emit signals for specialized message types.
125 msg_type = msg['msg_type']
125 msg_type = msg['header']['msg_type']
126 126 signal = getattr(self, msg_type + '_received', None)
127 127 if signal:
128 128 signal.emit(msg)
@@ -155,7 +155,7 b' class QtStdInSocketChannel(SocketChannelQObject, StdInSocketChannel):'
155 155 self.message_received.emit(msg)
156 156
157 157 # Emit signals for specialized message types.
158 msg_type = msg['msg_type']
158 msg_type = msg['header']['msg_type']
159 159 if msg_type == 'input_request':
160 160 self.input_requested.emit(msg)
161 161
@@ -670,7 +670,7 b' class Client(HasTraits):'
670 670 while msg is not None:
671 671 if self.debug:
672 672 pprint(msg)
673 msg_type = msg['msg_type']
673 msg_type = msg['header']['msg_type']
674 674 handler = self._notification_handlers.get(msg_type, None)
675 675 if handler is None:
676 676 raise Exception("Unhandled message type: %s"%msg.msg_type)
@@ -684,7 +684,7 b' class Client(HasTraits):'
684 684 while msg is not None:
685 685 if self.debug:
686 686 pprint(msg)
687 msg_type = msg['msg_type']
687 msg_type = msg['header']['msg_type']
688 688 handler = self._queue_handlers.get(msg_type, None)
689 689 if handler is None:
690 690 raise Exception("Unhandled message type: %s"%msg.msg_type)
@@ -729,7 +729,7 b' class Client(HasTraits):'
729 729 msg_id = parent['msg_id']
730 730 content = msg['content']
731 731 header = msg['header']
732 msg_type = msg['msg_type']
732 msg_type = msg['header']['msg_type']
733 733
734 734 # init metadata:
735 735 md = self.metadata[msg_id]
@@ -994,7 +994,7 b' class Client(HasTraits):'
994 994 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
995 995 subheader=subheader, track=track)
996 996
997 msg_id = msg['msg_id']
997 msg_id = msg['header']['msg_id']
998 998 self.outstanding.add(msg_id)
999 999 if ident:
1000 1000 # possibly routed to a specific engine
@@ -523,7 +523,7 b' class DirectView(View):'
523 523 ident=ident)
524 524 if track:
525 525 trackers.append(msg['tracker'])
526 msg_ids.append(msg['msg_id'])
526 msg_ids.append(msg['header']['msg_id'])
527 527 tracker = None if track is False else zmq.MessageTracker(*trackers)
528 528 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
529 529 if block:
@@ -980,7 +980,7 b' class LoadBalancedView(View):'
980 980 subheader=subheader)
981 981 tracker = None if track is False else msg['tracker']
982 982
983 ar = AsyncResult(self.client, msg['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
983 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
984 984
985 985 if block:
986 986 try:
@@ -485,7 +485,7 b' class Hub(SessionFactory):'
485 485 return
486 486 client_id = idents[0]
487 487 try:
488 msg = self.session.unpack_message(msg, content=True)
488 msg = self.session.unserialize(msg, content=True)
489 489 except Exception:
490 490 content = error.wrap_exception()
491 491 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
@@ -494,7 +494,7 b' class Hub(SessionFactory):'
494 494 return
495 495 # print client_id, header, parent, content
496 496 #switch on message type:
497 msg_type = msg['msg_type']
497 msg_type = msg['header']['msg_type']
498 498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
499 499 handler = self.query_handlers.get(msg_type, None)
500 500 try:
@@ -550,7 +550,7 b' class Hub(SessionFactory):'
550 550 return
551 551 queue_id, client_id = idents[:2]
552 552 try:
553 msg = self.session.unpack_message(msg)
553 msg = self.session.unserialize(msg)
554 554 except Exception:
555 555 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
556 556 return
@@ -597,7 +597,7 b' class Hub(SessionFactory):'
597 597
598 598 client_id, queue_id = idents[:2]
599 599 try:
600 msg = self.session.unpack_message(msg)
600 msg = self.session.unserialize(msg)
601 601 except Exception:
602 602 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
603 603 queue_id,client_id, msg), exc_info=True)
@@ -647,7 +647,7 b' class Hub(SessionFactory):'
647 647 client_id = idents[0]
648 648
649 649 try:
650 msg = self.session.unpack_message(msg)
650 msg = self.session.unserialize(msg)
651 651 except Exception:
652 652 self.log.error("task::client %r sent invalid task message: %r"%(
653 653 client_id, msg), exc_info=True)
@@ -697,7 +697,7 b' class Hub(SessionFactory):'
697 697 """save the result of a completed task."""
698 698 client_id = idents[0]
699 699 try:
700 msg = self.session.unpack_message(msg)
700 msg = self.session.unserialize(msg)
701 701 except Exception:
702 702 self.log.error("task::invalid task result message send to %r: %r"%(
703 703 client_id, msg), exc_info=True)
@@ -744,7 +744,7 b' class Hub(SessionFactory):'
744 744
745 745 def save_task_destination(self, idents, msg):
746 746 try:
747 msg = self.session.unpack_message(msg, content=True)
747 msg = self.session.unserialize(msg, content=True)
748 748 except Exception:
749 749 self.log.error("task::invalid task tracking message", exc_info=True)
750 750 return
@@ -781,7 +781,7 b' class Hub(SessionFactory):'
781 781 """save an iopub message into the db"""
782 782 # print (topics)
783 783 try:
784 msg = self.session.unpack_message(msg, content=True)
784 msg = self.session.unserialize(msg, content=True)
785 785 except Exception:
786 786 self.log.error("iopub::invalid IOPub message", exc_info=True)
787 787 return
@@ -791,7 +791,7 b' class Hub(SessionFactory):'
791 791 self.log.error("iopub::invalid IOPub message: %r"%msg)
792 792 return
793 793 msg_id = parent['msg_id']
794 msg_type = msg['msg_type']
794 msg_type = msg['header']['msg_type']
795 795 content = msg['content']
796 796
797 797 # ensure msg_id is in db
@@ -1165,7 +1165,7 b' class Hub(SessionFactory):'
1165 1165 msg = self.session.msg(header['msg_type'])
1166 1166 msg['content'] = rec['content']
1167 1167 msg['header'] = header
1168 msg['msg_id'] = rec['msg_id']
1168 msg['header']['msg_id'] = rec['msg_id']
1169 1169 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1170 1170
1171 1171 finish(dict(status='ok'))
@@ -211,12 +211,12 b' class TaskScheduler(SessionFactory):'
211 211 self.log.warn("task::Invalid Message: %r",msg)
212 212 return
213 213 try:
214 msg = self.session.unpack_message(msg)
214 msg = self.session.unserialize(msg)
215 215 except ValueError:
216 216 self.log.warn("task::Unauthorized message from: %r"%idents)
217 217 return
218 218
219 msg_type = msg['msg_type']
219 msg_type = msg['header']['msg_type']
220 220
221 221 handler = self._notification_handlers.get(msg_type, None)
222 222 if handler is None:
@@ -307,7 +307,7 b' class TaskScheduler(SessionFactory):'
307 307 self.notifier_stream.flush()
308 308 try:
309 309 idents, msg = self.session.feed_identities(raw_msg, copy=False)
310 msg = self.session.unpack_message(msg, content=False, copy=False)
310 msg = self.session.unserialize(msg, content=False, copy=False)
311 311 except Exception:
312 312 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
313 313 return
@@ -515,7 +515,7 b' class TaskScheduler(SessionFactory):'
515 515 """dispatch method for result replies"""
516 516 try:
517 517 idents,msg = self.session.feed_identities(raw_msg, copy=False)
518 msg = self.session.unpack_message(msg, content=False, copy=False)
518 msg = self.session.unserialize(msg, content=False, copy=False)
519 519 engine = idents[0]
520 520 try:
521 521 idx = self.targets.index(engine)
@@ -90,7 +90,7 b' class EngineFactory(RegistrationFactory):'
90 90 loop = self.loop
91 91 identity = self.bident
92 92 idents,msg = self.session.feed_identities(msg)
93 msg = Message(self.session.unpack_message(msg))
93 msg = Message(self.session.unserialize(msg))
94 94
95 95 if msg.content.status == 'ok':
96 96 self.id = int(msg.content.id)
@@ -40,11 +40,11 b' class KernelStarter(object):'
40 40 def dispatch_request(self, raw_msg):
41 41 idents, msg = self.session.feed_identities()
42 42 try:
43 msg = self.session.unpack_message(msg, content=False)
43 msg = self.session.unserialize(msg, content=False)
44 44 except:
45 45 print ("bad msg: %s"%msg)
46 46
47 msgtype = msg['msg_type']
47 msgtype = msg['header']['msg_type']
48 48 handler = self.handlers.get(msgtype, None)
49 49 if handler is None:
50 50 self.downstream.send_multipart(raw_msg, copy=False)
@@ -54,11 +54,11 b' class KernelStarter(object):'
54 54 def dispatch_reply(self, raw_msg):
55 55 idents, msg = self.session.feed_identities()
56 56 try:
57 msg = self.session.unpack_message(msg, content=False)
57 msg = self.session.unserialize(msg, content=False)
58 58 except:
59 59 print ("bad msg: %s"%msg)
60 60
61 msgtype = msg['msg_type']
61 msgtype = msg['header']['msg_type']
62 62 handler = self.handlers.get(msgtype, None)
63 63 if handler is None:
64 64 self.upstream.send_multipart(raw_msg, copy=False)
@@ -150,7 +150,7 b' class Kernel(SessionFactory):'
150 150
151 151 self.log.info("Aborting:")
152 152 self.log.info(str(msg))
153 msg_type = msg['msg_type']
153 msg_type = msg['header']['msg_type']
154 154 reply_type = msg_type.split('_')[0] + '_reply'
155 155 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
156 156 # self.reply_socket.send(ident,zmq.SNDMORE)
@@ -195,7 +195,7 b' class Kernel(SessionFactory):'
195 195 def dispatch_control(self, msg):
196 196 idents,msg = self.session.feed_identities(msg, copy=False)
197 197 try:
198 msg = self.session.unpack_message(msg, content=True, copy=False)
198 msg = self.session.unserialize(msg, content=True, copy=False)
199 199 except:
200 200 self.log.error("Invalid Message", exc_info=True)
201 201 return
@@ -204,10 +204,11 b' class Kernel(SessionFactory):'
204 204
205 205 header = msg['header']
206 206 msg_id = header['msg_id']
207 msg_type = header['msg_type']
207 208
208 handler = self.control_handlers.get(msg['msg_type'], None)
209 handler = self.control_handlers.get(msg_type, None)
209 210 if handler is None:
210 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
211 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg_type)
211 212 else:
212 213 handler(self.control_stream, idents, msg)
213 214
@@ -373,7 +374,7 b' class Kernel(SessionFactory):'
373 374 self.control_stream.flush()
374 375 idents,msg = self.session.feed_identities(msg, copy=False)
375 376 try:
376 msg = self.session.unpack_message(msg, content=True, copy=False)
377 msg = self.session.unserialize(msg, content=True, copy=False)
377 378 except:
378 379 self.log.error("Invalid Message", exc_info=True)
379 380 return
@@ -383,17 +384,18 b' class Kernel(SessionFactory):'
383 384
384 385 header = msg['header']
385 386 msg_id = header['msg_id']
387 msg_type = msg['header']['msg_type']
386 388 if self.check_aborted(msg_id):
387 389 self.aborted.remove(msg_id)
388 390 # is it safe to assume a msg_id will not be resubmitted?
389 reply_type = msg['msg_type'].split('_')[0] + '_reply'
391 reply_type = msg_type.split('_')[0] + '_reply'
390 392 status = {'status' : 'aborted'}
391 393 reply_msg = self.session.send(stream, reply_type, subheader=status,
392 394 content=status, parent=msg, ident=idents)
393 395 return
394 handler = self.shell_handlers.get(msg['msg_type'], None)
396 handler = self.shell_handlers.get(msg_type, None)
395 397 if handler is None:
396 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
398 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg_type)
397 399 else:
398 400 handler(stream, idents, msg)
399 401
@@ -56,8 +56,9 b' class TestDictBackend(TestCase):'
56 56 msg = self.session.msg('apply_request', content=dict(a=5))
57 57 msg['buffers'] = []
58 58 rec = init_record(msg)
59 msg_ids.append(msg['msg_id'])
60 self.db.add_record(msg['msg_id'], rec)
59 msg_id = msg['header']['msg_id']
60 msg_ids.append(msg_id)
61 self.db.add_record(msg_id, rec)
61 62 return msg_ids
62 63
63 64 def test_add_record(self):
@@ -125,6 +125,8 b' class Kernel(Configurable):'
125 125 if msg is None:
126 126 return
127 127
128 msg_type = msg['header']['msg_type']
129
128 130 # This assert will raise in versions of zeromq 2.0.7 and lesser.
129 131 # We now require 2.0.8 or above, so we can uncomment for safety.
130 132 # print(ident,msg, file=sys.__stdout__)
@@ -133,11 +135,11 b' class Kernel(Configurable):'
133 135 # Print some info about this message and leave a '--->' marker, so it's
134 136 # easier to trace visually the message chain when debugging. Each
135 137 # handler prints its message at the end.
136 self.log.debug('\n*** MESSAGE TYPE:'+str(msg['msg_type'])+'***')
138 self.log.debug('\n*** MESSAGE TYPE:'+str(msg_type)+'***')
137 139 self.log.debug(' Content: '+str(msg['content'])+'\n --->\n ')
138 140
139 141 # Find and call actual handler for message
140 handler = self.handlers.get(msg['msg_type'], None)
142 handler = self.handlers.get(msg_type, None)
141 143 if handler is None:
142 144 self.log.error("UNKNOWN MESSAGE TYPE:" +str(msg))
143 145 else:
@@ -375,7 +377,7 b' class Kernel(Configurable):'
375 377 "Unexpected missing message part."
376 378
377 379 self.log.debug("Aborting:\n"+str(Message(msg)))
378 msg_type = msg['msg_type']
380 msg_type = msg['header']['msg_type']
379 381 reply_type = msg_type.split('_')[0] + '_reply'
380 382 reply_msg = self.session.send(self.shell_socket, reply_type,
381 383 {'status' : 'aborted'}, msg, ident=ident)
@@ -190,7 +190,7 b' class Kernel(HasTraits):'
190 190 else:
191 191 assert ident is not None, "Missing message part."
192 192 self.log.debug("Aborting: %s"%Message(msg))
193 msg_type = msg['msg_type']
193 msg_type = msg['header']['msg_type']
194 194 reply_type = msg_type.split('_')[0] + '_reply'
195 195 reply_msg = self.session.send(self.shell_socket, reply_type, {'status':'aborted'}, msg, ident=ident)
196 196 self.log.debug(Message(reply_msg))
@@ -244,7 +244,7 b' class Session(Configurable):'
244 244 def _session_default(self):
245 245 return bytes(uuid.uuid4())
246 246
247 username = Unicode(os.environ.get('USER','username'), config=True,
247 username = Unicode(os.environ.get('USER',u'username'), config=True,
248 248 help="""Username for the Session. Default is your system username.""")
249 249
250 250 # message signature related traits:
@@ -350,18 +350,16 b' class Session(Configurable):'
350 350 def msg_header(self, msg_type):
351 351 return msg_header(self.msg_id, msg_type, self.username, self.session)
352 352
353 def msg(self, msg_type, content=None, parent=None, subheader=None):
353 def msg(self, msg_type, content=None, parent=None, subheader=None, header=None):
354 354 """Return the nested message dict.
355 355
356 356 This format is different from what is sent over the wire. The
357 self.serialize method converts this nested message dict to the wire
358 format, which uses a message list.
357 serialize/unserialize methods converts this nested message dict to the wire
358 format, which is a list of message parts.
359 359 """
360 360 msg = {}
361 msg['header'] = self.msg_header(msg_type)
362 msg['msg_id'] = msg['header']['msg_id']
361 msg['header'] = self.msg_header(msg_type) if header is None else header
363 362 msg['parent_header'] = {} if parent is None else extract_header(parent)
364 msg['msg_type'] = msg_type
365 363 msg['content'] = {} if content is None else content
366 364 sub = {} if subheader is None else subheader
367 365 msg['header'].update(sub)
@@ -385,6 +383,10 b' class Session(Configurable):'
385 383 def serialize(self, msg, ident=None):
386 384 """Serialize the message components to bytes.
387 385
386 This is roughly the inverse of unserialize. The serialize/unserialize
387 methods work with full message lists, whereas pack/unpack work with
388 the individual message parts in the message list.
389
388 390 Parameters
389 391 ----------
390 392 msg : dict or Message
@@ -435,7 +437,7 b' class Session(Configurable):'
435 437 return to_send
436 438
437 439 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
438 buffers=None, subheader=None, track=False):
440 buffers=None, subheader=None, track=False, header=None):
439 441 """Build and send a message via stream or socket.
440 442
441 443 The message format used by this function internally is as follows:
@@ -443,37 +445,42 b' class Session(Configurable):'
443 445 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
444 446 buffer1,buffer2,...]
445 447
446 The self.serialize method converts the nested message dict into this
448 The serialize/unserialize methods convert the nested message dict into this
447 449 format.
448 450
449 451 Parameters
450 452 ----------
451 453
452 454 stream : zmq.Socket or ZMQStream
453 the socket-like object used to send the data
455 The socket-like object used to send the data.
454 456 msg_or_type : str or Message/dict
455 457 Normally, msg_or_type will be a msg_type unless a message is being
456 sent more than once.
458 sent more than once. If a header is supplied, this can be set to
459 None and the msg_type will be pulled from the header.
457 460
458 461 content : dict or None
459 the content of the message (ignored if msg_or_type is a message)
462 The content of the message (ignored if msg_or_type is a message).
463 header : dict or None
464 The header dict for the message (ignores if msg_to_type is a message).
460 465 parent : Message or dict or None
461 the parent or parent header describing the parent of this message
466 The parent or parent header describing the parent of this message
467 (ignored if msg_or_type is a message).
462 468 ident : bytes or list of bytes
463 the zmq.IDENTITY routing path
469 The zmq.IDENTITY routing path.
464 470 subheader : dict or None
465 extra header keys for this message's header
471 Extra header keys for this message's header (ignored if msg_or_type
472 is a message).
466 473 buffers : list or None
467 the already-serialized buffers to be appended to the message
474 The already-serialized buffers to be appended to the message.
468 475 track : bool
469 whether to track. Only for use with Sockets,
470 because ZMQStream objects cannot track messages.
476 Whether to track. Only for use with Sockets, because ZMQStream
477 objects cannot track messages.
471 478
472 479 Returns
473 480 -------
474 msg : message dict
475 the constructed message
476 (msg,tracker) : (message dict, MessageTracker)
481 msg : dict
482 The constructed message.
483 (msg,tracker) : (dict, MessageTracker)
477 484 if track=True, then a 2-tuple will be returned,
478 485 the first element being the constructed
479 486 message, and the second being the MessageTracker
@@ -486,11 +493,12 b' class Session(Configurable):'
486 493 raise TypeError("ZMQStream cannot track messages")
487 494
488 495 if isinstance(msg_or_type, (Message, dict)):
489 # we got a Message, not a msg_type
490 # don't build a new Message
496 # We got a Message or message dict, not a msg_type so don't
497 # build a new Message.
491 498 msg = msg_or_type
492 499 else:
493 msg = self.msg(msg_or_type, content, parent, subheader)
500 msg = self.msg(msg_or_type, content=content, parent=parent,
501 subheader=subheader, header=header)
494 502
495 503 buffers = [] if buffers is None else buffers
496 504 to_send = self.serialize(msg, ident)
@@ -578,7 +586,7 b' class Session(Configurable):'
578 586 # invalid large messages can cause very expensive string comparisons
579 587 idents, msg_list = self.feed_identities(msg_list, copy)
580 588 try:
581 return idents, self.unpack_message(msg_list, content=content, copy=copy)
589 return idents, self.unserialize(msg_list, content=content, copy=copy)
582 590 except Exception as e:
583 591 print (idents, msg_list)
584 592 # TODO: handle it
@@ -601,9 +609,11 b' class Session(Configurable):'
601 609 Returns
602 610 -------
603 611 (idents,msg_list) : two lists
604 idents will always be a list of bytes - the indentity prefix
605 msg_list will be a list of bytes or Messages, unchanged from input
606 msg_list should be unpackable via self.unpack_message at this point.
612 idents will always be a list of bytes, each of which is a ZMQ
613 identity. msg_list will be a list of bytes or zmq.Messages of the
614 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
615 should be unpackable/unserializable via self.unserialize at this
616 point.
607 617 """
608 618 if copy:
609 619 idx = msg_list.index(DELIM)
@@ -619,21 +629,30 b' class Session(Configurable):'
619 629 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
620 630 return [m.bytes for m in idents], msg_list
621 631
622 def unpack_message(self, msg_list, content=True, copy=True):
623 """Return a message object from the format
624 sent by self.send.
632 def unserialize(self, msg_list, content=True, copy=True):
633 """Unserialize a msg_list to a nested message dict.
634
635 This is roughly the inverse of serialize. The serialize/unserialize
636 methods work with full message lists, whereas pack/unpack work with
637 the individual message parts in the message list.
625 638
626 639 Parameters:
627 640 -----------
628
641 msg_list : list of bytes or Message objects
642 The list of message parts of the form [HMAC,p_header,p_parent,
643 p_content,buffer1,buffer2,...].
629 644 content : bool (True)
630 whether to unpack the content dict (True),
631 or leave it serialized (False)
632
645 Whether to unpack the content dict (True), or leave it packed
646 (False).
633 647 copy : bool (True)
634 whether to return the bytes (True),
635 or the non-copying Message object in each place (False)
648 Whether to return the bytes (True), or the non-copying Message
649 object in each place (False).
636 650
651 Returns
652 -------
653 msg : dict
654 The nested message dict with top-level keys [header, parent_header,
655 content, buffers].
637 656 """
638 657 minlen = 4
639 658 message = {}
@@ -651,7 +670,6 b' class Session(Configurable):'
651 670 if not len(msg_list) >= minlen:
652 671 raise TypeError("malformed message, must have at least %i elements"%minlen)
653 672 message['header'] = self.unpack(msg_list[1])
654 message['msg_type'] = message['header']['msg_type']
655 673 message['parent_header'] = self.unpack(msg_list[2])
656 674 if content:
657 675 message['content'] = self.unpack(msg_list[3])
@@ -26,37 +26,102 b' class SessionTestCase(BaseZMQTestCase):'
26 26 BaseZMQTestCase.setUp(self)
27 27 self.session = ss.Session()
28 28
29
30 class MockSocket(zmq.Socket):
31
32 def __init__(self, *args, **kwargs):
33 super(MockSocket,self).__init__(*args,**kwargs)
34 self.data = []
35
36 def send_multipart(self, msgparts, *args, **kwargs):
37 self.data.extend(msgparts)
38
39 def send(self, part, *args, **kwargs):
40 self.data.append(part)
41
42 def recv_multipart(self, *args, **kwargs):
43 return self.data
44
29 45 class TestSession(SessionTestCase):
30 46
31 47 def test_msg(self):
32 48 """message format"""
33 49 msg = self.session.msg('execute')
34 thekeys = set('header msg_id parent_header msg_type content'.split())
50 thekeys = set('header parent_header content'.split())
35 51 s = set(msg.keys())
36 52 self.assertEquals(s, thekeys)
37 53 self.assertTrue(isinstance(msg['content'],dict))
38 54 self.assertTrue(isinstance(msg['header'],dict))
39 55 self.assertTrue(isinstance(msg['parent_header'],dict))
40 self.assertEquals(msg['msg_type'], 'execute')
41
42
56 self.assertEquals(msg['header']['msg_type'], 'execute')
57
58 def test_serialize(self):
59 msg = self.session.msg('execute',content=dict(a=10))
60 msg_list = self.session.serialize(msg, ident=b'foo')
61 ident, msg_list = self.session.feed_identities(msg_list)
62 new_msg = self.session.unserialize(msg_list)
63 self.assertEquals(ident[0], b'foo')
64 self.assertEquals(new_msg['header'],msg['header'])
65 self.assertEquals(new_msg['content'],msg['content'])
66 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
67
68 def test_send(self):
69 socket = MockSocket(zmq.Context.instance(),zmq.PAIR)
70
71 msg = self.session.msg('execute', content=dict(a=10))
72 self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
73 ident, msg_list = self.session.feed_identities(socket.data)
74 new_msg = self.session.unserialize(msg_list)
75 self.assertEquals(ident[0], b'foo')
76 self.assertEquals(new_msg['header'],msg['header'])
77 self.assertEquals(new_msg['content'],msg['content'])
78 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
79 self.assertEquals(new_msg['buffers'],[b'bar'])
80
81 socket.data = []
82
83 content = msg['content']
84 header = msg['header']
85 parent = msg['parent_header']
86 msg_type = header['msg_type']
87 self.session.send(socket, None, content=content, parent=parent,
88 header=header, ident=b'foo', buffers=[b'bar'])
89 ident, msg_list = self.session.feed_identities(socket.data)
90 new_msg = self.session.unserialize(msg_list)
91 self.assertEquals(ident[0], b'foo')
92 self.assertEquals(new_msg['header'],msg['header'])
93 self.assertEquals(new_msg['content'],msg['content'])
94 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
95 self.assertEquals(new_msg['buffers'],[b'bar'])
96
97 socket.data = []
98
99 self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
100 ident, new_msg = self.session.recv(socket)
101 self.assertEquals(ident[0], b'foo')
102 self.assertEquals(new_msg['header'],msg['header'])
103 self.assertEquals(new_msg['content'],msg['content'])
104 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
105 self.assertEquals(new_msg['buffers'],[b'bar'])
106
107 socket.close()
43 108
44 109 def test_args(self):
45 110 """initialization arguments for Session"""
46 111 s = self.session
47 112 self.assertTrue(s.pack is ss.default_packer)
48 113 self.assertTrue(s.unpack is ss.default_unpacker)
49 self.assertEquals(s.username, os.environ.get('USER', 'username'))
114 self.assertEquals(s.username, os.environ.get('USER', u'username'))
50 115
51 116 s = ss.Session()
52 self.assertEquals(s.username, os.environ.get('USER', 'username'))
117 self.assertEquals(s.username, os.environ.get('USER', u'username'))
53 118
54 119 self.assertRaises(TypeError, ss.Session, pack='hi')
55 120 self.assertRaises(TypeError, ss.Session, unpack='hi')
56 121 u = str(uuid.uuid4())
57 s = ss.Session(username='carrot', session=u)
122 s = ss.Session(username=u'carrot', session=u)
58 123 self.assertEquals(s.session, u)
59 self.assertEquals(s.username, 'carrot')
124 self.assertEquals(s.username, u'carrot')
60 125
61 126 def test_tracking(self):
62 127 """test tracking messages"""
@@ -109,3 +174,4 b' class TestSession(SessionTestCase):'
109 174 content = dict(code='whoda',stuff=object())
110 175 themsg = self.session.msg('execute',content=content)
111 176 pmsg = theids
177
@@ -101,18 +101,18 b' generic structure::'
101 101 # collaborative settings where multiple users may be interacting with the
102 102 # same kernel simultaneously, so that frontends can label the various
103 103 # messages in a meaningful way.
104 'header' : { 'msg_id' : uuid,
104 'header' : {
105 'msg_id' : uuid,
105 106 'username' : str,
106 107 'session' : uuid
108 # All recognized message type strings are listed below.
109 'msg_type' : str,
107 110 },
108 111
109 112 # In a chain of messages, the header from the parent is copied so that
110 113 # clients can track where messages come from.
111 114 'parent_header' : dict,
112 115
113 # All recognized message type strings are listed below.
114 'msg_type' : str,
115
116 116 # The actual content of the message must be a dict, whose structure
117 117 # depends on the message type.x
118 118 'content' : dict,
General Comments 0
You need to be logged in to leave comments. Login now