From d51586b9eda97ab57c65fa0be2aa7756f173ec15 2011-04-08 00:38:16 From: MinRK Date: 2011-04-08 00:38:16 Subject: [PATCH] Improvements to dependency handling Specifically: * add 'success_only' switch to Dependencies * Scheduler handles some cases where Dependencies are impossible to meet. --- diff --git a/IPython/utils/pickleutil.py b/IPython/utils/pickleutil.py index b3b09ee..3314e9e 100644 --- a/IPython/utils/pickleutil.py +++ b/IPython/utils/pickleutil.py @@ -16,17 +16,23 @@ __docformat__ = "restructuredtext en" #------------------------------------------------------------------------------- from types import FunctionType +import copy -# contents of codeutil should either be in here, or codeutil belongs in IPython/util from IPython.zmq.parallel.dependency import dependent + import codeutil +#------------------------------------------------------------------------------- +# Classes +#------------------------------------------------------------------------------- + + class CannedObject(object): def __init__(self, obj, keys=[]): self.keys = keys - self.obj = obj + self.obj = copy.copy(obj) for key in keys: - setattr(obj, key, can(getattr(obj, key))) + setattr(self.obj, key, can(getattr(obj, key))) def getObject(self, g=None): @@ -43,6 +49,7 @@ class CannedFunction(CannedObject): def __init__(self, f): self._checkType(f) self.code = f.func_code + self.__name__ = f.__name__ def _checkType(self, obj): assert isinstance(obj, FunctionType), "Not a function type" @@ -53,6 +60,11 @@ class CannedFunction(CannedObject): newFunc = FunctionType(self.code, g) return newFunc +#------------------------------------------------------------------------------- +# Functions +#------------------------------------------------------------------------------- + + def can(obj): if isinstance(obj, FunctionType): return CannedFunction(obj) diff --git a/IPython/zmq/parallel/asyncresult.py b/IPython/zmq/parallel/asyncresult.py index efcee41..887fb12 100644 --- a/IPython/zmq/parallel/asyncresult.py +++ b/IPython/zmq/parallel/asyncresult.py @@ -36,6 +36,7 @@ class AsyncResult(object): self._fname=fname self._ready = False self._success = None + self._flatten_result = len(msg_ids) == 1 def __repr__(self): if self._ready: @@ -49,7 +50,7 @@ class AsyncResult(object): Override me in subclasses for turning a list of results into the expected form. """ - if len(self.msg_ids) == 1: + if self._flatten_result: return res[0] else: return res @@ -115,7 +116,7 @@ class AsyncResult(object): def get_dict(self, timeout=-1): """Get the results as a dict, keyed by engine_id.""" results = self.get(timeout) - engine_ids = [md['engine_id'] for md in self._metadata ] + engine_ids = [ md['engine_id'] for md in self._metadata ] bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k)) maxcount = bycount.count(bycount[-1]) if maxcount > 1: @@ -130,11 +131,17 @@ class AsyncResult(object): """result property.""" return self._result + # abbreviated alias: + r = result + @property @check_ready def metadata(self): """metadata property.""" - return self._metadata + if self._flatten_result: + return self._metadata[0] + else: + return self._metadata @property def result_dict(self): @@ -157,7 +164,11 @@ class AsyncResult(object): elif isinstance(key, slice): return error.collect_exceptions(self._result[key], self._fname) elif isinstance(key, basestring): - return [ md[key] for md in self._metadata ] + values = [ md[key] for md in self._metadata ] + if self._flatten_result: + return values[0] + else: + return values else: raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key)) @@ -177,8 +188,9 @@ class AsyncMapResult(AsyncResult): """ def __init__(self, client, msg_ids, mapObject, fname=''): - self._mapObject = mapObject AsyncResult.__init__(self, client, msg_ids, fname=fname) + self._mapObject = mapObject + self._flatten_result = False def _reconstruct_result(self, res): """Perform the gather on the actual results.""" diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index e2c5ed3..82baec9 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -91,7 +91,13 @@ def defaultblock(f, self, *args, **kwargs): #-------------------------------------------------------------------------- class Metadata(dict): - """Subclass of dict for initializing metadata values.""" + """Subclass of dict for initializing metadata values. + + Attribute access works on keys. + + These objects have a strict set of keys - errors will raise if you try + to add new keys. + """ def __init__(self, *args, **kwargs): dict.__init__(self) md = {'msg_id' : None, @@ -113,7 +119,27 @@ class Metadata(dict): } self.update(md) self.update(dict(*args, **kwargs)) + + def __getattr__(self, key): + """getattr aliased to getitem""" + if key in self.iterkeys(): + return self[key] + else: + raise AttributeError(key) + def __setattr__(self, key, value): + """setattr aliased to setitem, with strict""" + if key in self.iterkeys(): + self[key] = value + else: + raise AttributeError(key) + + def __setitem__(self, key, value): + """strict static key enforcement""" + if key in self.iterkeys(): + dict.__setitem__(self, key, value) + else: + raise KeyError(key) class Client(object): @@ -372,16 +398,22 @@ class Client(object): def _extract_metadata(self, header, parent, content): md = {'msg_id' : parent['msg_id'], - 'submitted' : datetime.strptime(parent['date'], ss.ISO8601), - 'started' : datetime.strptime(header['started'], ss.ISO8601), - 'completed' : datetime.strptime(header['date'], ss.ISO8601), 'received' : datetime.now(), - 'engine_uuid' : header['engine'], - 'engine_id' : self._engines.get(header['engine'], None), + 'engine_uuid' : header.get('engine', None), 'follow' : parent['follow'], 'after' : parent['after'], 'status' : content['status'], } + + if md['engine_uuid'] is not None: + md['engine_id'] = self._engines.get(md['engine_uuid'], None) + + if 'date' in parent: + md['submitted'] = datetime.strptime(parent['date'], ss.ISO8601) + if 'started' in header: + md['started'] = datetime.strptime(header['started'], ss.ISO8601) + if 'date' in header: + md['completed'] = datetime.strptime(header['date'], ss.ISO8601) return md def _handle_execute_reply(self, msg): @@ -393,7 +425,10 @@ class Client(object): parent = msg['parent_header'] msg_id = parent['msg_id'] if msg_id not in self.outstanding: - print("got unknown result: %s"%msg_id) + if msg_id in self.history: + print ("got stale result: %s"%msg_id) + else: + print ("got unknown result: %s"%msg_id) else: self.outstanding.remove(msg_id) self.results[msg_id] = ss.unwrap_exception(msg['content']) @@ -403,7 +438,12 @@ class Client(object): parent = msg['parent_header'] msg_id = parent['msg_id'] if msg_id not in self.outstanding: - print ("got unknown result: %s"%msg_id) + if msg_id in self.history: + print ("got stale result: %s"%msg_id) + print self.results[msg_id] + print msg + else: + print ("got unknown result: %s"%msg_id) else: self.outstanding.remove(msg_id) content = msg['content'] @@ -424,9 +464,10 @@ class Client(object): pass else: e = ss.unwrap_exception(content) - e_uuid = e.engine_info['engineid'] - eid = self._engines[e_uuid] - e.engine_info['engineid'] = eid + if e.engine_info: + e_uuid = e.engine_info['engineid'] + eid = self._engines[e_uuid] + e.engine_info['engineid'] = eid self.results[msg_id] = e def _flush_notifications(self): @@ -811,6 +852,8 @@ class Client(object): elif after is None: after = [] if isinstance(follow, Dependency): + # if len(follow) > 1 and follow.mode == 'all': + # warn("complex follow-dependencies are not rigorously tested for reachability", UserWarning) follow = follow.as_dict() elif isinstance(follow, AsyncResult): follow=follow.msg_ids @@ -827,7 +870,6 @@ class Client(object): after=None, follow=None): """The underlying method for applying functions in a load balanced manner, via the task queue.""" - subheader = dict(after=after, follow=follow) bufs = ss.pack_apply_message(f,args,kwargs) content = dict(bound=bound) diff --git a/IPython/zmq/parallel/dependency.py b/IPython/zmq/parallel/dependency.py index 964a823..3915a40 100644 --- a/IPython/zmq/parallel/dependency.py +++ b/IPython/zmq/parallel/dependency.py @@ -1,6 +1,8 @@ """Dependency utilities""" from IPython.external.decorator import decorator +from error import UnmetDependency + # flags ALL = 1 << 0 @@ -8,9 +10,6 @@ ANY = 1 << 1 HERE = 1 << 2 ANYWHERE = 1 << 3 -class UnmetDependency(Exception): - pass - class depend(object): """Dependency decorator, for use with tasks.""" @@ -30,7 +29,7 @@ class dependent(object): def __init__(self, f, df, *dargs, **dkwargs): self.f = f - self.func_name = self.f.func_name + self.func_name = getattr(f, '__name__', 'f') self.df = df self.dargs = dargs self.dkwargs = dkwargs @@ -39,6 +38,10 @@ class dependent(object): if self.df(*self.dargs, **self.dkwargs) is False: raise UnmetDependency() return self.f(*args, **kwargs) + + @property + def __name__(self): + return self.func_name def _require(*names): for name in names: @@ -57,18 +60,23 @@ class Dependency(set): Subclassed from set().""" mode='all' + success_only=True - def __init__(self, dependencies=[], mode='all'): + def __init__(self, dependencies=[], mode='all', success_only=True): if isinstance(dependencies, dict): # load from dict - dependencies = dependencies.get('dependencies', []) mode = dependencies.get('mode', mode) + success_only = dependencies.get('success_only', success_only) + dependencies = dependencies.get('dependencies', []) set.__init__(self, dependencies) self.mode = mode.lower() + self.success_only=success_only if self.mode not in ('any', 'all'): raise NotImplementedError("Only any|all supported, not %r"%mode) - def check(self, completed): + def check(self, completed, failed=None): + if failed is not None and not self.success_only: + completed = completed.union(failed) if len(self) == 0: return True if self.mode == 'all': @@ -78,13 +86,26 @@ class Dependency(set): else: raise NotImplementedError("Only any|all supported, not %r"%mode) + def unreachable(self, failed): + if len(self) == 0 or len(failed) == 0 or not self.success_only: + return False + print self, self.success_only, self.mode, failed + if self.mode == 'all': + return not self.isdisjoint(failed) + elif self.mode == 'any': + return self.issubset(failed) + else: + raise NotImplementedError("Only any|all supported, not %r"%mode) + + def as_dict(self): """Represent this dependency as a dict. For json compatibility.""" return dict( dependencies=list(self), - mode=self.mode + mode=self.mode, + success_only=self.success_only, ) -__all__ = ['UnmetDependency', 'depend', 'require', 'Dependency'] +__all__ = ['depend', 'require', 'Dependency'] diff --git a/IPython/zmq/parallel/error.py b/IPython/zmq/parallel/error.py index add5aac..e467d5d 100644 --- a/IPython/zmq/parallel/error.py +++ b/IPython/zmq/parallel/error.py @@ -148,6 +148,12 @@ class FileTimeoutError(KernelError): class TimeoutError(KernelError): pass +class UnmetDependency(KernelError): + pass + +class ImpossibleDependency(UnmetDependency): + pass + class RemoteError(KernelError): """Error raised elsewhere""" ename=None diff --git a/IPython/zmq/parallel/scheduler.py b/IPython/zmq/parallel/scheduler.py index 6a8f5ac..3275b6f 100644 --- a/IPython/zmq/parallel/scheduler.py +++ b/IPython/zmq/parallel/scheduler.py @@ -27,6 +27,7 @@ from IPython.external.decorator import decorator from IPython.config.configurable import Configurable from IPython.utils.traitlets import Instance, Dict, List, Set +import error from client import Client from dependency import Dependency import streamsession as ss @@ -104,6 +105,9 @@ def leastload(loads): #--------------------------------------------------------------------- # Classes #--------------------------------------------------------------------- +# store empty default dependency: +MET = Dependency([]) + class TaskScheduler(Configurable): """Python TaskScheduler object. @@ -126,10 +130,14 @@ class TaskScheduler(Configurable): depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow) pending = Dict() # dict by engine_uuid of submitted tasks completed = Dict() # dict by engine_uuid of completed tasks + failed = Dict() # dict by engine_uuid of failed tasks + destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed) clients = Dict() # dict by msg_id for who submitted the task targets = List() # list of target IDENTs loads = List() # list of engine loads - all_done = Set() # set of all completed tasks + all_completed = Set() # set of all completed tasks + all_failed = Set() # set of all failed tasks + all_done = Set() # set of all finished tasks=union(completed,failed) blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency session = Instance(ss.StreamSession) @@ -182,6 +190,7 @@ class TaskScheduler(Configurable): self.loads.insert(0,0) # initialize sets self.completed[uid] = set() + self.failed[uid] = set() self.pending[uid] = {} if len(self.targets) == 1: self.resume_receiving() @@ -196,6 +205,11 @@ class TaskScheduler(Configurable): self.engine_stream.flush() self.completed.pop(uid) + self.failed.pop(uid) + # don't pop destinations, because it might be used later + # map(self.destinations.pop, self.completed.pop(uid)) + # map(self.destinations.pop, self.failed.pop(uid)) + lost = self.pending.pop(uid) idx = self.targets.index(uid) @@ -235,15 +249,23 @@ class TaskScheduler(Configurable): # time dependencies after = Dependency(header.get('after', [])) if after.mode == 'all': - after.difference_update(self.all_done) - if after.check(self.all_done): + after.difference_update(self.all_completed) + if not after.success_only: + after.difference_update(self.all_failed) + if after.check(self.all_completed, self.all_failed): # recast as empty set, if `after` already met, # to prevent unnecessary set comparisons - after = Dependency([]) + after = MET # location dependencies follow = Dependency(header.get('follow', [])) - if len(after) == 0: + + # check if unreachable: + if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed): + self.depending[msg_id] = [raw_msg,MET,MET] + return self.fail_unreachable(msg_id) + + if after.check(self.all_completed, self.all_failed): # time deps already met, try to run if not self.maybe_run(msg_id, raw_msg, follow): # can't run yet @@ -252,6 +274,35 @@ class TaskScheduler(Configurable): self.save_unmet(msg_id, raw_msg, after, follow) @logged + def fail_unreachable(self, msg_id): + """a message has become unreachable""" + if msg_id not in self.depending: + logging.error("msg %r already failed!"%msg_id) + return + raw_msg, after, follow = self.depending.pop(msg_id) + for mid in follow.union(after): + if mid in self.dependencies: + self.dependencies[mid].remove(msg_id) + + idents,msg = self.session.feed_identities(raw_msg, copy=False) + msg = self.session.unpack_message(msg, copy=False, content=False) + header = msg['header'] + + try: + raise error.ImpossibleDependency() + except: + content = ss.wrap_exception() + + self.all_done.add(msg_id) + self.all_failed.add(msg_id) + + msg = self.session.send(self.client_stream, 'apply_reply', content, + parent=header, ident=idents) + self.session.send(self.mon_stream, msg, ident=['outtask']+idents) + + self.update_dependencies(msg_id, success=False) + + @logged def maybe_run(self, msg_id, raw_msg, follow=None): """check location dependencies, and run if they are met.""" @@ -259,10 +310,20 @@ class TaskScheduler(Configurable): def can_run(idx): target = self.targets[idx] return target not in self.blacklist.get(msg_id, []) and\ - follow.check(self.completed[target]) + follow.check(self.completed[target], self.failed[target]) indices = filter(can_run, range(len(self.targets))) if not indices: + # TODO evaluate unmeetable follow dependencies + if follow.mode == 'all': + dests = set() + relevant = self.all_completed if follow.success_only else self.all_done + for m in follow.intersection(relevant): + dests.add(self.destinations[m]) + if len(dests) > 1: + self.fail_unreachable(msg_id) + + return False else: indices = None @@ -271,10 +332,10 @@ class TaskScheduler(Configurable): return True @logged - def save_unmet(self, msg_id, msg, after, follow): + def save_unmet(self, msg_id, raw_msg, after, follow): """Save a message for later submission when its dependencies are met.""" - self.depending[msg_id] = (msg_id,msg,after,follow) - # track the ids in both follow/after, but not those already completed + self.depending[msg_id] = [raw_msg,after,follow] + # track the ids in follow or after, but not those already finished for dep_id in after.union(follow).difference(self.all_done): if dep_id not in self.dependencies: self.dependencies[dep_id] = set() @@ -313,14 +374,15 @@ class TaskScheduler(Configurable): msg = self.session.unpack_message(msg, content=False, copy=False) header = msg['header'] if header.get('dependencies_met', True): - self.handle_result_success(idents, msg['parent_header'], raw_msg) - # send to monitor + success = (header['status'] == 'ok') + self.handle_result(idents, msg['parent_header'], raw_msg, success) + # send to Hub monitor self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False) else: self.handle_unmet_dependency(idents, msg['parent_header']) @logged - def handle_result_success(self, idents, parent, raw_msg): + def handle_result(self, idents, parent, raw_msg, success=True): # first, relay result to client engine = idents[0] client = idents[1] @@ -331,10 +393,16 @@ class TaskScheduler(Configurable): # now, update our data structures msg_id = parent['msg_id'] self.pending[engine].pop(msg_id) - self.completed[engine].add(msg_id) + if success: + self.completed[engine].add(msg_id) + self.all_completed.add(msg_id) + else: + self.failed[engine].add(msg_id) + self.all_failed.add(msg_id) self.all_done.add(msg_id) + self.destinations[msg_id] = engine - self.update_dependencies(msg_id) + self.update_dependencies(msg_id, success) @logged def handle_unmet_dependency(self, idents, parent): @@ -346,24 +414,39 @@ class TaskScheduler(Configurable): raw_msg,follow = self.pending[engine].pop(msg_id) if not self.maybe_run(msg_id, raw_msg, follow): # resubmit failed, put it back in our dependency tree - self.save_unmet(msg_id, raw_msg, Dependency(), follow) + self.save_unmet(msg_id, raw_msg, MET, follow) pass + @logged - def update_dependencies(self, dep_id): + def update_dependencies(self, dep_id, success=True): """dep_id just finished. Update our dependency table and submit any jobs that just became runable.""" - + # print ("\n\n***********") + # pprint (dep_id) + # pprint (self.dependencies) + # pprint (self.depending) + # pprint (self.all_completed) + # pprint (self.all_failed) + # print ("\n\n***********\n\n") if dep_id not in self.dependencies: return jobs = self.dependencies.pop(dep_id) - for job in jobs: - msg_id, raw_msg, after, follow = self.depending[job] - if dep_id in after: - after.remove(dep_id) - if not after: # time deps met, maybe run + + for msg_id in jobs: + raw_msg, after, follow = self.depending[msg_id] + # if dep_id in after: + # if after.mode == 'all' and (success or not after.success_only): + # after.remove(dep_id) + + if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed): + self.fail_unreachable(msg_id) + + elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run + self.depending[msg_id][1] = MET if self.maybe_run(msg_id, raw_msg, follow): - self.depending.pop(job) - for mid in follow: + + self.depending.pop(msg_id) + for mid in follow.union(after): if mid in self.dependencies: self.dependencies[mid].remove(msg_id) diff --git a/IPython/zmq/parallel/streamkernel.py b/IPython/zmq/parallel/streamkernel.py index 1516257..f30a4ca 100755 --- a/IPython/zmq/parallel/streamkernel.py +++ b/IPython/zmq/parallel/streamkernel.py @@ -34,7 +34,6 @@ from IPython.zmq.displayhook import DisplayHook from factory import SessionFactory from streamsession import StreamSession, Message, extract_header, serialize_object,\ unpack_apply_message, ISO8601, wrap_exception -from dependency import UnmetDependency import heartmonitor from client import Client @@ -266,9 +265,7 @@ class Kernel(SessionFactory): reply_content = exc_content else: reply_content = {'status' : 'ok'} - # reply_msg = self.session.msg(u'execute_reply', reply_content, parent) - # self.reply_socket.send(ident, zmq.SNDMORE) - # self.reply_socket.send_json(reply_msg) + reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent, ident=ident, subheader = dict(started=started)) logging.debug(str(reply_msg)) @@ -317,10 +314,7 @@ class Kernel(SessionFactory): suffix = prefix = "_" # prevent keyword collisions with lambda f,args,kwargs = unpack_apply_message(bufs, working, copy=False) # if f.fun - if hasattr(f, 'func_name'): - fname = f.func_name - else: - fname = f.__name__ + fname = getattr(f, '__name__', 'f') fname = prefix+fname.strip('<>')+suffix argname = prefix+"args"+suffix @@ -350,16 +344,17 @@ class Kernel(SessionFactory): reply_content = exc_content result_buf = [] - if exc_content['ename'] == UnmetDependency.__name__: + if exc_content['ename'] == 'UnmetDependency': sub['dependencies_met'] = False else: reply_content = {'status' : 'ok'} - # reply_msg = self.session.msg(u'execute_reply', reply_content, parent) - # self.reply_socket.send(ident, zmq.SNDMORE) - # self.reply_socket.send_json(reply_msg) + + # put 'ok'/'error' status in header, for scheduler introspection: + sub['status'] = reply_content['status'] + reply_msg = self.session.send(stream, u'apply_reply', reply_content, parent=parent, ident=ident,buffers=result_buf, subheader=sub) - # print(Message(reply_msg), file=sys.__stdout__) + # if reply_msg['content']['status'] == u'error': # self.abort_queues() @@ -400,13 +395,11 @@ class Kernel(SessionFactory): return dispatcher for s in self.shell_streams: - # s.on_recv(printer) s.on_recv(make_dispatcher(s), copy=False) - # s.on_err(printer) + s.on_err(printer) if self.iopub_stream: self.iopub_stream.on_err(printer) - # self.iopub_stream.on_send(printer) #### while True mode: # while True: diff --git a/IPython/zmq/parallel/streamsession.py b/IPython/zmq/parallel/streamsession.py index 2df2a93..1780fef 100644 --- a/IPython/zmq/parallel/streamsession.py +++ b/IPython/zmq/parallel/streamsession.py @@ -399,12 +399,12 @@ class StreamSession(object): stream.send(b, flag, copy=False) if buffers: stream.send(buffers[-1], copy=False) - omsg = Message(msg) + # omsg = Message(msg) if self.debug: - pprint.pprint(omsg) + pprint.pprint(msg) pprint.pprint(to_send) pprint.pprint(buffers) - return omsg + return msg def send_raw(self, stream, msg, flags=0, copy=True, ident=None): """Send a raw message via ident path.