diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index e54483f..83951f2 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -265,7 +265,7 @@ class Client(HasTraits): _context = Instance('zmq.Context') _config = Dict() _engines=Instance(ReverseDict, (), {}) - _registration_socket=Instance('zmq.Socket') + # _hub_socket=Instance('zmq.Socket') _query_socket=Instance('zmq.Socket') _control_socket=Instance('zmq.Socket') _iopub_socket=Instance('zmq.Socket') @@ -339,12 +339,12 @@ class Client(HasTraits): self.session = ss.StreamSession(**key_arg) else: self.session = ss.StreamSession(username, **key_arg) - self._registration_socket = self._context.socket(zmq.XREQ) - self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session) + self._query_socket = self._context.socket(zmq.XREQ) + self._query_socket.setsockopt(zmq.IDENTITY, self.session.session) if self._ssh: - tunnel.tunnel_connection(self._registration_socket, url, sshserver, **ssh_kwargs) + tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs) else: - self._registration_socket.connect(url) + self._query_socket.connect(url) self.session.debug = self.debug @@ -449,8 +449,8 @@ class Client(HasTraits): else: return s.connect(url) - self.session.send(self._registration_socket, 'connection_request') - idents,msg = self.session.recv(self._registration_socket,mode=0) + self.session.send(self._query_socket, 'connection_request') + idents,msg = self.session.recv(self._query_socket,mode=0) if self.debug: pprint(msg) msg = ss.Message(msg) @@ -458,29 +458,29 @@ class Client(HasTraits): self._config['registration'] = dict(content) if content.status == 'ok': if content.mux: - self._mux_socket = self._context.socket(zmq.PAIR) + self._mux_socket = self._context.socket(zmq.XREQ) self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session) connect_socket(self._mux_socket, content.mux) if content.task: self._task_scheme, task_addr = content.task - self._task_socket = self._context.socket(zmq.PAIR) + self._task_socket = self._context.socket(zmq.XREQ) self._task_socket.setsockopt(zmq.IDENTITY, self.session.session) connect_socket(self._task_socket, task_addr) if content.notification: self._notification_socket = self._context.socket(zmq.SUB) connect_socket(self._notification_socket, content.notification) - self._notification_socket.setsockopt(zmq.SUBSCRIBE, "") - if content.query: - self._query_socket = self._context.socket(zmq.PAIR) - self._query_socket.setsockopt(zmq.IDENTITY, self.session.session) - connect_socket(self._query_socket, content.query) + self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'') + # if content.query: + # self._query_socket = self._context.socket(zmq.XREQ) + # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session) + # connect_socket(self._query_socket, content.query) if content.control: - self._control_socket = self._context.socket(zmq.PAIR) + self._control_socket = self._context.socket(zmq.XREQ) self._control_socket.setsockopt(zmq.IDENTITY, self.session.session) connect_socket(self._control_socket, content.control) if content.iopub: self._iopub_socket = self._context.socket(zmq.SUB) - self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '') + self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'') self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session) connect_socket(self._iopub_socket, content.iopub) self._update_engines(dict(content.engines)) @@ -496,6 +496,7 @@ class Client(HasTraits): def _unwrap_exception(self, content): """unwrap exception, and remap engineid to int.""" e = error.unwrap_exception(content) + print e.traceback if e.engine_info: e_uuid = e.engine_info['engine_uuid'] eid = self._engines[e_uuid] diff --git a/IPython/zmq/parallel/engine.py b/IPython/zmq/parallel/engine.py index 08e9b98..c90a0f5 100755 --- a/IPython/zmq/parallel/engine.py +++ b/IPython/zmq/parallel/engine.py @@ -41,7 +41,7 @@ class EngineFactory(RegistrationFactory): super(EngineFactory, self).__init__(**kwargs) ctx = self.context - reg = ctx.socket(zmq.PAIR) + reg = ctx.socket(zmq.XREQ) reg.setsockopt(zmq.IDENTITY, self.ident) reg.connect(self.url) self.registrar = zmqstream.ZMQStream(reg, self.loop) @@ -74,16 +74,26 @@ class EngineFactory(RegistrationFactory): task_addr = msg.content.task if task_addr: shell_addrs.append(str(task_addr)) - shell_streams = [] + + # Uncomment this to go back to two-socket model + # shell_streams = [] + # for addr in shell_addrs: + # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop) + # stream.setsockopt(zmq.IDENTITY, identity) + # stream.connect(disambiguate_url(addr, self.location)) + # shell_streams.append(stream) + + # Now use only one shell stream for mux and tasks + stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop) + stream.setsockopt(zmq.IDENTITY, identity) + shell_streams = [stream] for addr in shell_addrs: - stream = zmqstream.ZMQStream(ctx.socket(zmq.PAIR), loop) - stream.setsockopt(zmq.IDENTITY, identity) stream.connect(disambiguate_url(addr, self.location)) - shell_streams.append(stream) + # end single stream-socket # control stream: control_addr = str(msg.content.control) - control_stream = zmqstream.ZMQStream(ctx.socket(zmq.PAIR), loop) + control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop) control_stream.setsockopt(zmq.IDENTITY, identity) control_stream.connect(disambiguate_url(control_addr, self.location)) diff --git a/IPython/zmq/parallel/hub.py b/IPython/zmq/parallel/hub.py index 8166f67..9c45d66 100755 --- a/IPython/zmq/parallel/hub.py +++ b/IPython/zmq/parallel/hub.py @@ -119,10 +119,6 @@ class HubFactory(RegistrationFactory): def _mon_port_default(self): return select_random_ports(1)[0] - query_port = Instance(int, config=True) - def _query_port_default(self): - return select_random_ports(1)[0] - notifier_port = Instance(int, config=True) def _notifier_port_default(self): return select_random_ports(1)[0] @@ -194,11 +190,11 @@ class HubFactory(RegistrationFactory): loop = self.loop # Registrar socket - reg = ZMQStream(ctx.socket(zmq.XREP), loop) - reg.bind(client_iface % self.regport) + q = ZMQStream(ctx.socket(zmq.XREP), loop) + q.bind(client_iface % self.regport) self.log.info("Hub listening on %s for registration."%(client_iface%self.regport)) if self.client_ip != self.engine_ip: - reg.bind(engine_iface % self.regport) + q.bind(engine_iface % self.regport) self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport)) ### Engine connections ### @@ -212,9 +208,6 @@ class HubFactory(RegistrationFactory): period=self.ping, logname=self.log.name) ### Client connections ### - # Clientele socket - c = ZMQStream(ctx.socket(zmq.XREP), loop) - c.bind(client_iface%self.query_port) # Notifier socket n = ZMQStream(ctx.socket(zmq.PUB), loop) n.bind(client_iface%self.notifier_port) @@ -230,7 +223,7 @@ class HubFactory(RegistrationFactory): # connect the db self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1])) - cdir = self.config.Global.cluster_dir + # cdir = self.config.Global.cluster_dir self.db = import_item(self.db_class)(session=self.session.session, config=self.config) time.sleep(.25) @@ -246,7 +239,6 @@ class HubFactory(RegistrationFactory): self.client_info = { 'control' : client_iface%self.control[0], - 'query': client_iface%self.query_port, 'mux': client_iface%self.mux[0], 'task' : (self.scheme, client_iface%self.task[0]), 'iopub' : client_iface%self.iopub[0], @@ -255,7 +247,7 @@ class HubFactory(RegistrationFactory): self.log.debug("Hub engine addrs: %s"%self.engine_info) self.log.debug("Hub client addrs: %s"%self.client_info) self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor, - registrar=reg, clientele=c, notifier=n, db=self.db, + query=q, notifier=n, db=self.db, engine_info=self.engine_info, client_info=self.client_info, logname=self.log.name) @@ -269,10 +261,8 @@ class Hub(LoggingFactory): session: StreamSession object context: zmq context for creating new connections (?) queue: ZMQStream for monitoring the command queue (SUB) - registrar: ZMQStream for engine registration requests (XREP) + query: ZMQStream for engine registration and client queries requests (XREP) heartbeat: HeartMonitor object checking the pulse of the engines - clientele: ZMQStream for client connections (XREP) - not used for jobs, only query/control commands notifier: ZMQStream for broadcasting engine registration changes (PUB) db: connection to db for out of memory logging of commands NotImplemented @@ -300,8 +290,7 @@ class Hub(LoggingFactory): # objects from constructor: loop=Instance(ioloop.IOLoop) - registrar=Instance(ZMQStream) - clientele=Instance(ZMQStream) + query=Instance(ZMQStream) monitor=Instance(ZMQStream) heartmonitor=Instance(HeartMonitor) notifier=Instance(ZMQStream) @@ -317,10 +306,8 @@ class Hub(LoggingFactory): session: streamsession for sending serialized data # engine: queue: ZMQStream for monitoring queue messages - registrar: ZMQStream for engine registration + query: ZMQStream for engine+client registration and client requests heartbeat: HeartMonitor object for tracking engines - # client: - clientele: ZMQStream for client connections # extra: db: ZMQStream for db connection (NotImplemented) engine_info: zmq address/protocol dict for engine connections @@ -340,8 +327,7 @@ class Hub(LoggingFactory): validate_url_container(self.engine_info) # register our callbacks - self.registrar.on_recv(self.dispatch_register_request) - self.clientele.on_recv(self.dispatch_client_msg) + self.query.on_recv(self.dispatch_query) self.monitor.on_recv(self.dispatch_monitor_traffic) self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure) @@ -357,15 +343,13 @@ class Hub(LoggingFactory): 'iopub': self.save_iopub_message, } - self.client_handlers = {'queue_request': self.queue_status, + self.query_handlers = {'queue_request': self.queue_status, 'result_request': self.get_results, 'purge_request': self.purge_results, 'load_request': self.check_load, 'resubmit_request': self.resubmit_task, 'shutdown_request': self.shutdown_request, - } - - self.registrar_handlers = {'registration_request' : self.register_engine, + 'registration_request' : self.register_engine, 'unregistration_request' : self.unregister_engine, 'connection_request': self.connection_request, } @@ -418,27 +402,27 @@ class Hub(LoggingFactory): # dispatch methods (1 per stream) #----------------------------------------------------------------------------- - def dispatch_register_request(self, msg): - """""" - self.log.debug("registration::dispatch_register_request(%s)"%msg) - idents,msg = self.session.feed_identities(msg) - if not idents: - self.log.error("Bad Queue Message: %s"%msg, exc_info=True) - return - try: - msg = self.session.unpack_message(msg,content=True) - except: - self.log.error("registration::got bad registration message: %s"%msg, exc_info=True) - return - - msg_type = msg['msg_type'] - content = msg['content'] - - handler = self.registrar_handlers.get(msg_type, None) - if handler is None: - self.log.error("registration::got bad registration message: %s"%msg) - else: - handler(idents, msg) + # def dispatch_registration_request(self, msg): + # """""" + # self.log.debug("registration::dispatch_register_request(%s)"%msg) + # idents,msg = self.session.feed_identities(msg) + # if not idents: + # self.log.error("Bad Query Message: %s"%msg, exc_info=True) + # return + # try: + # msg = self.session.unpack_message(msg,content=True) + # except: + # self.log.error("registration::got bad registration message: %s"%msg, exc_info=True) + # return + # + # msg_type = msg['msg_type'] + # content = msg['content'] + # + # handler = self.query_handlers.get(msg_type, None) + # if handler is None: + # self.log.error("registration::got bad registration message: %s"%msg) + # else: + # handler(idents, msg) def dispatch_monitor_traffic(self, msg): """all ME and Task queue messages come through here, as well as @@ -456,37 +440,37 @@ class Hub(LoggingFactory): self.log.error("Invalid monitor topic: %s"%switch) - def dispatch_client_msg(self, msg): - """Route messages from clients""" + def dispatch_query(self, msg): + """Route registration requests and queries from clients.""" idents, msg = self.session.feed_identities(msg) if not idents: - self.log.error("Bad Client Message: %s"%msg) + self.log.error("Bad Query Message: %s"%msg) return client_id = idents[0] try: msg = self.session.unpack_message(msg, content=True) except: content = error.wrap_exception() - self.log.error("Bad Client Message: %s"%msg, exc_info=True) - self.session.send(self.clientele, "hub_error", ident=client_id, + self.log.error("Bad Query Message: %s"%msg, exc_info=True) + self.session.send(self.query, "hub_error", ident=client_id, content=content) return # print client_id, header, parent, content #switch on message type: msg_type = msg['msg_type'] - self.log.info("client:: client %s requested %s"%(client_id, msg_type)) - handler = self.client_handlers.get(msg_type, None) + self.log.info("client::client %s requested %s"%(client_id, msg_type)) + handler = self.query_handlers.get(msg_type, None) try: assert handler is not None, "Bad Message Type: %s"%msg_type except: content = error.wrap_exception() self.log.error("Bad Message Type: %s"%msg_type, exc_info=True) - self.session.send(self.clientele, "hub_error", ident=client_id, + self.session.send(self.query, "hub_error", ident=client_id, content=content) return else: - handler(client_id, msg) + handler(idents, msg) def dispatch_db(self, msg): """""" @@ -752,7 +736,7 @@ class Hub(LoggingFactory): for k,v in self.keytable.iteritems(): jsonable[str(k)] = v content['engines'] = jsonable - self.session.send(self.registrar, 'connection_reply', content, parent=msg, ident=client_id) + self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id) def register_engine(self, reg, msg): """Register a new engine.""" @@ -801,7 +785,7 @@ class Hub(LoggingFactory): content = error.wrap_exception() break - msg = self.session.send(self.registrar, "registration_reply", + msg = self.session.send(self.query, "registration_reply", content=content, ident=reg) @@ -912,7 +896,7 @@ class Hub(LoggingFactory): # for eid,ec in self.engines.iteritems(): # self.session.send(s, 'shutdown_request', content=dict(restart=False), ident=ec.queue) # time.sleep(1) - self.session.send(self.clientele, 'shutdown_reply', content={'status': 'ok'}, ident=client_id) + self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id) dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop) dc.start() @@ -929,7 +913,7 @@ class Hub(LoggingFactory): targets = self._validate_targets(targets) except: content = error.wrap_exception() - self.session.send(self.clientele, "hub_error", + self.session.send(self.query, "hub_error", content=content, ident=client_id) return @@ -937,7 +921,7 @@ class Hub(LoggingFactory): # loads = {} for t in targets: content[bytes(t)] = len(self.queues[t])+len(self.tasks[t]) - self.session.send(self.clientele, "load_reply", content=content, ident=client_id) + self.session.send(self.query, "load_reply", content=content, ident=client_id) def queue_status(self, client_id, msg): @@ -953,7 +937,7 @@ class Hub(LoggingFactory): targets = self._validate_targets(targets) except: content = error.wrap_exception() - self.session.send(self.clientele, "hub_error", + self.session.send(self.query, "hub_error", content=content, ident=client_id) return verbose = content.get('verbose', False) @@ -968,7 +952,7 @@ class Hub(LoggingFactory): tasks = len(tasks) content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks} # pending - self.session.send(self.clientele, "queue_reply", content=content, ident=client_id) + self.session.send(self.query, "queue_reply", content=content, ident=client_id) def purge_results(self, client_id, msg): """Purge results from memory. This method is more valuable before we move @@ -1006,7 +990,7 @@ class Hub(LoggingFactory): uid = self.engines[eid].queue self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None})) - self.session.send(self.clientele, 'purge_reply', content=reply, ident=client_id) + self.session.send(self.query, 'purge_reply', content=reply, ident=client_id) def resubmit_task(self, client_id, msg, buffers): """Resubmit a task.""" @@ -1049,7 +1033,7 @@ class Hub(LoggingFactory): except: content = error.wrap_exception() break - self.session.send(self.clientele, "result_reply", content=content, + self.session.send(self.query, "result_reply", content=content, parent=msg, ident=client_id, buffers=buffers)