From a514d133c6a8bb25e471dce528dd4dcfb7a87d3d 2011-04-08 00:38:19 From: MinRK Date: 2011-04-08 00:38:19 Subject: [PATCH] allow load-balancing across subsets of engines --- diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index fb082e8..65c8e62 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -326,7 +326,7 @@ class Client(object): else: self._registration_socket.connect(url) self._engines = ReverseDict() - self._ids = set() + self._ids = [] self.outstanding=set() self.results = {} self.metadata = {} @@ -370,7 +370,8 @@ class Client(object): for k,v in engines.iteritems(): eid = int(k) self._engines[eid] = bytes(v) # force not unicode - self._ids.add(eid) + self._ids.append(eid) + self._ids = sorted(self._ids) if sorted(self._engines.keys()) != range(len(self._engines)) and \ self._task_scheme == 'pure' and self._task_socket: self._stop_scheduling_tasks() @@ -470,7 +471,6 @@ class Client(object): eid = content['id'] d = {eid : content['queue']} self._update_engines(d) - self._ids.add(int(eid)) def _unregister_engine(self, msg): """Unregister an engine that has died.""" @@ -664,9 +664,9 @@ class Client(object): """property for convenient RemoteFunction generation. >>> @client.remote - ... def f(): + ... def getpid(): import os - print (os.getpid()) + return os.getpid() """ return remote(self, block=self.block) @@ -867,6 +867,7 @@ class Client(object): # pass to Dependency constructor return list(Dependency(dep)) + @defaultblock def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None, after=None, follow=None, timeout=None): """Call `f(*args, **kwargs)` on a remote engine(s), returning the result. @@ -903,24 +904,9 @@ class Client(object): Run on each specified engine if int: Run on single engine - - after : Dependency or collection of msg_ids - Only for load-balanced execution (targets=None) - Specify a list of msg_ids as a time-based dependency. - This job will only be run *after* the dependencies - have been met. - - follow : Dependency or collection of msg_ids - Only for load-balanced execution (targets=None) - Specify a list of msg_ids as a location-based dependency. - This job will only be run on an engine where this dependency - is met. - timeout : float/int or None - Only for load-balanced execution (targets=None) - Specify an amount of time (in seconds) for the scheduler to - wait for dependencies to be met before failing with a - DependencyTimeout. + after,follow,timeout only used in `apply_balanced`. See that docstring + for details. Returns ------- @@ -947,25 +933,88 @@ class Client(object): if not isinstance(kwargs, dict): raise TypeError("kwargs must be dict, not %s"%type(kwargs)) - options = dict(bound=bound, block=block) + options = dict(bound=bound, block=block, targets=targets) if targets is None: - if self._task_socket: - return self._apply_balanced(f, args, kwargs, timeout=timeout, - after=after, follow=follow, **options) - else: - msg = "Task farming is disabled" - if self._task_scheme == 'pure': - msg += " because the pure ZMQ scheduler cannot handle" - msg += " disappearing engines." - raise RuntimeError(msg) + return self.apply_balanced(f, args, kwargs, timeout=timeout, + after=after, follow=follow, **options) else: - return self._apply_direct(f, args, kwargs, targets=targets, **options) + if follow or after or timeout: + msg = "follow, after, and timeout args are only used for load-balanced" + msg += "execution." + raise ValueError(msg) + return self._apply_direct(f, args, kwargs, **options) - def _apply_balanced(self, f, args, kwargs, bound=True, block=None, + @defaultblock + def apply_balanced(self, f, args, kwargs, bound=True, block=None, targets=None, after=None, follow=None, timeout=None): - """The underlying method for applying functions in a load balanced - manner, via the task queue.""" + """call f(*args, **kwargs) remotely in a load-balanced manner. + + Parameters + ---------- + + f : function + The fuction to be called remotely + args : tuple/list + The positional arguments passed to `f` + kwargs : dict + The keyword arguments passed to `f` + bound : bool (default: True) + Whether to execute in the Engine(s) namespace, or in a clean + namespace not affecting the engine. + block : bool (default: self.block) + Whether to wait for the result, or return immediately. + False: + returns AsyncResult + True: + returns actual result(s) of f(*args, **kwargs) + if multiple targets: + list of results, matching `targets` + targets : int,list of ints, 'all', None + Specify the destination of the job. + if None: + Submit via Task queue for load-balancing. + if 'all': + Run on all active engines + if list: + Run on each specified engine + if int: + Run on single engine + + after : Dependency or collection of msg_ids + Only for load-balanced execution (targets=None) + Specify a list of msg_ids as a time-based dependency. + This job will only be run *after* the dependencies + have been met. + + follow : Dependency or collection of msg_ids + Only for load-balanced execution (targets=None) + Specify a list of msg_ids as a location-based dependency. + This job will only be run on an engine where this dependency + is met. + + timeout : float/int or None + Only for load-balanced execution (targets=None) + Specify an amount of time (in seconds) for the scheduler to + wait for dependencies to be met before failing with a + DependencyTimeout. + + Returns + ------- + if block is False: + return AsyncResult wrapping msg_id + output of AsyncResult.get() is identical to that of `apply(...block=True)` + else: + wait for, and return actual result of `f(*args, **kwargs)` + + """ + + if self._task_socket is None: + msg = "Task farming is disabled" + if self._task_scheme == 'pure': + msg += " because the pure ZMQ scheduler cannot handle" + msg += " disappearing engines." + raise RuntimeError(msg) if self._task_scheme == 'pure': # pure zmq scheme doesn't support dependencies @@ -978,9 +1027,26 @@ class Client(object): warnings.warn(msg, RuntimeWarning) + # defaults: + args = args if args is not None else [] + kwargs = kwargs if kwargs is not None else {} + + if targets: + idents,_ = self._build_targets(targets) + else: + idents = [] + + # enforce types of f,args,kwrags + if not callable(f): + raise TypeError("f must be callable, not %s"%type(f)) + if not isinstance(args, (tuple, list)): + raise TypeError("args must be tuple or list, not %s"%type(args)) + if not isinstance(kwargs, dict): + raise TypeError("kwargs must be dict, not %s"%type(kwargs)) + after = self._build_dependency(after) follow = self._build_dependency(follow) - subheader = dict(after=after, follow=follow, timeout=timeout) + subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents) bufs = ss.pack_apply_message(f,args,kwargs) content = dict(bound=bound) @@ -991,31 +1057,40 @@ class Client(object): self.history.append(msg_id) ar = AsyncResult(self, [msg_id], fname=f.__name__) if block: - return ar.get() + try: + return ar.get() + except KeyboardInterrupt: + return ar else: return ar def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None): """Then underlying method for applying functions to specific engines - via the MUX queue.""" + via the MUX queue. + + Not to be called directly! + """ - queues,targets = self._build_targets(targets) + idents,targets = self._build_targets(targets) subheader = {} content = dict(bound=bound) bufs = ss.pack_apply_message(f,args,kwargs) msg_ids = [] - for queue in queues: + for ident in idents: msg = self.session.send(self._mux_socket, "apply_request", - content=content, buffers=bufs,ident=queue, subheader=subheader) + content=content, buffers=bufs, ident=ident, subheader=subheader) msg_id = msg['msg_id'] self.outstanding.add(msg_id) self.history.append(msg_id) msg_ids.append(msg_id) ar = AsyncResult(self, msg_ids, fname=f.__name__) if block: - return ar.get() + try: + return ar.get() + except KeyboardInterrupt: + return ar else: return ar @@ -1037,6 +1112,15 @@ class Client(object): """Decorator for making a RemoteFunction.""" return remote(self, bound=bound, targets=targets, block=block) + def view(self, targets=None, balanced=False): + """Method for constructing View objects""" + if not balanced: + if not targets: + targets = slice(None) + return self[targets] + else: + return LoadBalancedView(self, targets) + #-------------------------------------------------------------------------- # Data movement #-------------------------------------------------------------------------- diff --git a/IPython/zmq/parallel/scheduler.py b/IPython/zmq/parallel/scheduler.py index 4e0430a..947f46f 100644 --- a/IPython/zmq/parallel/scheduler.py +++ b/IPython/zmq/parallel/scheduler.py @@ -265,6 +265,9 @@ class TaskScheduler(SessionFactory): msg_id = header['msg_id'] self.all_ids.add(msg_id) + # targets + targets = set(header.get('targets', [])) + # time dependencies after = Dependency(header.get('after', [])) if after.all: @@ -279,28 +282,31 @@ class TaskScheduler(SessionFactory): # location dependencies follow = Dependency(header.get('follow', [])) + # turn timeouts into datetime objects: + timeout = header.get('timeout', None) + if timeout: + timeout = datetime.now() + timedelta(0,timeout,0) + + args = [raw_msg, targets, after, follow, timeout] + + # validate and reduce dependencies: for dep in after,follow: # check valid: if msg_id in dep or dep.difference(self.all_ids): - self.depending[msg_id] = [raw_msg,MET,MET,None] + self.depending[msg_id] = args return self.fail_unreachable(msg_id, error.InvalidDependency) # check if unreachable: if dep.unreachable(self.all_failed): - self.depending[msg_id] = [raw_msg,MET,MET,None] + self.depending[msg_id] = args return self.fail_unreachable(msg_id) - # turn timeouts into datetime objects: - timeout = header.get('timeout', None) - if timeout: - timeout = datetime.now() + timedelta(0,timeout,0) - if after.check(self.all_completed, self.all_failed): # time deps already met, try to run - if not self.maybe_run(msg_id, raw_msg, follow, timeout): + if not self.maybe_run(msg_id, *args): # can't run yet - self.save_unmet(msg_id, raw_msg, after, follow, timeout) + self.save_unmet(msg_id, *args) else: - self.save_unmet(msg_id, raw_msg, after, follow, timeout) + self.save_unmet(msg_id, *args) # @logged def audit_timeouts(self): @@ -309,17 +315,18 @@ class TaskScheduler(SessionFactory): for msg_id in self.depending.keys(): # must recheck, in case one failure cascaded to another: if msg_id in self.depending: - raw,after,follow,timeout = self.depending[msg_id] + raw,after,targets,follow,timeout = self.depending[msg_id] if timeout and timeout < now: self.fail_unreachable(msg_id, timeout=True) @logged def fail_unreachable(self, msg_id, why=error.ImpossibleDependency): - """a message has become unreachable""" + """a task has become unreachable, send a reply with an ImpossibleDependency + error.""" if msg_id not in self.depending: self.log.error("msg %r already failed!"%msg_id) return - raw_msg, after, follow, timeout = self.depending.pop(msg_id) + raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id) for mid in follow.union(after): if mid in self.graph: self.graph[mid].remove(msg_id) @@ -344,37 +351,51 @@ class TaskScheduler(SessionFactory): self.update_graph(msg_id, success=False) @logged - def maybe_run(self, msg_id, raw_msg, follow=None, timeout=None): + def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout): """check location dependencies, and run if they are met.""" - - if follow: + blacklist = self.blacklist.setdefault(msg_id, set()) + if follow or targets or blacklist: + # we need a can_run filter def can_run(idx): target = self.targets[idx] - return target not in self.blacklist.get(msg_id, []) and\ - follow.check(self.completed[target], self.failed[target]) + # check targets + if targets and target not in targets: + return False + # check blacklist + if target in blacklist: + return False + # check follow + return follow.check(self.completed[target], self.failed[target]) indices = filter(can_run, range(len(self.targets))) if not indices: + # couldn't run if follow.all: + # check follow for impossibility dests = set() relevant = self.all_completed if follow.success_only else self.all_done for m in follow.intersection(relevant): dests.add(self.destinations[m]) if len(dests) > 1: self.fail_unreachable(msg_id) - - + return False + if targets: + # check blacklist+targets for impossibility + targets.difference_update(blacklist) + if not targets or not targets.intersection(self.targets): + self.fail_unreachable(msg_id) + return False return False else: indices = None - self.submit_task(msg_id, raw_msg, follow, timeout, indices) + self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices) return True @logged - def save_unmet(self, msg_id, raw_msg, after, follow, timeout): + def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout): """Save a message for later submission when its dependencies are met.""" - self.depending[msg_id] = [raw_msg,after,follow,timeout] + self.depending[msg_id] = [raw_msg,targets,after,follow,timeout] # track the ids in follow or after, but not those already finished for dep_id in after.union(follow).difference(self.all_done): if dep_id not in self.graph: @@ -382,7 +403,7 @@ class TaskScheduler(SessionFactory): self.graph[dep_id].add(msg_id) @logged - def submit_task(self, msg_id, raw_msg, follow, timeout, indices=None): + def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None): """Submit a task to any of a subset of our targets.""" if indices: loads = [self.loads[i] for i in indices] @@ -396,7 +417,7 @@ class TaskScheduler(SessionFactory): self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False) self.engine_stream.send_multipart(raw_msg, copy=False) self.add_job(idx) - self.pending[target][msg_id] = (raw_msg, follow, timeout) + self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout) content = dict(msg_id=msg_id, engine_id=target) self.session.send(self.mon_stream, 'task_destination', content=content, ident=['tracktask',self.session.session]) @@ -406,6 +427,7 @@ class TaskScheduler(SessionFactory): #----------------------------------------------------------------------- @logged def dispatch_result(self, raw_msg): + """dispatch method for result replies""" try: idents,msg = self.session.feed_identities(raw_msg, copy=False) msg = self.session.unpack_message(msg, content=False, copy=False) @@ -424,6 +446,7 @@ class TaskScheduler(SessionFactory): @logged def handle_result(self, idents, parent, raw_msg, success=True): + """handle a real task result, either success or failure""" # first, relay result to client engine = idents[0] client = idents[1] @@ -448,21 +471,30 @@ class TaskScheduler(SessionFactory): @logged def handle_unmet_dependency(self, idents, parent): + """handle an unmet dependency""" engine = idents[0] msg_id = parent['msg_id'] + if msg_id not in self.blacklist: self.blacklist[msg_id] = set() self.blacklist[msg_id].add(engine) - raw_msg,follow,timeout = self.pending[engine].pop(msg_id) - if not self.maybe_run(msg_id, raw_msg, follow, timeout): + + args = self.pending[engine].pop(msg_id) + raw,targets,after,follow,timeout = args + + if self.blacklist[msg_id] == targets: + self.depending[msg_id] = args + return self.fail_unreachable(msg_id) + + elif not self.maybe_run(msg_id, *args): # resubmit failed, put it back in our dependency tree - self.save_unmet(msg_id, raw_msg, MET, follow, timeout) - pass + self.save_unmet(msg_id, *args) + @logged def update_graph(self, dep_id, success=True): """dep_id just finished. Update our dependency - table and submit any jobs that just became runable.""" + graph and submit any jobs that just became runable.""" # print ("\n\n***********") # pprint (dep_id) # pprint (self.graph) @@ -475,7 +507,7 @@ class TaskScheduler(SessionFactory): jobs = self.graph.pop(dep_id) for msg_id in jobs: - raw_msg, after, follow, timeout = self.depending[msg_id] + raw_msg, targets, after, follow, timeout = self.depending[msg_id] # if dep_id in after: # if after.all and (success or not after.success_only): # after.remove(dep_id) @@ -484,8 +516,7 @@ class TaskScheduler(SessionFactory): self.fail_unreachable(msg_id) elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run - self.depending[msg_id][1] = MET - if self.maybe_run(msg_id, raw_msg, follow, timeout): + if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout): self.depending.pop(msg_id) for mid in follow.union(after): diff --git a/IPython/zmq/parallel/view.py b/IPython/zmq/parallel/view.py index a67c4da..a397728 100644 --- a/IPython/zmq/parallel/view.py +++ b/IPython/zmq/parallel/view.py @@ -66,10 +66,15 @@ class View(object): Don't use this class, use subclasses. """ - _targets = None block=None bound=None history=None + outstanding = set() + results = {} + + _targets = None + _apply_name = 'apply' + _default_names = ['targets', 'block'] def __init__(self, client, targets=None): self.client = client @@ -80,6 +85,9 @@ class View(object): self.history = [] self.outstanding = set() self.results = {} + for name in self._default_names: + setattr(self, name, getattr(self, name, None)) + def __repr__(self): strtargets = str(self._targets) @@ -95,11 +103,23 @@ class View(object): def targets(self, value): self._targets = value # raise AttributeError("Cannot set my targets argument after construction!") - + + def _defaults(self, *excludes): + """return dict of our default attributes, excluding names given.""" + d = {} + for name in self._default_names: + if name not in excludes: + d[name] = getattr(self, name) + return d + @sync_results def spin(self): """spin the client, and sync""" self.client.spin() + + @property + def _apply(self): + return getattr(self.client, self._apply_name) @sync_results @save_ids @@ -113,7 +133,7 @@ class View(object): else: returns actual result of f(*args, **kwargs) """ - return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=self.bound) + return self._apply(f, args, kwargs, **self._defaults()) @save_ids def apply_async(self, f, *args, **kwargs): @@ -123,7 +143,8 @@ class View(object): returns msg_id """ - return self.client.apply(f,args,kwargs, block=False, targets=self.targets, bound=False) + d = self._defaults('block', 'bound') + return self._apply(f,args,kwargs, block=False, bound=False, **d) @spin_after @save_ids @@ -135,7 +156,8 @@ class View(object): returns: actual result of f(*args, **kwargs) """ - return self.client.apply(f,args,kwargs, block=True, targets=self.targets, bound=False) + d = self._defaults('block', 'bound') + return self._apply(f,args,kwargs, block=True, bound=False, **d) @sync_results @save_ids @@ -150,7 +172,8 @@ class View(object): This method has access to the targets' globals """ - return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=True) + d = self._defaults('bound') + return self._apply(f, args, kwargs, bound=True, **d) @sync_results @save_ids @@ -163,7 +186,8 @@ class View(object): This method has access to the targets' globals """ - return self.client.apply(f, args, kwargs, block=False, targets=self.targets, bound=True) + d = self._defaults('block', 'bound') + return self._apply(f, args, kwargs, block=False, bound=True, **d) @spin_after @save_ids @@ -175,7 +199,8 @@ class View(object): This method has access to the targets' globals """ - return self.client.apply(f, args, kwargs, block=True, targets=self.targets, bound=True) + d = self._defaults('block', 'bound') + return self._apply(f, args, kwargs, block=True, bound=True, **d) @spin_after @save_ids @@ -337,24 +362,22 @@ class LoadBalancedView(View): Typically created via: - >>> lbv = client[None] - + >>> v = client[None] + but can also be created with: - >>> lbc = LoadBalancedView(client) + >>> v = client.view([1,3],balanced=True) + + which would restrict loadbalancing to between engines 1 and 3. - TODO: allow subset of engines across which to balance. """ - def __repr__(self): - return "<%s %s>"%(self.__class__.__name__, self.client._config['url']) - @property - def targets(self): - return None - - @targets.setter - def targets(self, value): - raise AttributeError("Cannot set targets for LoadbalancedView!") - - \ No newline at end of file + _apply_name = 'apply_balanced' + _default_names = ['targets', 'block', 'bound', 'follow', 'after', 'timeout'] + + def __init__(self, client, targets=None): + super(LoadBalancedView, self).__init__(client, targets) + self._ntargets = 1 + self._apply_name = 'apply_balanced' +