From 23d190659dd7c2ed4db834054b835d31f92a0932 2011-04-08 00:38:17 From: MinRK Date: 2011-04-08 00:38:17 Subject: [PATCH] add timeout for unmet dependencies in task scheduler --- diff --git a/IPython/zmq/parallel/asyncresult.py b/IPython/zmq/parallel/asyncresult.py index 887fb12..d92dba8 100644 --- a/IPython/zmq/parallel/asyncresult.py +++ b/IPython/zmq/parallel/asyncresult.py @@ -36,7 +36,7 @@ class AsyncResult(object): self._fname=fname self._ready = False self._success = None - self._flatten_result = len(msg_ids) == 1 + self._single_result = len(msg_ids) == 1 def __repr__(self): if self._ready: @@ -50,7 +50,7 @@ class AsyncResult(object): Override me in subclasses for turning a list of results into the expected form. """ - if self._flatten_result: + if self._single_result: return res[0] else: return res @@ -90,7 +90,12 @@ class AsyncResult(object): try: results = map(self._client.results.get, self.msg_ids) self._result = results - results = error.collect_exceptions(results, self._fname) + if self._single_result: + r = results[0] + if isinstance(r, Exception): + raise r + else: + results = error.collect_exceptions(results, self._fname) self._result = self._reconstruct_result(results) except Exception, e: self._exception = e @@ -138,7 +143,7 @@ class AsyncResult(object): @check_ready def metadata(self): """metadata property.""" - if self._flatten_result: + if self._single_result: return self._metadata[0] else: return self._metadata @@ -165,7 +170,7 @@ class AsyncResult(object): return error.collect_exceptions(self._result[key], self._fname) elif isinstance(key, basestring): values = [ md[key] for md in self._metadata ] - if self._flatten_result: + if self._single_result: return values[0] else: return values @@ -190,7 +195,7 @@ class AsyncMapResult(AsyncResult): def __init__(self, client, msg_ids, mapObject, fname=''): AsyncResult.__init__(self, client, msg_ids, fname=fname) self._mapObject = mapObject - self._flatten_result = False + self._single_result = False def _reconstruct_result(self, res): """Perform the gather on the actual results.""" diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index 82baec9..3f85555 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -765,9 +765,26 @@ class Client(object): raise result return result - + + def _build_dependency(self, dep): + """helper for building jsonable dependencies from various input forms""" + if isinstance(dep, Dependency): + return dep.as_dict() + elif isinstance(dep, AsyncResult): + return dep.msg_ids + elif dep is None: + return [] + elif isinstance(dep, set): + return list(dep) + elif isinstance(dep, (list,dict)): + return dep + elif isinstance(dep, str): + return [dep] + else: + raise TypeError("Dependency may be: set,list,dict,Dependency or AsyncResult, not %r"%type(dep)) + def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None, - after=None, follow=None): + after=None, follow=None, timeout=None): """Call `f(*args, **kwargs)` on a remote engine(s), returning the result. This is the central execution command for the client. @@ -817,6 +834,10 @@ class Client(object): This job will only be run on an engine where this dependency is met. + timeout : float or None + Only for load-balanced execution (targets=None) + Specify an amount of time (in seconds) + Returns ------- if block is False: @@ -844,33 +865,23 @@ class Client(object): raise TypeError("args must be tuple or list, not %s"%type(args)) if not isinstance(kwargs, dict): raise TypeError("kwargs must be dict, not %s"%type(kwargs)) - - if isinstance(after, Dependency): - after = after.as_dict() - elif isinstance(after, AsyncResult): - after=after.msg_ids - elif after is None: - after = [] - if isinstance(follow, Dependency): - # if len(follow) > 1 and follow.mode == 'all': - # warn("complex follow-dependencies are not rigorously tested for reachability", UserWarning) - follow = follow.as_dict() - elif isinstance(follow, AsyncResult): - follow=follow.msg_ids - elif follow is None: - follow = [] - options = dict(bound=bound, block=block, after=after, follow=follow) + + after = self._build_dependency(after) + follow = self._build_dependency(follow) + + options = dict(bound=bound, block=block) if targets is None: - return self._apply_balanced(f, args, kwargs, **options) + return self._apply_balanced(f, args, kwargs, timeout=timeout, + after=after, follow=follow, **options) else: return self._apply_direct(f, args, kwargs, targets=targets, **options) def _apply_balanced(self, f, args, kwargs, bound=True, block=None, - after=None, follow=None): + after=None, follow=None, timeout=None): """The underlying method for applying functions in a load balanced manner, via the task queue.""" - subheader = dict(after=after, follow=follow) + subheader = dict(after=after, follow=follow, timeout=timeout) bufs = ss.pack_apply_message(f,args,kwargs) content = dict(bound=bound) @@ -885,8 +896,7 @@ class Client(object): else: return ar - def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None, - after=None, follow=None): + def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None): """Then underlying method for applying functions to specific engines via the MUX queue.""" diff --git a/IPython/zmq/parallel/controller.py b/IPython/zmq/parallel/controller.py index 6365c25..26b7ea5 100755 --- a/IPython/zmq/parallel/controller.py +++ b/IPython/zmq/parallel/controller.py @@ -100,9 +100,9 @@ class ControllerFactory(HubFactory): self.log.warn("task::using no Task scheduler") else: - self.log.warn("task::using Python %s Task scheduler"%self.scheme) + self.log.info("task::using Python %s Task scheduler"%self.scheme) sargs = (self.client_addrs['task'], self.engine_addrs['task'], self.monitor_url, self.client_addrs['notification']) - q = Process(target=launch_scheduler, args=sargs, kwargs = dict(scheme=self.scheme)) + q = Process(target=launch_scheduler, args=sargs, kwargs = dict(scheme=self.scheme,logname=self.log.name, loglevel=self.log.level)) q.daemon=True children.append(q) diff --git a/IPython/zmq/parallel/dependency.py b/IPython/zmq/parallel/dependency.py index 3915a40..7f78097 100644 --- a/IPython/zmq/parallel/dependency.py +++ b/IPython/zmq/parallel/dependency.py @@ -55,7 +55,7 @@ def require(*names): return depend(_require, *names) class Dependency(set): - """An object for representing a set of dependencies. + """An object for representing a set of msg_id dependencies. Subclassed from set().""" diff --git a/IPython/zmq/parallel/error.py b/IPython/zmq/parallel/error.py index e467d5d..a52b512 100644 --- a/IPython/zmq/parallel/error.py +++ b/IPython/zmq/parallel/error.py @@ -154,6 +154,9 @@ class UnmetDependency(KernelError): class ImpossibleDependency(UnmetDependency): pass +class DependencyTimeout(UnmetDependency): + pass + class RemoteError(KernelError): """Error raised elsewhere""" ename=None diff --git a/IPython/zmq/parallel/scheduler.py b/IPython/zmq/parallel/scheduler.py index 7dbec99..216c512 100644 --- a/IPython/zmq/parallel/scheduler.py +++ b/IPython/zmq/parallel/scheduler.py @@ -12,9 +12,9 @@ Python Scheduler exists. from __future__ import print_function import sys import logging -from random import randint,random +from random import randint, random from types import FunctionType - +from datetime import datetime, timedelta try: import numpy except ImportError: @@ -29,11 +29,11 @@ from IPython.external.decorator import decorator from IPython.utils.traitlets import Instance, Dict, List, Set import error -from client import Client +# from client import Client from dependency import Dependency import streamsession as ss from entry_point import connect_logger, local_logger -from factory import LoggingFactory +from factory import SessionFactory @decorator @@ -110,7 +110,7 @@ def leastload(loads): # store empty default dependency: MET = Dependency([]) -class TaskScheduler(LoggingFactory): +class TaskScheduler(SessionFactory): """Python TaskScheduler object. This is the simplest object that supports msg_id based @@ -125,7 +125,6 @@ class TaskScheduler(LoggingFactory): engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream - io_loop = Instance(ioloop.IOLoop) # internals: dependencies = Dict() # dict by msg_id of [ msg_ids that depend on key ] @@ -141,20 +140,18 @@ class TaskScheduler(LoggingFactory): all_failed = Set() # set of all failed tasks all_done = Set() # set of all finished tasks=union(completed,failed) blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency - session = Instance(ss.StreamSession) + auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback') - def __init__(self, **kwargs): - super(TaskScheduler, self).__init__(**kwargs) - - self.session = ss.StreamSession(username="TaskScheduler") - + def start(self): self.engine_stream.on_recv(self.dispatch_result, copy=False) self._notification_handlers = dict( registration_notification = self._register_engine, unregistration_notification = self._unregister_engine ) self.notifier_stream.on_recv(self.dispatch_notification) + self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 1e3, self.loop) # 1 Hz + self.auditor.start() self.log.info("Scheduler started...%r"%self) def resume_receiving(self): @@ -261,37 +258,55 @@ class TaskScheduler(LoggingFactory): # location dependencies follow = Dependency(header.get('follow', [])) - # check if unreachable: if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed): - self.depending[msg_id] = [raw_msg,MET,MET] + self.depending[msg_id] = [raw_msg,MET,MET,None] return self.fail_unreachable(msg_id) + # turn timeouts into datetime objects: + timeout = header.get('timeout', None) + if timeout: + timeout = datetime.now() + timedelta(0,timeout,0) + if after.check(self.all_completed, self.all_failed): # time deps already met, try to run if not self.maybe_run(msg_id, raw_msg, follow): # can't run yet - self.save_unmet(msg_id, raw_msg, after, follow) + self.save_unmet(msg_id, raw_msg, after, follow, timeout) else: - self.save_unmet(msg_id, raw_msg, after, follow) + self.save_unmet(msg_id, raw_msg, after, follow, timeout) @logged - def fail_unreachable(self, msg_id): + def audit_timeouts(self): + """Audit all waiting tasks for expired timeouts.""" + now = datetime.now() + for msg_id in self.depending.keys(): + # must recheck, in case one failure cascaded to another: + if msg_id in self.depending: + raw,after,follow,timeout = self.depending[msg_id] + if timeout and timeout < now: + self.fail_unreachable(msg_id, timeout=True) + + @logged + def fail_unreachable(self, msg_id, timeout=False): """a message has become unreachable""" if msg_id not in self.depending: self.log.error("msg %r already failed!"%msg_id) return - raw_msg, after, follow = self.depending.pop(msg_id) + raw_msg, after, follow, timeout = self.depending.pop(msg_id) for mid in follow.union(after): if mid in self.dependencies: self.dependencies[mid].remove(msg_id) + # FIXME: unpacking a message I've already unpacked, but didn't save: idents,msg = self.session.feed_identities(raw_msg, copy=False) msg = self.session.unpack_message(msg, copy=False, content=False) header = msg['header'] + impossible = error.DependencyTimeout if timeout else error.ImpossibleDependency + try: - raise error.ImpossibleDependency() + raise impossible() except: content = ss.wrap_exception() @@ -334,9 +349,9 @@ class TaskScheduler(LoggingFactory): return True @logged - def save_unmet(self, msg_id, raw_msg, after, follow): + def save_unmet(self, msg_id, raw_msg, after, follow, timeout): """Save a message for later submission when its dependencies are met.""" - self.depending[msg_id] = [raw_msg,after,follow] + self.depending[msg_id] = [raw_msg,after,follow,timeout] # track the ids in follow or after, but not those already finished for dep_id in after.union(follow).difference(self.all_done): if dep_id not in self.dependencies: @@ -413,10 +428,10 @@ class TaskScheduler(LoggingFactory): if msg_id not in self.blacklist: self.blacklist[msg_id] = set() self.blacklist[msg_id].add(engine) - raw_msg,follow = self.pending[engine].pop(msg_id) + raw_msg,follow,timeout = self.pending[engine].pop(msg_id) if not self.maybe_run(msg_id, raw_msg, follow): # resubmit failed, put it back in our dependency tree - self.save_unmet(msg_id, raw_msg, MET, follow) + self.save_unmet(msg_id, raw_msg, MET, follow, timeout) pass @logged @@ -435,7 +450,7 @@ class TaskScheduler(LoggingFactory): jobs = self.dependencies.pop(dep_id) for msg_id in jobs: - raw_msg, after, follow = self.depending[msg_id] + raw_msg, after, follow, timeout = self.depending[msg_id] # if dep_id in after: # if after.mode == 'all' and (success or not after.success_only): # after.remove(dep_id) @@ -497,9 +512,9 @@ def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, logname='ZMQ', log_a local_logger(logname, loglevel) scheduler = TaskScheduler(client_stream=ins, engine_stream=outs, - mon_stream=mons,notifier_stream=nots, - scheme=scheme,io_loop=loop, logname=logname) - + mon_stream=mons, notifier_stream=nots, + scheme=scheme, loop=loop, logname=logname) + scheduler.start() try: loop.start() except KeyboardInterrupt: