diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index cb2d75b..673f82f 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -4,6 +4,8 @@ import time import threading +from pprint import pprint + from functools import wraps from IPython.external.decorator import decorator @@ -46,7 +48,9 @@ def defaultblock(f, self, *args, **kwargs): self.block = saveblock return ret - +class AbortedTask(object): + def __init__(self, msg_id): + self.msg_id = msg_id # @decorator # def checktargets(f): # @wraps(f) @@ -101,7 +105,11 @@ class Client(object): execution methods: apply/apply_bound/apply_to legacy: execute, run - control methods: queue_status, get_result + query methods: queue_status, get_result + + control methods: abort, kill + + """ @@ -109,7 +117,8 @@ class Client(object): _connected=False _engines=None registration_socket=None - controller_socket=None + query_socket=None + control_socket=None notification_socket=None queue_socket=None task_socket=None @@ -117,8 +126,9 @@ class Client(object): outstanding=None results = None history = None + debug = False - def __init__(self, addr, context=None, username=None): + def __init__(self, addr, context=None, username=None, debug=False): if context is None: context = zmq.Context() self.context = context @@ -135,6 +145,8 @@ class Client(object): self.outstanding=set() self.results = {} self.history = [] + self.debug = debug + self.session.debug = debug self._connect() self._notification_handlers = {'registration_notification' : self._register_engine, @@ -152,7 +164,7 @@ class Client(object): def _update_engines(self, engines): for k,v in engines.iteritems(): eid = int(k) - self._engines[eid] = v + self._engines[eid] = bytes(v) # force not unicode self._ids.add(eid) def _build_targets(self, targets): @@ -173,7 +185,9 @@ class Client(object): return self._connected=True self.session.send(self.registration_socket, 'connection_request') - msg = self.session.recv(self.registration_socket,mode=0)[-1] + idents,msg = self.session.recv(self.registration_socket,mode=0) + if self.debug: + pprint(msg) msg = ss.Message(msg) content = msg.content if content.status == 'ok': @@ -189,10 +203,14 @@ class Client(object): self.notification_socket = self.context.socket(zmq.SUB) self.notification_socket.connect(content.notification) self.notification_socket.setsockopt(zmq.SUBSCRIBE, "") - if content.controller: - self.controller_socket = self.context.socket(zmq.PAIR) - self.controller_socket.setsockopt(zmq.IDENTITY, self.session.session) - self.controller_socket.connect(content.controller) + if content.query: + self.query_socket = self.context.socket(zmq.PAIR) + self.query_socket.setsockopt(zmq.IDENTITY, self.session.session) + self.query_socket.connect(content.query) + if content.control: + self.control_socket = self.context.socket(zmq.PAIR) + self.control_socket.setsockopt(zmq.IDENTITY, self.session.session) + self.control_socket.connect(content.control) self._update_engines(dict(content.engines)) else: @@ -226,7 +244,7 @@ class Client(object): self.results[msg_id] = ss.unwrap_exception(msg['content']) def _handle_apply_reply(self, msg): - # print msg + # pprint(msg) # msg_id = msg['msg_id'] parent = msg['parent_header'] msg_id = parent['msg_id'] @@ -237,14 +255,19 @@ class Client(object): content = msg['content'] if content['status'] == 'ok': self.results[msg_id] = ss.unserialize_object(msg['buffers']) + elif content['status'] == 'aborted': + self.results[msg_id] = AbortedTask(msg_id) + elif content['status'] == 'resubmitted': + pass # handle resubmission else: - self.results[msg_id] = ss.unwrap_exception(content) def _flush_notifications(self): "flush incoming notifications of engine registrations" msg = self.session.recv(self.notification_socket, mode=zmq.NOBLOCK) while msg is not None: + if self.debug: + pprint(msg) msg = msg[-1] msg_type = msg['msg_type'] handler = self._notification_handlers.get(msg_type, None) @@ -258,6 +281,8 @@ class Client(object): "flush incoming task or queue results" msg = self.session.recv(sock, mode=zmq.NOBLOCK) while msg is not None: + if self.debug: + pprint(msg) msg = msg[-1] msg_type = msg['msg_type'] handler = self._queue_handlers.get(msg_type, None) @@ -267,6 +292,14 @@ class Client(object): handler(msg) msg = self.session.recv(sock, mode=zmq.NOBLOCK) + def _flush_control(self, sock): + "flush incoming control replies" + msg = self.session.recv(sock, mode=zmq.NOBLOCK) + while msg is not None: + if self.debug: + pprint(msg) + msg = self.session.recv(sock, mode=zmq.NOBLOCK) + ###### get/setitem ######## def __getitem__(self, key): @@ -297,6 +330,8 @@ class Client(object): self._flush_results(self.queue_socket) if self.task_socket: self._flush_results(self.task_socket) + if self.control_socket: + self._flush_control(self.control_socket) @spinfirst def queue_status(self, targets=None, verbose=False): @@ -308,25 +343,79 @@ class Client(object): the engines on which to execute default : all verbose : bool - whether to return + whether to return lengths only, or lists of ids for each element """ targets = self._build_targets(targets)[1] content = dict(targets=targets) - self.session.send(self.controller_socket, "queue_request", content=content) - idents,msg = self.session.recv(self.controller_socket, 0) + self.session.send(self.query_socket, "queue_request", content=content) + idents,msg = self.session.recv(self.query_socket, 0) + if self.debug: + pprint(msg) return msg['content'] @spinfirst - def clear(self, targets=None): + @defaultblock + def clear(self, targets=None, block=None): """clear the namespace in target(s)""" - pass + targets = self._build_targets(targets)[0] + print targets + for t in targets: + self.session.send(self.control_socket, 'clear_request', content={},ident=t) + error = False + if self.block: + for i in range(len(targets)): + idents,msg = self.session.recv(self.control_socket,0) + if self.debug: + pprint(msg) + if msg['content']['status'] != 'ok': + error = msg['content'] + if error: + return error + @spinfirst - def abort(self, targets=None): + @defaultblock + def abort(self, msg_ids = None, targets=None, block=None): """abort the Queues of target(s)""" - pass + targets = self._build_targets(targets)[0] + print targets + if isinstance(msg_ids, basestring): + msg_ids = [msg_ids] + content = dict(msg_ids=msg_ids) + for t in targets: + self.session.send(self.control_socket, 'abort_request', + content=content, ident=t) + error = False + if self.block: + for i in range(len(targets)): + idents,msg = self.session.recv(self.control_socket,0) + if self.debug: + pprint(msg) + if msg['content']['status'] != 'ok': + error = msg['content'] + if error: + return error + @spinfirst + @defaultblock + def kill(self, targets=None, block=None): + """Terminates one or more engine processes.""" + targets = self._build_targets(targets)[0] + print targets + for t in targets: + self.session.send(self.control_socket, 'kill_request', content={},ident=t) + error = False + if self.block: + for i in range(len(targets)): + idents,msg = self.session.recv(self.control_socket,0) + if self.debug: + pprint(msg) + if msg['content']['status'] != 'ok': + error = msg['content'] + if error: + return error + @defaultblock def execute(self, code, targets='all', block=None): """executes `code` on `targets` in blocking or nonblocking manner. @@ -363,22 +452,6 @@ class Client(object): """ result = self.apply(execute, (code,), targets=None, block=block, bound=False) return result - - # a = time.time() - # content = dict(code=code) - # b = time.time() - # msg = self.session.send(self.task_socket, 'execute_request', - # content=content) - # c = time.time() - # msg_id = msg['msg_id'] - # self.outstanding.add(msg_id) - # self.history.append(msg_id) - # d = time.time() - # if block: - # self.barrier(msg_id) - # return self.results[msg_id] - # else: - # return msg_id def _apply_balanced(self, f, args, kwargs, bound=True, block=None): """the underlying method for applying functions in a load balanced @@ -402,7 +475,7 @@ class Client(object): """Then underlying method for applying functions to specific engines.""" block = block if block is not None else self.block queues,targets = self._build_targets(targets) - + print queues bufs = ss.pack_apply_message(f,args,kwargs) content = dict(bound=bound) msg_ids = [] @@ -438,51 +511,16 @@ class Client(object): """ args = args if args is not None else [] kwargs = kwargs if kwargs is not None else {} + if not isinstance(args, (tuple, list)): + raise TypeError("args must be tuple or list, not %s"%type(args)) + if not isinstance(kwargs, dict): + raise TypeError("kwargs must be dict, not %s"%type(kwargs)) if targets is None: return self._apply_balanced(f,args,kwargs,bound=bound, block=block) else: return self._apply_direct(f, args, kwargs, bound=bound,block=block, targets=targets) - # def apply_bound(self, f, *args, **kwargs): - # """calls f(*args, **kwargs) on a remote engine. This does get - # executed in an engine's namespace. The controller selects the - # target engine via 0MQ XREQ load balancing. - # - # if self.block is False: - # returns msg_id - # else: - # returns actual result of f(*args, **kwargs) - # """ - # return self._apply(f, args, kwargs, bound=True) - # - # - # def apply_to(self, targets, f, *args, **kwargs): - # """calls f(*args, **kwargs) on a specific engine. - # - # if self.block is False: - # returns msg_id - # else: - # returns actual result of f(*args, **kwargs) - # - # The target's namespace is not used here. - # Use apply_bound_to() to access target's globals. - # """ - # return self._apply_to(False, targets, f, args, kwargs) - # - # def apply_bound_to(self, targets, f, *args, **kwargs): - # """calls f(*args, **kwargs) on a specific engine. - # - # if self.block is False: - # returns msg_id - # else: - # returns actual result of f(*args, **kwargs) - # - # This method has access to the target's globals - # - # """ - # return self._apply_to(f, args, kwargs) - # def push(self, ns, targets=None, block=None): """push the contents of `ns` into the namespace on `target`""" if not isinstance(ns, dict): @@ -546,9 +584,11 @@ class Client(object): theids.append(msg_id) content = dict(msg_ids=theids, status_only=status_only) - msg = self.session.send(self.controller_socket, "result_request", content=content) - zmq.select([self.controller_socket], [], []) - idents,msg = self.session.recv(self.controller_socket, zmq.NOBLOCK) + msg = self.session.send(self.query_socket, "result_request", content=content) + zmq.select([self.query_socket], [], []) + idents,msg = self.session.recv(self.query_socket, zmq.NOBLOCK) + if self.debug: + pprint(msg) # while True: # try: diff --git a/IPython/zmq/parallel/controller.py b/IPython/zmq/parallel/controller.py index 7c93f32..b2635a2 100644 --- a/IPython/zmq/parallel/controller.py +++ b/IPython/zmq/parallel/controller.py @@ -297,6 +297,8 @@ class Controller(object): self.save_task_result(idents, msg) elif switch == 'tracktask': self.save_task_destination(idents, msg) + elif switch in ('incontrol', 'outcontrol'): + pass else: logger.error("Invalid message topic: %s"%switch) diff --git a/IPython/zmq/parallel/engine.py b/IPython/zmq/parallel/engine.py index 2a6b781..320fadb 100644 --- a/IPython/zmq/parallel/engine.py +++ b/IPython/zmq/parallel/engine.py @@ -7,6 +7,7 @@ import sys import time import traceback import uuid +from pprint import pprint import zmq from zmq.eventloop import ioloop, zmqstream @@ -20,7 +21,7 @@ import heartmonitor def printer(*msg): - print msg + pprint(msg) class Engine(object): """IPython engine""" @@ -29,26 +30,23 @@ class Engine(object): context=None loop=None session=None - queue_id=None - control_id=None - heart_id=None + ident=None registrar=None heart=None kernel=None - def __init__(self, context, loop, session, registrar, client, queue_id=None, heart_id=None): + def __init__(self, context, loop, session, registrar, client, ident=None, heart_id=None): self.context = context self.loop = loop self.session = session self.registrar = registrar self.client = client - self.queue_id = queue_id or str(uuid.uuid4()) - self.heart_id = heart_id or self.queue_id + self.ident = ident if ident else str(uuid.uuid4()) self.registrar.on_send(printer) def register(self): - content = dict(queue=self.queue_id, heartbeat=self.heart_id) + content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident) self.registrar.on_recv(self.complete_registration) self.session.send(self.registrar, "registration_request",content=content) @@ -61,14 +59,14 @@ class Engine(object): queue_addr = msg.content.queue if queue_addr: queue = self.context.socket(zmq.PAIR) - queue.setsockopt(zmq.IDENTITY, self.queue_id) + queue.setsockopt(zmq.IDENTITY, self.ident) queue.connect(str(queue_addr)) self.queue = zmqstream.ZMQStream(queue, self.loop) control_addr = msg.content.control if control_addr: control = self.context.socket(zmq.PAIR) - control.setsockopt(zmq.IDENTITY, self.queue_id) + control.setsockopt(zmq.IDENTITY, self.ident) control.connect(str(control_addr)) self.control = zmqstream.ZMQStream(control, self.loop) @@ -81,14 +79,14 @@ class Engine(object): self.task_stream = zmqstream.ZMQStream(task, self.loop) # TaskThread: # mon_addr = msg.content.monitor - # task = taskthread.TaskThread(zmq.PAIR, zmq.PUB, self.queue_id) + # task = taskthread.TaskThread(zmq.PAIR, zmq.PUB, self.ident) # task.connect_in(str(task_addr)) # task.connect_out(str(mon_addr)) # self.task_stream = taskthread.QueueStream(*task.queues) # task.start() hbs = msg.content.heartbeat - self.heart = heartmonitor.Heart(*map(str, hbs), heart_id=self.heart_id) + self.heart = heartmonitor.Heart(*map(str, hbs), heart_id=self.ident) self.heart.start() # ioloop.DelayedCallback(self.heart.start, 1000, self.loop).start() # placeholder for now: diff --git a/IPython/zmq/parallel/streamkernel.py b/IPython/zmq/parallel/streamkernel.py index 25b8d9b..2c2c6f3 100755 --- a/IPython/zmq/parallel/streamkernel.py +++ b/IPython/zmq/parallel/streamkernel.py @@ -4,10 +4,12 @@ Kernel adapted from kernel.py to use ZMQ Streams """ import __builtin__ +import os import sys import time import traceback from signal import SIGTERM, SIGKILL +from pprint import pprint from code import CommandCompiler @@ -18,6 +20,9 @@ from streamsession import StreamSession, Message, extract_header, serialize_obje unpack_apply_message from IPython.zmq.completer import KernelCompleter +def printer(*args): + pprint(args) + class OutStream(object): """A file like object that publishes the stream to a 0MQ PUB socket.""" @@ -133,6 +138,7 @@ class Kernel(object): task_stream=None, client=None): self.session = session self.control_stream = control_stream + self.control_socket = control_stream.socket self.reply_stream = reply_stream self.task_stream = task_stream self.pub_stream = pub_stream @@ -153,6 +159,10 @@ class Kernel(object): self.control_handlers[msg_type] = getattr(self, msg_type) #-------------------- control handlers ----------------------------- + def abort_queues(self): + for stream in (self.task_stream, self.reply_stream): + if stream: + self.abort_queue(stream) def abort_queue(self, stream): while True: @@ -186,28 +196,30 @@ class Kernel(object): time.sleep(0.05) def abort_request(self, stream, ident, parent): + """abort a specifig msg by id""" msg_ids = parent['content'].get('msg_ids', None) + if isinstance(msg_ids, basestring): + msg_ids = [msg_ids] if not msg_ids: - self.abort_queue(self.task_stream) - self.abort_queue(self.reply_stream) + self.abort_queues() for mid in msg_ids: - self.aborted.add(mid) + self.aborted.add(str(mid)) content = dict(status='ok') - self.session.send(stream, 'abort_reply', content=content, parent=parent, + reply_msg = self.session.send(stream, 'abort_reply', content=content, parent=parent, ident=ident) + print>>sys.__stdout__, Message(reply_msg) def kill_request(self, stream, idents, parent): - self.abort_queue(self.reply_stream) - if self.task_stream: - self.abort_queue(self.task_stream) + """kill ourselves. This should really be handled in an external process""" + self.abort_queues() msg = self.session.send(stream, 'kill_reply', ident=idents, parent=parent, content = dict(status='ok')) # we can know that a message is done if we *don't* use streams, but # use a socket directly with MessageTracker - time.sleep(1) + time.sleep(.5) os.kill(os.getpid(), SIGTERM) - time.sleep(.25) + time.sleep(1) os.kill(os.getpid(), SIGKILL) def dispatch_control(self, msg): @@ -221,7 +233,7 @@ class Kernel(object): if handler is None: print >> sys.__stderr__, "UNKNOWN CONTROL MESSAGE TYPE:", msg else: - handler(stream, idents, msg) + handler(self.control_stream, idents, msg) def flush_control(self): while any(zmq.select([self.control_socket],[],[],1e-4)): @@ -258,6 +270,16 @@ class Kernel(object): return True + def check_aborted(self, msg_id): + return msg_id in self.aborted + + def unmet_dependencies(self, stream, idents, msg): + reply_type = msg['msg_type'].split('_')[0] + '_reply' + content = dict(status='resubmitted', reason='unmet dependencies') + reply_msg = self.session.send(stream, reply_type, + content=content, parent=msg, ident=idents) + ### TODO: actually resubmit it ### + #-------------------- queue handlers ----------------------------- def execute_request(self, stream, ident, parent): @@ -297,7 +319,7 @@ class Kernel(object): reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent, ident=ident) # print>>sys.__stdout__, Message(reply_msg) if reply_msg['content']['status'] == u'error': - self.abort_queue() + self.abort_queues() def complete_request(self, stream, ident, parent): matches = {'matches' : self.complete(parent), @@ -334,7 +356,7 @@ class Kernel(object): else: working = dict() - suffix = prefix = "" + suffix = prefix = "_" # prevent keyword collisions with lambda f,args,kwargs = unpack_apply_message(bufs, working, copy=False) # if f.fun fname = prefix+f.func_name.strip('<>')+suffix @@ -379,7 +401,7 @@ class Kernel(object): reply_msg = self.session.send(stream, u'apply_reply', reply_content, parent=parent, ident=ident,buffers=result_buf) # print>>sys.__stdout__, Message(reply_msg) if reply_msg['content']['status'] == u'error': - self.abort_queue() + self.abort_queues() def dispatch_queue(self, stream, msg): self.flush_control() @@ -389,12 +411,15 @@ class Kernel(object): header = msg['header'] msg_id = header['msg_id'] dependencies = header.get('dependencies', []) - if self.check_aborted(msg_id): - return self.abort_reply(stream, msg) + self.aborted.remove(msg_id) + # is it safe to assume a msg_id will not be resubmitted? + reply_type = msg['msg_type'].split('_')[0] + '_reply' + reply_msg = self.session.send(stream, reply_type, + content={'status' : 'aborted'}, parent=msg, ident=idents) + return if not self.check_dependencies(dependencies): - return self.unmet_dependencies(stream, msg) - + return self.unmet_dependencies(stream, idents, msg) handler = self.queue_handlers.get(msg['msg_type'], None) if handler is None: print >> sys.__stderr__, "UNKNOWN MESSAGE TYPE:", msg @@ -405,12 +430,15 @@ class Kernel(object): #### stream mode: if self.control_stream: self.control_stream.on_recv(self.dispatch_control, copy=False) + self.control_stream.on_err(printer) if self.reply_stream: self.reply_stream.on_recv(lambda msg: self.dispatch_queue(self.reply_stream, msg), copy=False) + self.reply_stream.on_err(printer) if self.task_stream: self.task_stream.on_recv(lambda msg: self.dispatch_queue(self.task_stream, msg), copy=False) + self.task_stream.on_err(printer) #### while True mode: # while True: diff --git a/IPython/zmq/parallel/streamsession.py b/IPython/zmq/parallel/streamsession.py index 710a605..c01e300 100644 --- a/IPython/zmq/parallel/streamsession.py +++ b/IPython/zmq/parallel/streamsession.py @@ -257,7 +257,7 @@ def unpack_apply_message(bufs, g=None, copy=True): class StreamSession(object): """tweaked version of IPython.zmq.session.Session, for development in Parallel""" - + debug=False def __init__(self, username=None, session=None, packer=None, unpacker=None): if username is None: username = os.environ.get('USER','username') @@ -335,6 +335,10 @@ class StreamSession(object): if buffers: stream.send(buffers[-1], copy=False) omsg = Message(msg) + if self.debug: + pprint.pprint(omsg) + pprint.pprint(to_send) + pprint.pprint(buffers) return omsg def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): diff --git a/IPython/zmq/parallel/view.py b/IPython/zmq/parallel/view.py index b6d23ec..382f4d1 100644 --- a/IPython/zmq/parallel/view.py +++ b/IPython/zmq/parallel/view.py @@ -103,7 +103,7 @@ class View(object): This method has access to the targets' globals """ - return self.client.apply(f, args, kwargs, block=False, targets=self.targets, bound=True) + return self.client.apply(f, args, kwargs, block=True, targets=self.targets, bound=True) class DirectView(View): @@ -129,12 +129,28 @@ class DirectView(View): def __setitem__(self,key,value): self.update({key:value}) - def clear(self): - """clear the remote namespace""" - return self.client.clear(targets=self.targets,block=self.block) + def clear(self, block=False): + """Clear the remote namespaces on my engines.""" + block = block if block is not None else self.block + return self.client.clear(targets=self.targets,block=block) + + def kill(self, block=True): + """Kill my engines.""" + block = block if block is not None else self.block + return self.client.kill(targets=self.targets,block=block) - def abort(self): - return self.client.abort(targets=self.targets,block=self.block) + def abort(self, msg_ids=None, block=None): + """Abort jobs on my engines. + + Parameters + ---------- + + msg_ids : None, str, list of strs, optional + if None: abort all jobs. + else: abort specific msg_id(s). + """ + block = block if block is not None else self.block + return self.client.abort(msg_ids=msg_ids, targets=self.targets, block=block) class LoadBalancedView(View): _targets=None diff --git a/examples/zmqontroller/controller.py b/examples/zmqontroller/controller.py index 05064f7..4f95537 100644 --- a/examples/zmqontroller/controller.py +++ b/examples/zmqontroller/controller.py @@ -124,7 +124,7 @@ def setup(): client_addrs = { 'control' : "%s:%i"%(iface, ccport), - 'controller': "%s:%i"%(iface, cport), + 'query': "%s:%i"%(iface, cport), 'queue': "%s:%i"%(iface, cqport), 'task' : "%s:%i"%(iface, ctport), 'notification': "%s:%i"%(iface, nport)