diff --git a/IPython/parallel/apps/ipcontrollerapp.py b/IPython/parallel/apps/ipcontrollerapp.py index 81a7c2d..c0f66c2 100755 --- a/IPython/parallel/apps/ipcontrollerapp.py +++ b/IPython/parallel/apps/ipcontrollerapp.py @@ -41,7 +41,7 @@ from IPython.utils.importstring import import_item from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict # from IPython.parallel.controller.controller import ControllerFactory -from IPython.parallel.streamsession import StreamSession +from IPython.zmq.session import Session from IPython.parallel.controller.heartmonitor import HeartMonitor from IPython.parallel.controller.hub import HubFactory from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler @@ -109,7 +109,7 @@ class IPControllerApp(BaseParallelApplication): name = u'ipcontroller' description = _description config_file_name = Unicode(default_config_file_name) - classes = [ProfileDir, StreamSession, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo + classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo # change default to True auto_create = Bool(True, config=True, @@ -155,9 +155,9 @@ class IPControllerApp(BaseParallelApplication): import_statements = 'IPControllerApp.import_statements', location = 'IPControllerApp.location', - ident = 'StreamSession.session', - user = 'StreamSession.username', - exec_key = 'StreamSession.keyfile', + ident = 'Session.session', + user = 'Session.username', + exec_key = 'Session.keyfile', url = 'HubFactory.url', ip = 'HubFactory.ip', @@ -201,7 +201,7 @@ class IPControllerApp(BaseParallelApplication): # load from engine config with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f: cfg = json.loads(f.read()) - key = c.StreamSession.key = cfg['exec_key'] + key = c.Session.key = cfg['exec_key'] xport,addr = cfg['url'].split('://') c.HubFactory.engine_transport = xport ip,ports = addr.split(':') @@ -239,9 +239,9 @@ class IPControllerApp(BaseParallelApplication): # with open(keyfile, 'w') as f: # f.write(key) # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR) - c.StreamSession.key = key + c.Session.key = key else: - key = c.StreamSession.key = '' + key = c.Session.key = '' try: self.factory = HubFactory(config=c, log=self.log) diff --git a/IPython/parallel/apps/ipengineapp.py b/IPython/parallel/apps/ipengineapp.py index 241cad5..3130eb4 100755 --- a/IPython/parallel/apps/ipengineapp.py +++ b/IPython/parallel/apps/ipengineapp.py @@ -27,7 +27,7 @@ from IPython.parallel.apps.baseapp import BaseParallelApplication from IPython.zmq.log import EnginePUBHandler from IPython.config.configurable import Configurable -from IPython.parallel.streamsession import StreamSession +from IPython.zmq.session import Session from IPython.parallel.engine.engine import EngineFactory from IPython.parallel.engine.streamkernel import Kernel from IPython.parallel.util import disambiguate_url @@ -100,7 +100,7 @@ class IPEngineApp(BaseParallelApplication): app_name = Unicode(u'ipengine') description = Unicode(_description) config_file_name = Unicode(default_config_file_name) - classes = List([ProfileDir, StreamSession, EngineFactory, Kernel, MPI]) + classes = List([ProfileDir, Session, EngineFactory, Kernel, MPI]) startup_script = Unicode(u'', config=True, help='specify a script to be run at startup') @@ -124,9 +124,9 @@ class IPEngineApp(BaseParallelApplication): c = 'IPEngineApp.startup_command', s = 'IPEngineApp.startup_script', - ident = 'StreamSession.session', - user = 'StreamSession.username', - exec_key = 'StreamSession.keyfile', + ident = 'Session.session', + user = 'Session.username', + exec_key = 'Session.keyfile', url = 'EngineFactory.url', ip = 'EngineFactory.ip', @@ -190,7 +190,7 @@ class IPEngineApp(BaseParallelApplication): if isinstance(v, unicode): d[k] = v.encode() if d['exec_key']: - config.StreamSession.key = d['exec_key'] + config.Session.key = d['exec_key'] d['url'] = disambiguate_url(d['url'], d['location']) config.EngineFactory.url = d['url'] config.EngineFactory.location = d['location'] diff --git a/IPython/parallel/client/client.py b/IPython/parallel/client/client.py index ee875d6..1c1eab6 100644 --- a/IPython/parallel/client/client.py +++ b/IPython/parallel/client/client.py @@ -23,6 +23,7 @@ pjoin = os.path.join import zmq # from zmq.eventloop import ioloop, zmqstream +from IPython.utils.jsonutil import extract_dates from IPython.utils.path import get_ipython_dir from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode, Dict, List, Bool, Set) @@ -30,9 +31,10 @@ from IPython.external.decorator import decorator from IPython.external.ssh import tunnel from IPython.parallel import error -from IPython.parallel import streamsession as ss from IPython.parallel import util +from IPython.zmq.session import Session, Message + from .asyncresult import AsyncResult, AsyncHubResult from IPython.core.newapplication import ProfileDir, ProfileDirError from .view import DirectView, LoadBalancedView @@ -294,9 +296,9 @@ class Client(HasTraits): arg = 'key' key_arg = {arg:exec_key} if username is None: - self.session = ss.StreamSession(**key_arg) + self.session = Session(**key_arg) else: - self.session = ss.StreamSession(username=username, **key_arg) + self.session = Session(username=username, **key_arg) self._query_socket = self._context.socket(zmq.XREQ) self._query_socket.setsockopt(zmq.IDENTITY, self.session.session) if self._ssh: @@ -416,7 +418,7 @@ class Client(HasTraits): idents,msg = self.session.recv(self._query_socket,mode=0) if self.debug: pprint(msg) - msg = ss.Message(msg) + msg = Message(msg) content = msg.content self._config['registration'] = dict(content) if content.status == 'ok': @@ -478,11 +480,11 @@ class Client(HasTraits): md['engine_id'] = self._engines.get(md['engine_uuid'], None) if 'date' in parent: - md['submitted'] = datetime.strptime(parent['date'], util.ISO8601) + md['submitted'] = parent['date'] if 'started' in header: - md['started'] = datetime.strptime(header['started'], util.ISO8601) + md['started'] = header['started'] if 'date' in header: - md['completed'] = datetime.strptime(header['date'], util.ISO8601) + md['completed'] = header['date'] return md def _register_engine(self, msg): @@ -528,7 +530,7 @@ class Client(HasTraits): header = {} parent['msg_id'] = msg_id header['engine'] = uuid - header['date'] = datetime.now().strftime(util.ISO8601) + header['date'] = datetime.now() msg = dict(parent_header=parent, header=header, content=content) self._handle_apply_reply(msg) @@ -551,7 +553,7 @@ class Client(HasTraits): def _handle_apply_reply(self, msg): """Save the reply to an apply_request into our results.""" - parent = msg['parent_header'] + parent = extract_dates(msg['parent_header']) msg_id = parent['msg_id'] if msg_id not in self.outstanding: if msg_id in self.history: @@ -563,7 +565,7 @@ class Client(HasTraits): else: self.outstanding.remove(msg_id) content = msg['content'] - header = msg['header'] + header = extract_dates(msg['header']) # construct metadata: md = self.metadata[msg_id] @@ -589,33 +591,31 @@ class Client(HasTraits): def _flush_notifications(self): """Flush notifications of engine registrations waiting in ZMQ queue.""" - msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK) + idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK) while msg is not None: if self.debug: pprint(msg) - msg = msg[-1] msg_type = msg['msg_type'] handler = self._notification_handlers.get(msg_type, None) if handler is None: raise Exception("Unhandled message type: %s"%msg.msg_type) else: handler(msg) - msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK) + idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK) def _flush_results(self, sock): """Flush task or queue results waiting in ZMQ queue.""" - msg = self.session.recv(sock, mode=zmq.NOBLOCK) + idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK) while msg is not None: if self.debug: pprint(msg) - msg = msg[-1] msg_type = msg['msg_type'] handler = self._queue_handlers.get(msg_type, None) if handler is None: raise Exception("Unhandled message type: %s"%msg.msg_type) else: handler(msg) - msg = self.session.recv(sock, mode=zmq.NOBLOCK) + idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK) def _flush_control(self, sock): """Flush replies from the control channel waiting @@ -624,12 +624,12 @@ class Client(HasTraits): Currently: ignore them.""" if self._ignored_control_replies <= 0: return - msg = self.session.recv(sock, mode=zmq.NOBLOCK) + idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK) while msg is not None: self._ignored_control_replies -= 1 if self.debug: pprint(msg) - msg = self.session.recv(sock, mode=zmq.NOBLOCK) + idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK) def _flush_ignored_control(self): """flush ignored control replies""" @@ -638,19 +638,18 @@ class Client(HasTraits): self._ignored_control_replies -= 1 def _flush_ignored_hub_replies(self): - msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK) + ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK) while msg is not None: - msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK) + ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK) def _flush_iopub(self, sock): """Flush replies from the iopub channel waiting in the ZMQ queue. """ - msg = self.session.recv(sock, mode=zmq.NOBLOCK) + idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK) while msg is not None: if self.debug: pprint(msg) - msg = msg[-1] parent = msg['parent_header'] msg_id = parent['msg_id'] content = msg['content'] @@ -674,7 +673,7 @@ class Client(HasTraits): # reduntant? self.metadata[msg_id] = md - msg = self.session.recv(sock, mode=zmq.NOBLOCK) + idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK) #-------------------------------------------------------------------------- # len, getitem @@ -1172,6 +1171,7 @@ class Client(HasTraits): failures = [] # load cached results into result: content.update(local_results) + content = extract_dates(content) # update cache with results: for msg_id in sorted(theids): if msg_id in content['completed']: @@ -1338,6 +1338,8 @@ class Client(HasTraits): has_bufs = buffer_lens is not None has_rbufs = result_buffer_lens is not None for i,rec in enumerate(records): + # unpack timestamps + rec = extract_dates(rec) # relink buffers if has_bufs: blen = buffer_lens[i] @@ -1345,11 +1347,6 @@ class Client(HasTraits): if has_rbufs: blen = result_buffer_lens[i] rec['result_buffers'], buffers = buffers[:blen],buffers[blen:] - # turn timestamps back into times - for key in 'submitted started completed resubmitted'.split(): - maybedate = rec.get(key, None) - if maybedate and util.ISO8601_RE.match(maybedate): - rec[key] = datetime.strptime(maybedate, util.ISO8601) return records diff --git a/IPython/parallel/controller/hub.py b/IPython/parallel/controller/hub.py index ea30b8d..d2749a1 100755 --- a/IPython/parallel/controller/hub.py +++ b/IPython/parallel/controller/hub.py @@ -28,6 +28,7 @@ from IPython.utils.importstring import import_item from IPython.utils.traitlets import ( HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CStr ) +from IPython.utils.jsonutil import ISO8601, extract_dates from IPython.parallel import error, util from IPython.parallel.factory import RegistrationFactory, LoggingFactory @@ -71,13 +72,13 @@ def empty_record(): def init_record(msg): """Initialize a TaskRecord based on a request.""" - header = msg['header'] + header = extract_dates(msg['header']) return { 'msg_id' : header['msg_id'], 'header' : header, 'content': msg['content'], 'buffers': msg['buffers'], - 'submitted': datetime.strptime(header['date'], util.ISO8601), + 'submitted': header['date'], 'client_uuid' : None, 'engine_uuid' : None, 'started': None, @@ -295,7 +296,7 @@ class Hub(LoggingFactory): Parameters ========== loop: zmq IOLoop instance - session: StreamSession object + session: Session object context: zmq context for creating new connections (?) queue: ZMQStream for monitoring the command queue (SUB) query: ZMQStream for engine registration and client queries requests (XREP) @@ -610,11 +611,9 @@ class Hub(LoggingFactory): self.log.warn("queue:: unknown msg finished %r"%msg_id) return # update record anyway, because the unregistration could have been premature - rheader = msg['header'] - completed = datetime.strptime(rheader['date'], util.ISO8601) + rheader = extract_dates(msg['header']) + completed = rheader['date'] started = rheader.get('started', None) - if started is not None: - started = datetime.strptime(started, util.ISO8601) result = { 'result_header' : rheader, 'result_content': msg['content'], @@ -695,7 +694,7 @@ class Hub(LoggingFactory): if msg_id in self.unassigned: self.unassigned.remove(msg_id) - header = msg['header'] + header = extract_dates(msg['header']) engine_uuid = header.get('engine', None) eid = self.by_ident.get(engine_uuid, None) @@ -706,10 +705,8 @@ class Hub(LoggingFactory): self.completed[eid].append(msg_id) if msg_id in self.tasks[eid]: self.tasks[eid].remove(msg_id) - completed = datetime.strptime(header['date'], util.ISO8601) + completed = header['date'] started = header.get('started', None) - if started is not None: - started = datetime.strptime(started, util.ISO8601) result = { 'result_header' : header, 'result_content': msg['content'], @@ -1141,7 +1138,7 @@ class Hub(LoggingFactory): reply = error.wrap_exception() else: # send the messages - now_s = now.strftime(util.ISO8601) + now_s = now.strftime(ISO8601) for rec in records: header = rec['header'] # include resubmitted in header to prevent digest collision diff --git a/IPython/parallel/controller/sqlitedb.py b/IPython/parallel/controller/sqlitedb.py index c6f90bd..c60488d 100644 --- a/IPython/parallel/controller/sqlitedb.py +++ b/IPython/parallel/controller/sqlitedb.py @@ -17,7 +17,7 @@ from zmq.eventloop import ioloop from IPython.utils.traitlets import Unicode, Instance, List from .dictdb import BaseDB -from IPython.parallel.util import ISO8601 +from IPython.utils.jsonutil import date_default, extract_dates #----------------------------------------------------------------------------- # SQLite operators, adapters, and converters @@ -52,13 +52,13 @@ def _convert_datetime(ds): return datetime.strptime(ds, ISO8601) def _adapt_dict(d): - return json.dumps(d) + return json.dumps(d, default=date_default) def _convert_dict(ds): if ds is None: return ds else: - return json.loads(ds) + return extract_dates(json.loads(ds)) def _adapt_bufs(bufs): # this is *horrible* diff --git a/IPython/parallel/engine/engine.py b/IPython/parallel/engine/engine.py index be7fd96..d25697c 100755 --- a/IPython/parallel/engine/engine.py +++ b/IPython/parallel/engine/engine.py @@ -24,9 +24,10 @@ from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode from IPython.parallel.controller.heartmonitor import Heart from IPython.parallel.factory import RegistrationFactory -from IPython.parallel.streamsession import Message from IPython.parallel.util import disambiguate_url +from IPython.zmq.session import Message + from .streamkernel import Kernel class EngineFactory(RegistrationFactory): diff --git a/IPython/parallel/engine/kernelstarter.py b/IPython/parallel/engine/kernelstarter.py index 3395e4f..c9b558a 100644 --- a/IPython/parallel/engine/kernelstarter.py +++ b/IPython/parallel/engine/kernelstarter.py @@ -8,7 +8,7 @@ from zmq.eventloop import ioloop -from IPython.parallel.streamsession import StreamSession +from IPython.zmq.session import Session class KernelStarter(object): """Object for resetting/killing the Kernel.""" @@ -213,7 +213,7 @@ def make_starter(up_addr, down_addr, *args, **kwargs): """entry point function for launching a kernelstarter in a subprocess""" loop = ioloop.IOLoop.instance() ctx = zmq.Context() - session = StreamSession() + session = Session() upstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop) upstream.connect(up_addr) downstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop) diff --git a/IPython/parallel/engine/streamkernel.py b/IPython/parallel/engine/streamkernel.py index 8ea1e7f..f17ce4b 100755 --- a/IPython/parallel/engine/streamkernel.py +++ b/IPython/parallel/engine/streamkernel.py @@ -28,12 +28,13 @@ import zmq from zmq.eventloop import ioloop, zmqstream # Local imports. +from IPython.utils.jsonutil import ISO8601 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode from IPython.zmq.completer import KernelCompleter from IPython.parallel.error import wrap_exception from IPython.parallel.factory import SessionFactory -from IPython.parallel.util import serialize_object, unpack_apply_message, ISO8601 +from IPython.parallel.util import serialize_object, unpack_apply_message def printer(*args): pprint(args, stream=sys.__stdout__) @@ -42,7 +43,7 @@ def printer(*args): class _Passer(zmqstream.ZMQStream): """Empty class that implements `send()` that does nothing. - Subclass ZMQStream for StreamSession typechecking + Subclass ZMQStream for Session typechecking """ def __init__(self, *args, **kwargs): diff --git a/IPython/parallel/factory.py b/IPython/parallel/factory.py index 48aa12e..4bb53c1 100644 --- a/IPython/parallel/factory.py +++ b/IPython/parallel/factory.py @@ -21,8 +21,8 @@ from zmq.eventloop.ioloop import IOLoop from IPython.config.configurable import Configurable from IPython.utils.traitlets import Int, Instance, Unicode -import IPython.parallel.streamsession as ss from IPython.parallel.util import select_random_ports +from IPython.zmq.session import Session #----------------------------------------------------------------------------- # Classes @@ -43,7 +43,7 @@ class SessionFactory(LoggingFactory): def _context_default(self): return zmq.Context.instance() - session = Instance('IPython.parallel.streamsession.StreamSession') + session = Instance('IPython.zmq.session.Session') loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False) def _loop_default(self): return IOLoop.instance() @@ -53,7 +53,7 @@ class SessionFactory(LoggingFactory): super(SessionFactory, self).__init__(**kwargs) # construct the session - self.session = ss.StreamSession(**kwargs) + self.session = Session(**kwargs) class RegistrationFactory(SessionFactory): diff --git a/IPython/parallel/streamsession.py b/IPython/parallel/streamsession.py deleted file mode 100644 index fbff1ac..0000000 --- a/IPython/parallel/streamsession.py +++ /dev/null @@ -1,483 +0,0 @@ -#!/usr/bin/env python -"""edited session.py to work with streams, and move msg_type to the header -""" -#----------------------------------------------------------------------------- -# Copyright (C) 2010-2011 The IPython Development Team -# -# Distributed under the terms of the BSD License. The full license is in -# the file COPYING, distributed as part of this software. -#----------------------------------------------------------------------------- - -#----------------------------------------------------------------------------- -# Imports -#----------------------------------------------------------------------------- - -import hmac -import os -import pprint -import uuid -from datetime import datetime - -try: - import cPickle - pickle = cPickle -except: - cPickle = None - import pickle - -import zmq -from zmq.utils import jsonapi -from zmq.eventloop.zmqstream import ZMQStream - -from IPython.config.configurable import Configurable -from IPython.utils.importstring import import_item -from IPython.utils.traitlets import CStr, Unicode, Bool, Any, Instance, Set - -from .util import ISO8601 - -#----------------------------------------------------------------------------- -# utility functions -#----------------------------------------------------------------------------- - -def squash_unicode(obj): - """coerce unicode back to bytestrings.""" - if isinstance(obj,dict): - for key in obj.keys(): - obj[key] = squash_unicode(obj[key]) - if isinstance(key, unicode): - obj[squash_unicode(key)] = obj.pop(key) - elif isinstance(obj, list): - for i,v in enumerate(obj): - obj[i] = squash_unicode(v) - elif isinstance(obj, unicode): - obj = obj.encode('utf8') - return obj - -def _date_default(obj): - if isinstance(obj, datetime): - return obj.strftime(ISO8601) - else: - raise TypeError("%r is not JSON serializable"%obj) - -#----------------------------------------------------------------------------- -# globals and defaults -#----------------------------------------------------------------------------- - -_default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default' -json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:_date_default}) -json_unpacker = lambda s: squash_unicode(jsonapi.loads(s)) - -pickle_packer = lambda o: pickle.dumps(o,-1) -pickle_unpacker = pickle.loads - -default_packer = json_packer -default_unpacker = json_unpacker - - -DELIM="" - -#----------------------------------------------------------------------------- -# Classes -#----------------------------------------------------------------------------- - -class Message(object): - """A simple message object that maps dict keys to attributes. - - A Message can be created from a dict and a dict from a Message instance - simply by calling dict(msg_obj).""" - - def __init__(self, msg_dict): - dct = self.__dict__ - for k, v in dict(msg_dict).iteritems(): - if isinstance(v, dict): - v = Message(v) - dct[k] = v - - # Having this iterator lets dict(msg_obj) work out of the box. - def __iter__(self): - return iter(self.__dict__.iteritems()) - - def __repr__(self): - return repr(self.__dict__) - - def __str__(self): - return pprint.pformat(self.__dict__) - - def __contains__(self, k): - return k in self.__dict__ - - def __getitem__(self, k): - return self.__dict__[k] - - -def msg_header(msg_id, msg_type, username, session): - date=datetime.now().strftime(ISO8601) - return locals() - -def extract_header(msg_or_header): - """Given a message or header, return the header.""" - if not msg_or_header: - return {} - try: - # See if msg_or_header is the entire message. - h = msg_or_header['header'] - except KeyError: - try: - # See if msg_or_header is just the header - h = msg_or_header['msg_id'] - except KeyError: - raise - else: - h = msg_or_header - if not isinstance(h, dict): - h = dict(h) - return h - -class StreamSession(Configurable): - """tweaked version of IPython.zmq.session.Session, for development in Parallel""" - debug=Bool(False, config=True, help="""Debug output in the StreamSession""") - packer = Unicode('json',config=True, - help="""The name of the packer for serializing messages. - Should be one of 'json', 'pickle', or an import name - for a custom serializer.""") - def _packer_changed(self, name, old, new): - if new.lower() == 'json': - self.pack = json_packer - self.unpack = json_unpacker - elif new.lower() == 'pickle': - self.pack = pickle_packer - self.unpack = pickle_unpacker - else: - self.pack = import_item(new) - - unpacker = Unicode('json',config=True, - help="""The name of the unpacker for unserializing messages. - Only used with custom functions for `packer`.""") - def _unpacker_changed(self, name, old, new): - if new.lower() == 'json': - self.pack = json_packer - self.unpack = json_unpacker - elif new.lower() == 'pickle': - self.pack = pickle_packer - self.unpack = pickle_unpacker - else: - self.unpack = import_item(new) - - session = CStr('',config=True, - help="""The UUID identifying this session.""") - def _session_default(self): - return bytes(uuid.uuid4()) - username = Unicode(os.environ.get('USER','username'), config=True, - help="""Username for the Session. Default is your system username.""") - - # message signature related traits: - key = CStr('', config=True, - help="""execution key, for extra authentication.""") - def _key_changed(self, name, old, new): - if new: - self.auth = hmac.HMAC(new) - else: - self.auth = None - auth = Instance(hmac.HMAC) - counters = Instance('collections.defaultdict', (int,)) - digest_history = Set() - - keyfile = Unicode('', config=True, - help="""path to file containing execution key.""") - def _keyfile_changed(self, name, old, new): - with open(new, 'rb') as f: - self.key = f.read().strip() - - pack = Any(default_packer) # the actual packer function - def _pack_changed(self, name, old, new): - if not callable(new): - raise TypeError("packer must be callable, not %s"%type(new)) - - unpack = Any(default_unpacker) # the actual packer function - def _unpack_changed(self, name, old, new): - if not callable(new): - raise TypeError("packer must be callable, not %s"%type(new)) - - def __init__(self, **kwargs): - super(StreamSession, self).__init__(**kwargs) - self.none = self.pack({}) - - @property - def msg_id(self): - """always return new uuid""" - return str(uuid.uuid4()) - - def msg_header(self, msg_type): - return msg_header(self.msg_id, msg_type, self.username, self.session) - - def msg(self, msg_type, content=None, parent=None, subheader=None): - msg = {} - msg['header'] = self.msg_header(msg_type) - msg['msg_id'] = msg['header']['msg_id'] - msg['parent_header'] = {} if parent is None else extract_header(parent) - msg['msg_type'] = msg_type - msg['content'] = {} if content is None else content - sub = {} if subheader is None else subheader - msg['header'].update(sub) - return msg - - def check_key(self, msg_or_header): - """Check that a message's header has the right key""" - if not self.key: - return True - header = extract_header(msg_or_header) - return header.get('key', '') == self.key - - def sign(self, msg): - """Sign a message with HMAC digest. If no auth, return b''.""" - if self.auth is None: - return b'' - h = self.auth.copy() - for m in msg: - h.update(m) - return h.hexdigest() - - def serialize(self, msg, ident=None): - content = msg.get('content', {}) - if content is None: - content = self.none - elif isinstance(content, dict): - content = self.pack(content) - elif isinstance(content, bytes): - # content is already packed, as in a relayed message - pass - elif isinstance(content, unicode): - # should be bytes, but JSON often spits out unicode - content = content.encode('utf8') - else: - raise TypeError("Content incorrect type: %s"%type(content)) - - real_message = [self.pack(msg['header']), - self.pack(msg['parent_header']), - content - ] - - to_send = [] - - if isinstance(ident, list): - # accept list of idents - to_send.extend(ident) - elif ident is not None: - to_send.append(ident) - to_send.append(DELIM) - - signature = self.sign(real_message) - to_send.append(signature) - - to_send.extend(real_message) - - return to_send - - def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False): - """Build and send a message via stream or socket. - - Parameters - ---------- - - stream : zmq.Socket or ZMQStream - the socket-like object used to send the data - msg_or_type : str or Message/dict - Normally, msg_or_type will be a msg_type unless a message is being sent more - than once. - - content : dict or None - the content of the message (ignored if msg_or_type is a message) - buffers : list or None - the already-serialized buffers to be appended to the message - parent : Message or dict or None - the parent or parent header describing the parent of this message - subheader : dict or None - extra header keys for this message's header - ident : bytes or list of bytes - the zmq.IDENTITY routing path - track : bool - whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages. - - Returns - ------- - msg : message dict - the constructed message - (msg,tracker) : (message dict, MessageTracker) - if track=True, then a 2-tuple will be returned, the first element being the constructed - message, and the second being the MessageTracker - - """ - - if not isinstance(stream, (zmq.Socket, ZMQStream)): - raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream)) - elif track and isinstance(stream, ZMQStream): - raise TypeError("ZMQStream cannot track messages") - - if isinstance(msg_or_type, (Message, dict)): - # we got a Message, not a msg_type - # don't build a new Message - msg = msg_or_type - else: - msg = self.msg(msg_or_type, content, parent, subheader) - - buffers = [] if buffers is None else buffers - to_send = self.serialize(msg, ident) - flag = 0 - if buffers: - flag = zmq.SNDMORE - _track = False - else: - _track=track - if track: - tracker = stream.send_multipart(to_send, flag, copy=False, track=_track) - else: - tracker = stream.send_multipart(to_send, flag, copy=False) - for b in buffers[:-1]: - stream.send(b, flag, copy=False) - if buffers: - if track: - tracker = stream.send(buffers[-1], copy=False, track=track) - else: - tracker = stream.send(buffers[-1], copy=False) - - # omsg = Message(msg) - if self.debug: - pprint.pprint(msg) - pprint.pprint(to_send) - pprint.pprint(buffers) - - msg['tracker'] = tracker - - return msg - - def send_raw(self, stream, msg, flags=0, copy=True, ident=None): - """Send a raw message via ident path. - - Parameters - ---------- - msg : list of sendable buffers""" - to_send = [] - if isinstance(ident, bytes): - ident = [ident] - if ident is not None: - to_send.extend(ident) - - to_send.append(DELIM) - to_send.append(self.sign(msg)) - to_send.extend(msg) - stream.send_multipart(msg, flags, copy=copy) - - def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): - """receives and unpacks a message - returns [idents], msg""" - if isinstance(socket, ZMQStream): - socket = socket.socket - try: - msg = socket.recv_multipart(mode, copy=copy) - except zmq.ZMQError as e: - if e.errno == zmq.EAGAIN: - # We can convert EAGAIN to None as we know in this case - # recv_multipart won't return None. - return None - else: - raise - # return an actual Message object - # determine the number of idents by trying to unpack them. - # this is terrible: - idents, msg = self.feed_identities(msg, copy) - try: - return idents, self.unpack_message(msg, content=content, copy=copy) - except Exception as e: - print (idents, msg) - # TODO: handle it - raise e - - def feed_identities(self, msg, copy=True): - """feed until DELIM is reached, then return the prefix as idents and remainder as - msg. This is easily broken by setting an IDENT to DELIM, but that would be silly. - - Parameters - ---------- - msg : a list of Message or bytes objects - the message to be split - copy : bool - flag determining whether the arguments are bytes or Messages - - Returns - ------- - (idents,msg) : two lists - idents will always be a list of bytes - the indentity prefix - msg will be a list of bytes or Messages, unchanged from input - msg should be unpackable via self.unpack_message at this point. - """ - if copy: - idx = msg.index(DELIM) - return msg[:idx], msg[idx+1:] - else: - failed = True - for idx,m in enumerate(msg): - if m.bytes == DELIM: - failed = False - break - if failed: - raise ValueError("DELIM not in msg") - idents, msg = msg[:idx], msg[idx+1:] - return [m.bytes for m in idents], msg - - def unpack_message(self, msg, content=True, copy=True): - """Return a message object from the format - sent by self.send. - - Parameters: - ----------- - - content : bool (True) - whether to unpack the content dict (True), - or leave it serialized (False) - - copy : bool (True) - whether to return the bytes (True), - or the non-copying Message object in each place (False) - - """ - minlen = 4 - message = {} - if not copy: - for i in range(minlen): - msg[i] = msg[i].bytes - if self.auth is not None: - signature = msg[0] - if signature in self.digest_history: - raise ValueError("Duplicate Signature: %r"%signature) - self.digest_history.add(signature) - check = self.sign(msg[1:4]) - if not signature == check: - raise ValueError("Invalid Signature: %r"%signature) - if not len(msg) >= minlen: - raise TypeError("malformed message, must have at least %i elements"%minlen) - message['header'] = self.unpack(msg[1]) - message['msg_type'] = message['header']['msg_type'] - message['parent_header'] = self.unpack(msg[2]) - if content: - message['content'] = self.unpack(msg[3]) - else: - message['content'] = msg[3] - - message['buffers'] = msg[4:] - return message - -def test_msg2obj(): - am = dict(x=1) - ao = Message(am) - assert ao.x == am['x'] - - am['y'] = dict(z=1) - ao = Message(am) - assert ao.y.z == am['y']['z'] - - k1, k2 = 'y', 'z' - assert ao[k1][k2] == am[k1][k2] - - am2 = dict(ao) - assert am['x'] == am2['x'] - assert am['y']['z'] == am2['y']['z'] diff --git a/IPython/parallel/tests/test_db.py b/IPython/parallel/tests/test_db.py index fc47b61..e0b0a3d 100644 --- a/IPython/parallel/tests/test_db.py +++ b/IPython/parallel/tests/test_db.py @@ -20,18 +20,21 @@ from unittest import TestCase from nose import SkipTest -from IPython.parallel import error, streamsession as ss +from IPython.parallel import error from IPython.parallel.controller.dictdb import DictDB from IPython.parallel.controller.sqlitedb import SQLiteDB from IPython.parallel.controller.hub import init_record, empty_record +from IPython.zmq.session import Session + + #------------------------------------------------------------------------------- # TestCases #------------------------------------------------------------------------------- class TestDictBackend(TestCase): def setUp(self): - self.session = ss.StreamSession() + self.session = Session() self.db = self.create_db() self.load_records(16) diff --git a/IPython/parallel/tests/test_streamsession.py b/IPython/parallel/tests/test_streamsession.py deleted file mode 100644 index 051c11c..0000000 --- a/IPython/parallel/tests/test_streamsession.py +++ /dev/null @@ -1,111 +0,0 @@ -"""test building messages with streamsession""" - -#------------------------------------------------------------------------------- -# Copyright (C) 2011 The IPython Development Team -# -# Distributed under the terms of the BSD License. The full license is in -# the file COPYING, distributed as part of this software. -#------------------------------------------------------------------------------- - -#------------------------------------------------------------------------------- -# Imports -#------------------------------------------------------------------------------- - -import os -import uuid -import zmq - -from zmq.tests import BaseZMQTestCase -from zmq.eventloop.zmqstream import ZMQStream -# from IPython.zmq.tests import SessionTestCase -from IPython.parallel import streamsession as ss - -class SessionTestCase(BaseZMQTestCase): - - def setUp(self): - BaseZMQTestCase.setUp(self) - self.session = ss.StreamSession() - -class TestSession(SessionTestCase): - - def test_msg(self): - """message format""" - msg = self.session.msg('execute') - thekeys = set('header msg_id parent_header msg_type content'.split()) - s = set(msg.keys()) - self.assertEquals(s, thekeys) - self.assertTrue(isinstance(msg['content'],dict)) - self.assertTrue(isinstance(msg['header'],dict)) - self.assertTrue(isinstance(msg['parent_header'],dict)) - self.assertEquals(msg['msg_type'], 'execute') - - - - def test_args(self): - """initialization arguments for StreamSession""" - s = self.session - self.assertTrue(s.pack is ss.default_packer) - self.assertTrue(s.unpack is ss.default_unpacker) - self.assertEquals(s.username, os.environ.get('USER', 'username')) - - s = ss.StreamSession() - self.assertEquals(s.username, os.environ.get('USER', 'username')) - - self.assertRaises(TypeError, ss.StreamSession, pack='hi') - self.assertRaises(TypeError, ss.StreamSession, unpack='hi') - u = str(uuid.uuid4()) - s = ss.StreamSession(username='carrot', session=u) - self.assertEquals(s.session, u) - self.assertEquals(s.username, 'carrot') - - def test_tracking(self): - """test tracking messages""" - a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) - s = self.session - stream = ZMQStream(a) - msg = s.send(a, 'hello', track=False) - self.assertTrue(msg['tracker'] is None) - msg = s.send(a, 'hello', track=True) - self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker)) - M = zmq.Message(b'hi there', track=True) - msg = s.send(a, 'hello', buffers=[M], track=True) - t = msg['tracker'] - self.assertTrue(isinstance(t, zmq.MessageTracker)) - self.assertRaises(zmq.NotDone, t.wait, .1) - del M - t.wait(1) # this will raise - - - # def test_rekey(self): - # """rekeying dict around json str keys""" - # d = {'0': uuid.uuid4(), 0:uuid.uuid4()} - # self.assertRaises(KeyError, ss.rekey, d) - # - # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()} - # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']} - # rd = ss.rekey(d) - # self.assertEquals(d2,rd) - # - # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()} - # d2 = {1.5:d['1.5'],1:d['1']} - # rd = ss.rekey(d) - # self.assertEquals(d2,rd) - # - # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()} - # self.assertRaises(KeyError, ss.rekey, d) - # - def test_unique_msg_ids(self): - """test that messages receive unique ids""" - ids = set() - for i in range(2**12): - h = self.session.msg_header('test') - msg_id = h['msg_id'] - self.assertTrue(msg_id not in ids) - ids.add(msg_id) - - def test_feed_identities(self): - """scrub the front for zmq IDENTITIES""" - theids = "engine client other".split() - content = dict(code='whoda',stuff=object()) - themsg = self.session.msg('execute',content=content) - pmsg = theids diff --git a/IPython/parallel/util.py b/IPython/parallel/util.py index 332322e..f777c43 100644 --- a/IPython/parallel/util.py +++ b/IPython/parallel/util.py @@ -17,7 +17,6 @@ import re import stat import socket import sys -from datetime import datetime from signal import signal, SIGINT, SIGABRT, SIGTERM try: from signal import SIGKILL @@ -40,10 +39,6 @@ from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence from IPython.utils.newserialized import serialize, unserialize from IPython.zmq.log import EnginePUBHandler -# globals -ISO8601="%Y-%m-%dT%H:%M:%S.%f" -ISO8601_RE=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$") - #----------------------------------------------------------------------------- # Classes #----------------------------------------------------------------------------- @@ -101,18 +96,6 @@ class ReverseDict(dict): # Functions #----------------------------------------------------------------------------- -def extract_dates(obj): - """extract ISO8601 dates from unpacked JSON""" - if isinstance(obj, dict): - for k,v in obj.iteritems(): - obj[k] = extract_dates(v) - elif isinstance(obj, list): - obj = [ extract_dates(o) for o in obj ] - elif isinstance(obj, basestring): - if ISO8601_RE.match(obj): - obj = datetime.strptime(obj, ISO8601) - return obj - def validate_url(url): """validate a url for zeromq""" if not isinstance(url, basestring): diff --git a/IPython/utils/jsonutil.py b/IPython/utils/jsonutil.py index f7b1f76..d55fba0 100644 --- a/IPython/utils/jsonutil.py +++ b/IPython/utils/jsonutil.py @@ -11,12 +11,43 @@ # Imports #----------------------------------------------------------------------------- # stdlib +import re import types +from datetime import datetime + +#----------------------------------------------------------------------------- +# Globals and constants +#----------------------------------------------------------------------------- + +# timestamp formats +ISO8601="%Y-%m-%dT%H:%M:%S.%f" +ISO8601_PAT=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$") #----------------------------------------------------------------------------- # Classes and functions #----------------------------------------------------------------------------- +def extract_dates(obj): + """extract ISO8601 dates from unpacked JSON""" + if isinstance(obj, dict): + for k,v in obj.iteritems(): + obj[k] = extract_dates(v) + elif isinstance(obj, list): + obj = [ extract_dates(o) for o in obj ] + elif isinstance(obj, basestring): + if ISO8601_PAT.match(obj): + obj = datetime.strptime(obj, ISO8601) + return obj + +def date_default(obj): + """default function for packing datetime objects""" + if isinstance(obj, datetime): + return obj.strftime(ISO8601) + else: + raise TypeError("%r is not JSON serializable"%obj) + + + def json_clean(obj): """Clean an object to ensure it's safe to encode in JSON. diff --git a/IPython/zmq/session.py b/IPython/zmq/session.py index afb9907..e3bf3bc 100644 --- a/IPython/zmq/session.py +++ b/IPython/zmq/session.py @@ -1,10 +1,77 @@ +#!/usr/bin/env python +"""edited session.py to work with streams, and move msg_type to the header +""" +#----------------------------------------------------------------------------- +# Copyright (C) 2010-2011 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#----------------------------------------------------------------------------- + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +import hmac import os -import uuid import pprint +import uuid +from datetime import datetime + +try: + import cPickle + pickle = cPickle +except: + cPickle = None + import pickle import zmq +from zmq.utils import jsonapi +from zmq.eventloop.zmqstream import ZMQStream + +from IPython.config.configurable import Configurable +from IPython.utils.importstring import import_item +from IPython.utils.jsonutil import date_default +from IPython.utils.traitlets import CStr, Unicode, Bool, Any, Instance, Set + +#----------------------------------------------------------------------------- +# utility functions +#----------------------------------------------------------------------------- + +def squash_unicode(obj): + """coerce unicode back to bytestrings.""" + if isinstance(obj,dict): + for key in obj.keys(): + obj[key] = squash_unicode(obj[key]) + if isinstance(key, unicode): + obj[squash_unicode(key)] = obj.pop(key) + elif isinstance(obj, list): + for i,v in enumerate(obj): + obj[i] = squash_unicode(v) + elif isinstance(obj, unicode): + obj = obj.encode('utf8') + return obj + +#----------------------------------------------------------------------------- +# globals and defaults +#----------------------------------------------------------------------------- + +_default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default' +json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:date_default}) +json_unpacker = lambda s: squash_unicode(jsonapi.loads(s)) -from zmq.utils import jsonapi as json +pickle_packer = lambda o: pickle.dumps(o,-1) +pickle_unpacker = pickle.loads + +default_packer = json_packer +default_unpacker = json_unpacker + + +DELIM="" + +#----------------------------------------------------------------------------- +# Classes +#----------------------------------------------------------------------------- class Message(object): """A simple message object that maps dict keys to attributes. @@ -14,7 +81,7 @@ class Message(object): def __init__(self, msg_dict): dct = self.__dict__ - for k, v in msg_dict.iteritems(): + for k, v in dict(msg_dict).iteritems(): if isinstance(v, dict): v = Message(v) dct[k] = v @@ -36,13 +103,9 @@ class Message(object): return self.__dict__[k] -def msg_header(msg_id, username, session): - return { - 'msg_id' : msg_id, - 'username' : username, - 'session' : session - } - +def msg_header(msg_id, msg_type, username, session): + date=datetime.now() + return locals() def extract_header(msg_or_header): """Given a message or header, return the header.""" @@ -63,109 +126,341 @@ def extract_header(msg_or_header): h = dict(h) return h - -class Session(object): - - def __init__(self, username=os.environ.get('USER','username'), session=None): - self.username = username - if session is None: - self.session = str(uuid.uuid4()) +class Session(Configurable): + """tweaked version of IPython.zmq.session.Session, for development in Parallel""" + debug=Bool(False, config=True, help="""Debug output in the Session""") + packer = Unicode('json',config=True, + help="""The name of the packer for serializing messages. + Should be one of 'json', 'pickle', or an import name + for a custom serializer.""") + def _packer_changed(self, name, old, new): + if new.lower() == 'json': + self.pack = json_packer + self.unpack = json_unpacker + elif new.lower() == 'pickle': + self.pack = pickle_packer + self.unpack = pickle_unpacker else: - self.session = session - self.msg_id = 0 + self.pack = import_item(new) - def msg_header(self): - h = msg_header(self.msg_id, self.username, self.session) - self.msg_id += 1 - return h + unpacker = Unicode('json',config=True, + help="""The name of the unpacker for unserializing messages. + Only used with custom functions for `packer`.""") + def _unpacker_changed(self, name, old, new): + if new.lower() == 'json': + self.pack = json_packer + self.unpack = json_unpacker + elif new.lower() == 'pickle': + self.pack = pickle_packer + self.unpack = pickle_unpacker + else: + self.unpack = import_item(new) + + session = CStr('',config=True, + help="""The UUID identifying this session.""") + def _session_default(self): + return bytes(uuid.uuid4()) + username = Unicode(os.environ.get('USER','username'), config=True, + help="""Username for the Session. Default is your system username.""") + + # message signature related traits: + key = CStr('', config=True, + help="""execution key, for extra authentication.""") + def _key_changed(self, name, old, new): + if new: + self.auth = hmac.HMAC(new) + else: + self.auth = None + auth = Instance(hmac.HMAC) + counters = Instance('collections.defaultdict', (int,)) + digest_history = Set() + + keyfile = Unicode('', config=True, + help="""path to file containing execution key.""") + def _keyfile_changed(self, name, old, new): + with open(new, 'rb') as f: + self.key = f.read().strip() - def msg(self, msg_type, content=None, parent=None): - """Construct a standard-form message, with a given type, content, and parent. + pack = Any(default_packer) # the actual packer function + def _pack_changed(self, name, old, new): + if not callable(new): + raise TypeError("packer must be callable, not %s"%type(new)) - NOT to be called directly. - """ + unpack = Any(default_unpacker) # the actual packer function + def _unpack_changed(self, name, old, new): + if not callable(new): + raise TypeError("packer must be callable, not %s"%type(new)) + + def __init__(self, **kwargs): + super(Session, self).__init__(**kwargs) + self.none = self.pack({}) + + @property + def msg_id(self): + """always return new uuid""" + return str(uuid.uuid4()) + + def msg_header(self, msg_type): + return msg_header(self.msg_id, msg_type, self.username, self.session) + + def msg(self, msg_type, content=None, parent=None, subheader=None): msg = {} - msg['header'] = self.msg_header() + msg['header'] = self.msg_header(msg_type) + msg['msg_id'] = msg['header']['msg_id'] msg['parent_header'] = {} if parent is None else extract_header(parent) msg['msg_type'] = msg_type msg['content'] = {} if content is None else content + sub = {} if subheader is None else subheader + msg['header'].update(sub) return msg - def send(self, socket, msg_or_type, content=None, parent=None, ident=None): - """send a message via a socket, using a uniform message pattern. + def check_key(self, msg_or_header): + """Check that a message's header has the right key""" + if not self.key: + return True + header = extract_header(msg_or_header) + return header.get('key', '') == self.key + + def sign(self, msg): + """Sign a message with HMAC digest. If no auth, return b''.""" + if self.auth is None: + return b'' + h = self.auth.copy() + for m in msg: + h.update(m) + return h.hexdigest() + + def serialize(self, msg, ident=None): + content = msg.get('content', {}) + if content is None: + content = self.none + elif isinstance(content, dict): + content = self.pack(content) + elif isinstance(content, bytes): + # content is already packed, as in a relayed message + pass + elif isinstance(content, unicode): + # should be bytes, but JSON often spits out unicode + content = content.encode('utf8') + else: + raise TypeError("Content incorrect type: %s"%type(content)) + + real_message = [self.pack(msg['header']), + self.pack(msg['parent_header']), + content + ] + + to_send = [] + + if isinstance(ident, list): + # accept list of idents + to_send.extend(ident) + elif ident is not None: + to_send.append(ident) + to_send.append(DELIM) + + signature = self.sign(real_message) + to_send.append(signature) + + to_send.extend(real_message) + + return to_send + + def send(self, stream, msg_or_type, content=None, parent=None, ident=None, + buffers=None, subheader=None, track=False): + """Build and send a message via stream or socket. Parameters ---------- - socket : zmq.Socket - The socket on which to send. - msg_or_type : Message/dict or str - if str : then a new message will be constructed from content,parent - if Message/dict : then content and parent are ignored, and the message - is sent. This is only for use when sending a Message for a second time. - content : dict, optional - The contents of the message - parent : dict, optional - The parent header, or parent message, of this message - ident : bytes, optional - The zmq.IDENTITY prefix of the destination. - Only for use on certain socket types. + + stream : zmq.Socket or ZMQStream + the socket-like object used to send the data + msg_or_type : str or Message/dict + Normally, msg_or_type will be a msg_type unless a message is being sent more + than once. + + content : dict or None + the content of the message (ignored if msg_or_type is a message) + parent : Message or dict or None + the parent or parent header describing the parent of this message + ident : bytes or list of bytes + the zmq.IDENTITY routing path + subheader : dict or None + extra header keys for this message's header + buffers : list or None + the already-serialized buffers to be appended to the message + track : bool + whether to track. Only for use with Sockets, + because ZMQStream objects cannot track messages. Returns ------- - msg : dict - The message, as constructed by self.msg(msg_type,content,parent) + msg : message dict + the constructed message + (msg,tracker) : (message dict, MessageTracker) + if track=True, then a 2-tuple will be returned, + the first element being the constructed + message, and the second being the MessageTracker + """ + + if not isinstance(stream, (zmq.Socket, ZMQStream)): + raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream)) + elif track and isinstance(stream, ZMQStream): + raise TypeError("ZMQStream cannot track messages") + if isinstance(msg_or_type, (Message, dict)): - msg = dict(msg_or_type) + # we got a Message, not a msg_type + # don't build a new Message + msg = msg_or_type else: - msg = self.msg(msg_or_type, content, parent) - if ident is not None: - socket.send(ident, zmq.SNDMORE) - socket.send_json(msg) - return msg - - def recv(self, socket, mode=zmq.NOBLOCK): - """recv a message on a socket. + msg = self.msg(msg_or_type, content, parent, subheader) - Receive an optionally identity-prefixed message, as sent via session.send(). + buffers = [] if buffers is None else buffers + to_send = self.serialize(msg, ident) + flag = 0 + if buffers: + flag = zmq.SNDMORE + _track = False + else: + _track=track + if track: + tracker = stream.send_multipart(to_send, flag, copy=False, track=_track) + else: + tracker = stream.send_multipart(to_send, flag, copy=False) + for b in buffers[:-1]: + stream.send(b, flag, copy=False) + if buffers: + if track: + tracker = stream.send(buffers[-1], copy=False, track=track) + else: + tracker = stream.send(buffers[-1], copy=False) + + # omsg = Message(msg) + if self.debug: + pprint.pprint(msg) + pprint.pprint(to_send) + pprint.pprint(buffers) - Parameters - ---------- + msg['tracker'] = tracker - socket : zmq.Socket - The socket on which to recv a message. - mode : int, optional - the mode flag passed to socket.recv - default: zmq.NOBLOCK + return msg + + def send_raw(self, stream, msg, flags=0, copy=True, ident=None): + """Send a raw message via ident path. - Returns - ------- - (ident,msg) : tuple - always length 2. If no message received, then return is (None,None) - ident : bytes or None - the identity prefix is there was one, None otherwise. - msg : dict or None - The actual message. If mode==zmq.NOBLOCK and no message was waiting, - it will be None. - """ + Parameters + ---------- + msg : list of sendable buffers""" + to_send = [] + if isinstance(ident, bytes): + ident = [ident] + if ident is not None: + to_send.extend(ident) + + to_send.append(DELIM) + to_send.append(self.sign(msg)) + to_send.extend(msg) + stream.send_multipart(msg, flags, copy=copy) + + def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): + """receives and unpacks a message + returns [idents], msg""" + if isinstance(socket, ZMQStream): + socket = socket.socket try: msg = socket.recv_multipart(mode) - except zmq.ZMQError, e: + except zmq.ZMQError as e: if e.errno == zmq.EAGAIN: # We can convert EAGAIN to None as we know in this case - # recv_json won't return None. + # recv_multipart won't return None. return None,None else: raise - if len(msg) == 1: - ident=None - msg = msg[0] - elif len(msg) == 2: - ident, msg = msg + # return an actual Message object + # determine the number of idents by trying to unpack them. + # this is terrible: + idents, msg = self.feed_identities(msg, copy) + try: + return idents, self.unpack_message(msg, content=content, copy=copy) + except Exception as e: + print (idents, msg) + # TODO: handle it + raise e + + def feed_identities(self, msg, copy=True): + """feed until DELIM is reached, then return the prefix as idents and remainder as + msg. This is easily broken by setting an IDENT to DELIM, but that would be silly. + + Parameters + ---------- + msg : a list of Message or bytes objects + the message to be split + copy : bool + flag determining whether the arguments are bytes or Messages + + Returns + ------- + (idents,msg) : two lists + idents will always be a list of bytes - the indentity prefix + msg will be a list of bytes or Messages, unchanged from input + msg should be unpackable via self.unpack_message at this point. + """ + if copy: + idx = msg.index(DELIM) + return msg[:idx], msg[idx+1:] + else: + failed = True + for idx,m in enumerate(msg): + if m.bytes == DELIM: + failed = False + break + if failed: + raise ValueError("DELIM not in msg") + idents, msg = msg[:idx], msg[idx+1:] + return [m.bytes for m in idents], msg + + def unpack_message(self, msg, content=True, copy=True): + """Return a message object from the format + sent by self.send. + + Parameters: + ----------- + + content : bool (True) + whether to unpack the content dict (True), + or leave it serialized (False) + + copy : bool (True) + whether to return the bytes (True), + or the non-copying Message object in each place (False) + + """ + minlen = 4 + message = {} + if not copy: + for i in range(minlen): + msg[i] = msg[i].bytes + if self.auth is not None: + signature = msg[0] + if signature in self.digest_history: + raise ValueError("Duplicate Signature: %r"%signature) + self.digest_history.add(signature) + check = self.sign(msg[1:4]) + if not signature == check: + raise ValueError("Invalid Signature: %r"%signature) + if not len(msg) >= minlen: + raise TypeError("malformed message, must have at least %i elements"%minlen) + message['header'] = self.unpack(msg[1]) + message['msg_type'] = message['header']['msg_type'] + message['parent_header'] = self.unpack(msg[2]) + if content: + message['content'] = self.unpack(msg[3]) else: - raise ValueError("Got message with length > 2, which is invalid") + message['content'] = msg[3] - return ident, json.loads(msg) + message['buffers'] = msg[4:] + return message def test_msg2obj(): am = dict(x=1) diff --git a/IPython/zmq/tests/test_session.py b/IPython/zmq/tests/test_session.py index 051c11c..6279acc 100644 --- a/IPython/zmq/tests/test_session.py +++ b/IPython/zmq/tests/test_session.py @@ -17,14 +17,14 @@ import zmq from zmq.tests import BaseZMQTestCase from zmq.eventloop.zmqstream import ZMQStream -# from IPython.zmq.tests import SessionTestCase -from IPython.parallel import streamsession as ss + +from IPython.zmq import session as ss class SessionTestCase(BaseZMQTestCase): def setUp(self): BaseZMQTestCase.setUp(self) - self.session = ss.StreamSession() + self.session = ss.Session() class TestSession(SessionTestCase): @@ -42,19 +42,19 @@ class TestSession(SessionTestCase): def test_args(self): - """initialization arguments for StreamSession""" + """initialization arguments for Session""" s = self.session self.assertTrue(s.pack is ss.default_packer) self.assertTrue(s.unpack is ss.default_unpacker) self.assertEquals(s.username, os.environ.get('USER', 'username')) - s = ss.StreamSession() + s = ss.Session() self.assertEquals(s.username, os.environ.get('USER', 'username')) - self.assertRaises(TypeError, ss.StreamSession, pack='hi') - self.assertRaises(TypeError, ss.StreamSession, unpack='hi') + self.assertRaises(TypeError, ss.Session, pack='hi') + self.assertRaises(TypeError, ss.Session, unpack='hi') u = str(uuid.uuid4()) - s = ss.StreamSession(username='carrot', session=u) + s = ss.Session(username='carrot', session=u) self.assertEquals(s.session, u) self.assertEquals(s.username, 'carrot')