From 8de5e1156f8f6eca69316990f76f6c0be9dd6be5 2011-08-10 04:14:56 From: Brian E. Granger Date: 2011-08-10 04:14:56 Subject: [PATCH] Merge branch 'sessionwork' --- diff --git a/IPython/frontend/qt/base_frontend_mixin.py b/IPython/frontend/qt/base_frontend_mixin.py index 02080b7..c4189d3 100644 --- a/IPython/frontend/qt/base_frontend_mixin.py +++ b/IPython/frontend/qt/base_frontend_mixin.py @@ -96,7 +96,7 @@ class BaseFrontendMixin(object): """ Calls the frontend handler associated with the message type of the given message. """ - msg_type = msg['msg_type'] + msg_type = msg['header']['msg_type'] handler = getattr(self, '_handle_' + msg_type, None) if handler: handler(msg) diff --git a/IPython/frontend/qt/kernelmanager.py b/IPython/frontend/qt/kernelmanager.py index c6a5dde..f243114 100644 --- a/IPython/frontend/qt/kernelmanager.py +++ b/IPython/frontend/qt/kernelmanager.py @@ -66,7 +66,7 @@ class QtShellSocketChannel(SocketChannelQObject, ShellSocketChannel): self.message_received.emit(msg) # Emit signals for specialized message types. - msg_type = msg['msg_type'] + msg_type = msg['header']['msg_type'] signal = getattr(self, msg_type, None) if signal: signal.emit(msg) @@ -122,7 +122,7 @@ class QtSubSocketChannel(SocketChannelQObject, SubSocketChannel): # Emit the generic signal. self.message_received.emit(msg) # Emit signals for specialized message types. - msg_type = msg['msg_type'] + msg_type = msg['header']['msg_type'] signal = getattr(self, msg_type + '_received', None) if signal: signal.emit(msg) @@ -155,7 +155,7 @@ class QtStdInSocketChannel(SocketChannelQObject, StdInSocketChannel): self.message_received.emit(msg) # Emit signals for specialized message types. - msg_type = msg['msg_type'] + msg_type = msg['header']['msg_type'] if msg_type == 'input_request': self.input_requested.emit(msg) diff --git a/IPython/parallel/client/client.py b/IPython/parallel/client/client.py index edd7269..846b8eb 100644 --- a/IPython/parallel/client/client.py +++ b/IPython/parallel/client/client.py @@ -670,7 +670,7 @@ class Client(HasTraits): while msg is not None: if self.debug: pprint(msg) - msg_type = msg['msg_type'] + msg_type = msg['header']['msg_type'] handler = self._notification_handlers.get(msg_type, None) if handler is None: raise Exception("Unhandled message type: %s"%msg.msg_type) @@ -684,7 +684,7 @@ class Client(HasTraits): while msg is not None: if self.debug: pprint(msg) - msg_type = msg['msg_type'] + msg_type = msg['header']['msg_type'] handler = self._queue_handlers.get(msg_type, None) if handler is None: raise Exception("Unhandled message type: %s"%msg.msg_type) @@ -729,7 +729,7 @@ class Client(HasTraits): msg_id = parent['msg_id'] content = msg['content'] header = msg['header'] - msg_type = msg['msg_type'] + msg_type = msg['header']['msg_type'] # init metadata: md = self.metadata[msg_id] @@ -994,7 +994,7 @@ class Client(HasTraits): msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident, subheader=subheader, track=track) - msg_id = msg['msg_id'] + msg_id = msg['header']['msg_id'] self.outstanding.add(msg_id) if ident: # possibly routed to a specific engine diff --git a/IPython/parallel/client/view.py b/IPython/parallel/client/view.py index 12ed96d..1b1eae6 100644 --- a/IPython/parallel/client/view.py +++ b/IPython/parallel/client/view.py @@ -523,7 +523,7 @@ class DirectView(View): ident=ident) if track: trackers.append(msg['tracker']) - msg_ids.append(msg['msg_id']) + msg_ids.append(msg['header']['msg_id']) tracker = None if track is False else zmq.MessageTracker(*trackers) ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker) if block: @@ -980,7 +980,7 @@ class LoadBalancedView(View): subheader=subheader) tracker = None if track is False else msg['tracker'] - ar = AsyncResult(self.client, msg['msg_id'], fname=f.__name__, targets=None, tracker=tracker) + ar = AsyncResult(self.client, msg['header']['msg_id'], fname=f.__name__, targets=None, tracker=tracker) if block: try: diff --git a/IPython/parallel/controller/hub.py b/IPython/parallel/controller/hub.py index 5a66178..326213a 100755 --- a/IPython/parallel/controller/hub.py +++ b/IPython/parallel/controller/hub.py @@ -485,7 +485,7 @@ class Hub(SessionFactory): return client_id = idents[0] try: - msg = self.session.unpack_message(msg, content=True) + msg = self.session.unserialize(msg, content=True) except Exception: content = error.wrap_exception() self.log.error("Bad Query Message: %r"%msg, exc_info=True) @@ -494,7 +494,7 @@ class Hub(SessionFactory): return # print client_id, header, parent, content #switch on message type: - msg_type = msg['msg_type'] + msg_type = msg['header']['msg_type'] self.log.info("client::client %r requested %r"%(client_id, msg_type)) handler = self.query_handlers.get(msg_type, None) try: @@ -550,7 +550,7 @@ class Hub(SessionFactory): return queue_id, client_id = idents[:2] try: - msg = self.session.unpack_message(msg) + msg = self.session.unserialize(msg) except Exception: self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True) return @@ -597,7 +597,7 @@ class Hub(SessionFactory): client_id, queue_id = idents[:2] try: - msg = self.session.unpack_message(msg) + msg = self.session.unserialize(msg) except Exception: self.log.error("queue::engine %r sent invalid message to %r: %r"%( queue_id,client_id, msg), exc_info=True) @@ -647,7 +647,7 @@ class Hub(SessionFactory): client_id = idents[0] try: - msg = self.session.unpack_message(msg) + msg = self.session.unserialize(msg) except Exception: self.log.error("task::client %r sent invalid task message: %r"%( client_id, msg), exc_info=True) @@ -697,7 +697,7 @@ class Hub(SessionFactory): """save the result of a completed task.""" client_id = idents[0] try: - msg = self.session.unpack_message(msg) + msg = self.session.unserialize(msg) except Exception: self.log.error("task::invalid task result message send to %r: %r"%( client_id, msg), exc_info=True) @@ -744,7 +744,7 @@ class Hub(SessionFactory): def save_task_destination(self, idents, msg): try: - msg = self.session.unpack_message(msg, content=True) + msg = self.session.unserialize(msg, content=True) except Exception: self.log.error("task::invalid task tracking message", exc_info=True) return @@ -781,7 +781,7 @@ class Hub(SessionFactory): """save an iopub message into the db""" # print (topics) try: - msg = self.session.unpack_message(msg, content=True) + msg = self.session.unserialize(msg, content=True) except Exception: self.log.error("iopub::invalid IOPub message", exc_info=True) return @@ -791,7 +791,7 @@ class Hub(SessionFactory): self.log.error("iopub::invalid IOPub message: %r"%msg) return msg_id = parent['msg_id'] - msg_type = msg['msg_type'] + msg_type = msg['header']['msg_type'] content = msg['content'] # ensure msg_id is in db @@ -1165,7 +1165,7 @@ class Hub(SessionFactory): msg = self.session.msg(header['msg_type']) msg['content'] = rec['content'] msg['header'] = header - msg['msg_id'] = rec['msg_id'] + msg['header']['msg_id'] = rec['msg_id'] self.session.send(self.resubmit, msg, buffers=rec['buffers']) finish(dict(status='ok')) diff --git a/IPython/parallel/controller/scheduler.py b/IPython/parallel/controller/scheduler.py index 747d5b6..d7e6da6 100644 --- a/IPython/parallel/controller/scheduler.py +++ b/IPython/parallel/controller/scheduler.py @@ -211,12 +211,12 @@ class TaskScheduler(SessionFactory): self.log.warn("task::Invalid Message: %r",msg) return try: - msg = self.session.unpack_message(msg) + msg = self.session.unserialize(msg) except ValueError: self.log.warn("task::Unauthorized message from: %r"%idents) return - msg_type = msg['msg_type'] + msg_type = msg['header']['msg_type'] handler = self._notification_handlers.get(msg_type, None) if handler is None: @@ -307,7 +307,7 @@ class TaskScheduler(SessionFactory): self.notifier_stream.flush() try: idents, msg = self.session.feed_identities(raw_msg, copy=False) - msg = self.session.unpack_message(msg, content=False, copy=False) + msg = self.session.unserialize(msg, content=False, copy=False) except Exception: self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True) return @@ -515,7 +515,7 @@ class TaskScheduler(SessionFactory): """dispatch method for result replies""" try: idents,msg = self.session.feed_identities(raw_msg, copy=False) - msg = self.session.unpack_message(msg, content=False, copy=False) + msg = self.session.unserialize(msg, content=False, copy=False) engine = idents[0] try: idx = self.targets.index(engine) diff --git a/IPython/parallel/engine/engine.py b/IPython/parallel/engine/engine.py index dd91f8c..04201e9 100755 --- a/IPython/parallel/engine/engine.py +++ b/IPython/parallel/engine/engine.py @@ -90,7 +90,7 @@ class EngineFactory(RegistrationFactory): loop = self.loop identity = self.bident idents,msg = self.session.feed_identities(msg) - msg = Message(self.session.unpack_message(msg)) + msg = Message(self.session.unserialize(msg)) if msg.content.status == 'ok': self.id = int(msg.content.id) diff --git a/IPython/parallel/engine/kernelstarter.py b/IPython/parallel/engine/kernelstarter.py index fe92d1d..4aa0238 100644 --- a/IPython/parallel/engine/kernelstarter.py +++ b/IPython/parallel/engine/kernelstarter.py @@ -40,11 +40,11 @@ class KernelStarter(object): def dispatch_request(self, raw_msg): idents, msg = self.session.feed_identities() try: - msg = self.session.unpack_message(msg, content=False) + msg = self.session.unserialize(msg, content=False) except: print ("bad msg: %s"%msg) - msgtype = msg['msg_type'] + msgtype = msg['header']['msg_type'] handler = self.handlers.get(msgtype, None) if handler is None: self.downstream.send_multipart(raw_msg, copy=False) @@ -54,11 +54,11 @@ class KernelStarter(object): def dispatch_reply(self, raw_msg): idents, msg = self.session.feed_identities() try: - msg = self.session.unpack_message(msg, content=False) + msg = self.session.unserialize(msg, content=False) except: print ("bad msg: %s"%msg) - msgtype = msg['msg_type'] + msgtype = msg['header']['msg_type'] handler = self.handlers.get(msgtype, None) if handler is None: self.upstream.send_multipart(raw_msg, copy=False) @@ -227,4 +227,4 @@ def make_starter(up_addr, down_addr, *args, **kwargs): starter = KernelStarter(session, upstream, downstream, *args, **kwargs) starter.start() loop.start() - \ No newline at end of file + diff --git a/IPython/parallel/engine/streamkernel.py b/IPython/parallel/engine/streamkernel.py index 5e6203b..a660985 100755 --- a/IPython/parallel/engine/streamkernel.py +++ b/IPython/parallel/engine/streamkernel.py @@ -150,7 +150,7 @@ class Kernel(SessionFactory): self.log.info("Aborting:") self.log.info(str(msg)) - msg_type = msg['msg_type'] + msg_type = msg['header']['msg_type'] reply_type = msg_type.split('_')[0] + '_reply' # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg) # self.reply_socket.send(ident,zmq.SNDMORE) @@ -195,7 +195,7 @@ class Kernel(SessionFactory): def dispatch_control(self, msg): idents,msg = self.session.feed_identities(msg, copy=False) try: - msg = self.session.unpack_message(msg, content=True, copy=False) + msg = self.session.unserialize(msg, content=True, copy=False) except: self.log.error("Invalid Message", exc_info=True) return @@ -204,10 +204,11 @@ class Kernel(SessionFactory): header = msg['header'] msg_id = header['msg_id'] - - handler = self.control_handlers.get(msg['msg_type'], None) + msg_type = header['msg_type'] + + handler = self.control_handlers.get(msg_type, None) if handler is None: - self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type']) + self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg_type) else: handler(self.control_stream, idents, msg) @@ -373,7 +374,7 @@ class Kernel(SessionFactory): self.control_stream.flush() idents,msg = self.session.feed_identities(msg, copy=False) try: - msg = self.session.unpack_message(msg, content=True, copy=False) + msg = self.session.unserialize(msg, content=True, copy=False) except: self.log.error("Invalid Message", exc_info=True) return @@ -383,17 +384,18 @@ class Kernel(SessionFactory): header = msg['header'] msg_id = header['msg_id'] + msg_type = msg['header']['msg_type'] if self.check_aborted(msg_id): self.aborted.remove(msg_id) # is it safe to assume a msg_id will not be resubmitted? - reply_type = msg['msg_type'].split('_')[0] + '_reply' + reply_type = msg_type.split('_')[0] + '_reply' status = {'status' : 'aborted'} reply_msg = self.session.send(stream, reply_type, subheader=status, content=status, parent=msg, ident=idents) return - handler = self.shell_handlers.get(msg['msg_type'], None) + handler = self.shell_handlers.get(msg_type, None) if handler is None: - self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type']) + self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg_type) else: handler(stream, idents, msg) diff --git a/IPython/parallel/tests/test_db.py b/IPython/parallel/tests/test_db.py index fddb961..4711d9a 100644 --- a/IPython/parallel/tests/test_db.py +++ b/IPython/parallel/tests/test_db.py @@ -56,8 +56,9 @@ class TestDictBackend(TestCase): msg = self.session.msg('apply_request', content=dict(a=5)) msg['buffers'] = [] rec = init_record(msg) - msg_ids.append(msg['msg_id']) - self.db.add_record(msg['msg_id'], rec) + msg_id = msg['header']['msg_id'] + msg_ids.append(msg_id) + self.db.add_record(msg_id, rec) return msg_ids def test_add_record(self): diff --git a/IPython/zmq/ipkernel.py b/IPython/zmq/ipkernel.py index 3b47f6a..ceea88d 100755 --- a/IPython/zmq/ipkernel.py +++ b/IPython/zmq/ipkernel.py @@ -124,7 +124,9 @@ class Kernel(Configurable): ident,msg = self.session.recv(self.shell_socket, zmq.NOBLOCK) if msg is None: return - + + msg_type = msg['header']['msg_type'] + # This assert will raise in versions of zeromq 2.0.7 and lesser. # We now require 2.0.8 or above, so we can uncomment for safety. # print(ident,msg, file=sys.__stdout__) @@ -133,11 +135,11 @@ class Kernel(Configurable): # Print some info about this message and leave a '--->' marker, so it's # easier to trace visually the message chain when debugging. Each # handler prints its message at the end. - self.log.debug('\n*** MESSAGE TYPE:'+str(msg['msg_type'])+'***') + self.log.debug('\n*** MESSAGE TYPE:'+str(msg_type)+'***') self.log.debug(' Content: '+str(msg['content'])+'\n --->\n ') # Find and call actual handler for message - handler = self.handlers.get(msg['msg_type'], None) + handler = self.handlers.get(msg_type, None) if handler is None: self.log.error("UNKNOWN MESSAGE TYPE:" +str(msg)) else: @@ -375,7 +377,7 @@ class Kernel(Configurable): "Unexpected missing message part." self.log.debug("Aborting:\n"+str(Message(msg))) - msg_type = msg['msg_type'] + msg_type = msg['header']['msg_type'] reply_type = msg_type.split('_')[0] + '_reply' reply_msg = self.session.send(self.shell_socket, reply_type, {'status' : 'aborted'}, msg, ident=ident) diff --git a/IPython/zmq/pykernel.py b/IPython/zmq/pykernel.py index fb62629..42835da 100755 --- a/IPython/zmq/pykernel.py +++ b/IPython/zmq/pykernel.py @@ -190,7 +190,7 @@ class Kernel(HasTraits): else: assert ident is not None, "Missing message part." self.log.debug("Aborting: %s"%Message(msg)) - msg_type = msg['msg_type'] + msg_type = msg['header']['msg_type'] reply_type = msg_type.split('_')[0] + '_reply' reply_msg = self.session.send(self.shell_socket, reply_type, {'status':'aborted'}, msg, ident=ident) self.log.debug(Message(reply_msg)) diff --git a/IPython/zmq/session.py b/IPython/zmq/session.py index 7d2ebc9..98a573a 100644 --- a/IPython/zmq/session.py +++ b/IPython/zmq/session.py @@ -244,7 +244,7 @@ class Session(Configurable): def _session_default(self): return bytes(uuid.uuid4()) - username = Unicode(os.environ.get('USER','username'), config=True, + username = Unicode(os.environ.get('USER',u'username'), config=True, help="""Username for the Session. Default is your system username.""") # message signature related traits: @@ -350,18 +350,16 @@ class Session(Configurable): def msg_header(self, msg_type): return msg_header(self.msg_id, msg_type, self.username, self.session) - def msg(self, msg_type, content=None, parent=None, subheader=None): + def msg(self, msg_type, content=None, parent=None, subheader=None, header=None): """Return the nested message dict. This format is different from what is sent over the wire. The - self.serialize method converts this nested message dict to the wire - format, which uses a message list. + serialize/unserialize methods converts this nested message dict to the wire + format, which is a list of message parts. """ msg = {} - msg['header'] = self.msg_header(msg_type) - msg['msg_id'] = msg['header']['msg_id'] + msg['header'] = self.msg_header(msg_type) if header is None else header msg['parent_header'] = {} if parent is None else extract_header(parent) - msg['msg_type'] = msg_type msg['content'] = {} if content is None else content sub = {} if subheader is None else subheader msg['header'].update(sub) @@ -385,6 +383,10 @@ class Session(Configurable): def serialize(self, msg, ident=None): """Serialize the message components to bytes. + This is roughly the inverse of unserialize. The serialize/unserialize + methods work with full message lists, whereas pack/unpack work with + the individual message parts in the message list. + Parameters ---------- msg : dict or Message @@ -434,8 +436,8 @@ class Session(Configurable): return to_send - def send(self, stream, msg_or_type, content=None, parent=None, ident=None, - buffers=None, subheader=None, track=False): + def send(self, stream, msg_or_type, content=None, parent=None, ident=None, + buffers=None, subheader=None, track=False, header=None): """Build and send a message via stream or socket. The message format used by this function internally is as follows: @@ -443,37 +445,42 @@ class Session(Configurable): [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content, buffer1,buffer2,...] - The self.serialize method converts the nested message dict into this + The serialize/unserialize methods convert the nested message dict into this format. Parameters ---------- stream : zmq.Socket or ZMQStream - the socket-like object used to send the data + The socket-like object used to send the data. msg_or_type : str or Message/dict Normally, msg_or_type will be a msg_type unless a message is being - sent more than once. + sent more than once. If a header is supplied, this can be set to + None and the msg_type will be pulled from the header. content : dict or None - the content of the message (ignored if msg_or_type is a message) + The content of the message (ignored if msg_or_type is a message). + header : dict or None + The header dict for the message (ignores if msg_to_type is a message). parent : Message or dict or None - the parent or parent header describing the parent of this message + The parent or parent header describing the parent of this message + (ignored if msg_or_type is a message). ident : bytes or list of bytes - the zmq.IDENTITY routing path + The zmq.IDENTITY routing path. subheader : dict or None - extra header keys for this message's header + Extra header keys for this message's header (ignored if msg_or_type + is a message). buffers : list or None - the already-serialized buffers to be appended to the message + The already-serialized buffers to be appended to the message. track : bool - whether to track. Only for use with Sockets, - because ZMQStream objects cannot track messages. + Whether to track. Only for use with Sockets, because ZMQStream + objects cannot track messages. Returns ------- - msg : message dict - the constructed message - (msg,tracker) : (message dict, MessageTracker) + msg : dict + The constructed message. + (msg,tracker) : (dict, MessageTracker) if track=True, then a 2-tuple will be returned, the first element being the constructed message, and the second being the MessageTracker @@ -486,12 +493,13 @@ class Session(Configurable): raise TypeError("ZMQStream cannot track messages") if isinstance(msg_or_type, (Message, dict)): - # we got a Message, not a msg_type - # don't build a new Message + # We got a Message or message dict, not a msg_type so don't + # build a new Message. msg = msg_or_type else: - msg = self.msg(msg_or_type, content, parent, subheader) - + msg = self.msg(msg_or_type, content=content, parent=parent, + subheader=subheader, header=header) + buffers = [] if buffers is None else buffers to_send = self.serialize(msg, ident) flag = 0 @@ -521,7 +529,7 @@ class Session(Configurable): msg['tracker'] = tracker return msg - + def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None): """Send a raw message via ident path. @@ -543,7 +551,7 @@ class Session(Configurable): ident = [ident] if ident is not None: to_send.extend(ident) - + to_send.append(DELIM) to_send.append(self.sign(msg_list)) to_send.extend(msg_list) @@ -578,7 +586,7 @@ class Session(Configurable): # invalid large messages can cause very expensive string comparisons idents, msg_list = self.feed_identities(msg_list, copy) try: - return idents, self.unpack_message(msg_list, content=content, copy=copy) + return idents, self.unserialize(msg_list, content=content, copy=copy) except Exception as e: print (idents, msg_list) # TODO: handle it @@ -600,10 +608,12 @@ class Session(Configurable): Returns ------- - (idents,msg_list) : two lists - idents will always be a list of bytes - the indentity prefix - msg_list will be a list of bytes or Messages, unchanged from input - msg_list should be unpackable via self.unpack_message at this point. + (idents, msg_list) : two lists + idents will always be a list of bytes, each of which is a ZMQ + identity. msg_list will be a list of bytes or zmq.Messages of the + form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and + should be unpackable/unserializable via self.unserialize at this + point. """ if copy: idx = msg_list.index(DELIM) @@ -619,21 +629,30 @@ class Session(Configurable): idents, msg_list = msg_list[:idx], msg_list[idx+1:] return [m.bytes for m in idents], msg_list - def unpack_message(self, msg_list, content=True, copy=True): - """Return a message object from the format - sent by self.send. - + def unserialize(self, msg_list, content=True, copy=True): + """Unserialize a msg_list to a nested message dict. + + This is roughly the inverse of serialize. The serialize/unserialize + methods work with full message lists, whereas pack/unpack work with + the individual message parts in the message list. + Parameters: ----------- - + msg_list : list of bytes or Message objects + The list of message parts of the form [HMAC,p_header,p_parent, + p_content,buffer1,buffer2,...]. content : bool (True) - whether to unpack the content dict (True), - or leave it serialized (False) - + Whether to unpack the content dict (True), or leave it packed + (False). copy : bool (True) - whether to return the bytes (True), - or the non-copying Message object in each place (False) - + Whether to return the bytes (True), or the non-copying Message + object in each place (False). + + Returns + ------- + msg : dict + The nested message dict with top-level keys [header, parent_header, + content, buffers]. """ minlen = 4 message = {} @@ -651,7 +670,6 @@ class Session(Configurable): if not len(msg_list) >= minlen: raise TypeError("malformed message, must have at least %i elements"%minlen) message['header'] = self.unpack(msg_list[1]) - message['msg_type'] = message['header']['msg_type'] message['parent_header'] = self.unpack(msg_list[2]) if content: message['content'] = self.unpack(msg_list[3]) diff --git a/IPython/zmq/tests/test_session.py b/IPython/zmq/tests/test_session.py index 6279acc..7c4cf9d 100644 --- a/IPython/zmq/tests/test_session.py +++ b/IPython/zmq/tests/test_session.py @@ -26,37 +26,102 @@ class SessionTestCase(BaseZMQTestCase): BaseZMQTestCase.setUp(self) self.session = ss.Session() + +class MockSocket(zmq.Socket): + + def __init__(self, *args, **kwargs): + super(MockSocket,self).__init__(*args,**kwargs) + self.data = [] + + def send_multipart(self, msgparts, *args, **kwargs): + self.data.extend(msgparts) + + def send(self, part, *args, **kwargs): + self.data.append(part) + + def recv_multipart(self, *args, **kwargs): + return self.data + class TestSession(SessionTestCase): def test_msg(self): """message format""" msg = self.session.msg('execute') - thekeys = set('header msg_id parent_header msg_type content'.split()) + thekeys = set('header parent_header content'.split()) s = set(msg.keys()) self.assertEquals(s, thekeys) self.assertTrue(isinstance(msg['content'],dict)) self.assertTrue(isinstance(msg['header'],dict)) self.assertTrue(isinstance(msg['parent_header'],dict)) - self.assertEquals(msg['msg_type'], 'execute') - - - + self.assertEquals(msg['header']['msg_type'], 'execute') + + def test_serialize(self): + msg = self.session.msg('execute',content=dict(a=10)) + msg_list = self.session.serialize(msg, ident=b'foo') + ident, msg_list = self.session.feed_identities(msg_list) + new_msg = self.session.unserialize(msg_list) + self.assertEquals(ident[0], b'foo') + self.assertEquals(new_msg['header'],msg['header']) + self.assertEquals(new_msg['content'],msg['content']) + self.assertEquals(new_msg['parent_header'],msg['parent_header']) + + def test_send(self): + socket = MockSocket(zmq.Context.instance(),zmq.PAIR) + + msg = self.session.msg('execute', content=dict(a=10)) + self.session.send(socket, msg, ident=b'foo', buffers=[b'bar']) + ident, msg_list = self.session.feed_identities(socket.data) + new_msg = self.session.unserialize(msg_list) + self.assertEquals(ident[0], b'foo') + self.assertEquals(new_msg['header'],msg['header']) + self.assertEquals(new_msg['content'],msg['content']) + self.assertEquals(new_msg['parent_header'],msg['parent_header']) + self.assertEquals(new_msg['buffers'],[b'bar']) + + socket.data = [] + + content = msg['content'] + header = msg['header'] + parent = msg['parent_header'] + msg_type = header['msg_type'] + self.session.send(socket, None, content=content, parent=parent, + header=header, ident=b'foo', buffers=[b'bar']) + ident, msg_list = self.session.feed_identities(socket.data) + new_msg = self.session.unserialize(msg_list) + self.assertEquals(ident[0], b'foo') + self.assertEquals(new_msg['header'],msg['header']) + self.assertEquals(new_msg['content'],msg['content']) + self.assertEquals(new_msg['parent_header'],msg['parent_header']) + self.assertEquals(new_msg['buffers'],[b'bar']) + + socket.data = [] + + self.session.send(socket, msg, ident=b'foo', buffers=[b'bar']) + ident, new_msg = self.session.recv(socket) + self.assertEquals(ident[0], b'foo') + self.assertEquals(new_msg['header'],msg['header']) + self.assertEquals(new_msg['content'],msg['content']) + self.assertEquals(new_msg['parent_header'],msg['parent_header']) + self.assertEquals(new_msg['buffers'],[b'bar']) + + socket.close() + def test_args(self): """initialization arguments for Session""" s = self.session self.assertTrue(s.pack is ss.default_packer) self.assertTrue(s.unpack is ss.default_unpacker) - self.assertEquals(s.username, os.environ.get('USER', 'username')) + self.assertEquals(s.username, os.environ.get('USER', u'username')) s = ss.Session() - self.assertEquals(s.username, os.environ.get('USER', 'username')) + self.assertEquals(s.username, os.environ.get('USER', u'username')) self.assertRaises(TypeError, ss.Session, pack='hi') self.assertRaises(TypeError, ss.Session, unpack='hi') u = str(uuid.uuid4()) - s = ss.Session(username='carrot', session=u) + s = ss.Session(username=u'carrot', session=u) self.assertEquals(s.session, u) - self.assertEquals(s.username, 'carrot') + self.assertEquals(s.username, u'carrot') def test_tracking(self): """test tracking messages""" @@ -109,3 +174,4 @@ class TestSession(SessionTestCase): content = dict(code='whoda',stuff=object()) themsg = self.session.msg('execute',content=content) pmsg = theids + diff --git a/docs/source/development/messaging.txt b/docs/source/development/messaging.txt index 877cabf..f51a189 100644 --- a/docs/source/development/messaging.txt +++ b/docs/source/development/messaging.txt @@ -101,18 +101,18 @@ generic structure:: # collaborative settings where multiple users may be interacting with the # same kernel simultaneously, so that frontends can label the various # messages in a meaningful way. - 'header' : { 'msg_id' : uuid, - 'username' : str, - 'session' : uuid + 'header' : { + 'msg_id' : uuid, + 'username' : str, + 'session' : uuid + # All recognized message type strings are listed below. + 'msg_type' : str, }, # In a chain of messages, the header from the parent is copied so that # clients can track where messages come from. 'parent_header' : dict, - # All recognized message type strings are listed below. - 'msg_type' : str, - # The actual content of the message must be a dict, whose structure # depends on the message type.x 'content' : dict,