diff --git a/IPython/config/default/ipcontroller_config.py b/IPython/config/default/ipcontroller_config.py index 492044c..347c922 100644 --- a/IPython/config/default/ipcontroller_config.py +++ b/IPython/config/default/ipcontroller_config.py @@ -71,11 +71,11 @@ c = get_config() # dying engines, dependencies, or engine-subset load-balancing. # c.ControllerFactory.scheme = 'pure' -# The pure ZMQ scheduler can limit the number of outstanding tasks per engine -# by using the ZMQ HWM option. This allows engines with long-running tasks +# The Python scheduler can limit the number of outstanding tasks per engine +# by using an HWM option. This allows engines with long-running tasks # to not steal too many tasks from other engines. The default is 0, which # means agressively distribute messages, never waiting for them to finish. -# c.ControllerFactory.hwm = 1 +# c.TaskScheduler.hwm = 0 # Whether to use Threads or Processes to start the Schedulers. Threads will # use less resources, but potentially reduce throughput. Default is to diff --git a/IPython/parallel/controller/scheduler.py b/IPython/parallel/controller/scheduler.py index 1d9448a..05da905 100644 --- a/IPython/parallel/controller/scheduler.py +++ b/IPython/parallel/controller/scheduler.py @@ -35,7 +35,7 @@ from zmq.eventloop import ioloop, zmqstream # local imports from IPython.external.decorator import decorator from IPython.config.loader import Config -from IPython.utils.traitlets import Instance, Dict, List, Set +from IPython.utils.traitlets import Instance, Dict, List, Set, Int from IPython.parallel import error from IPython.parallel.factory import SessionFactory @@ -126,6 +126,8 @@ class TaskScheduler(SessionFactory): """ + hwm = Int(0, config=True) # limit number of outstanding tasks + # input arguments: scheme = Instance(FunctionType, default=leastload) # function for determining the destination client_stream = Instance(zmqstream.ZMQStream) # client-facing stream @@ -135,6 +137,7 @@ class TaskScheduler(SessionFactory): # internals: graph = Dict() # dict by msg_id of [ msg_ids that depend on key ] + # waiting = List() # list of msg_ids ready to run, but haven't due to HWM depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow) pending = Dict() # dict by engine_uuid of submitted tasks completed = Dict() # dict by engine_uuid of completed tasks @@ -143,6 +146,7 @@ class TaskScheduler(SessionFactory): clients = Dict() # dict by msg_id for who submitted the task targets = List() # list of target IDENTs loads = List() # list of engine loads + # full = Set() # set of IDENTs that have HWM outstanding tasks all_completed = Set() # set of all completed tasks all_failed = Set() # set of all failed tasks all_done = Set() # set of all finished tasks=union(completed,failed) @@ -216,7 +220,6 @@ class TaskScheduler(SessionFactory): # don't pop destinations, because it might be used later # map(self.destinations.pop, self.completed.pop(uid)) # map(self.destinations.pop, self.failed.pop(uid)) - idx = self.targets.index(uid) self.targets.pop(idx) self.loads.pop(idx) @@ -261,7 +264,7 @@ class TaskScheduler(SessionFactory): try: idents, msg = self.session.feed_identities(raw_msg, copy=False) msg = self.session.unpack_message(msg, content=False, copy=False) - except: + except Exception: self.log.error("task::Invaid task: %s"%raw_msg, exc_info=True) return @@ -362,16 +365,19 @@ class TaskScheduler(SessionFactory): def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout): """check location dependencies, and run if they are met.""" blacklist = self.blacklist.setdefault(msg_id, set()) - if follow or targets or blacklist: + if follow or targets or blacklist or self.hwm: # we need a can_run filter def can_run(idx): - target = self.targets[idx] - # check targets - if targets and target not in targets: + # check hwm + if self.loads[idx] == self.hwm: return False + target = self.targets[idx] # check blacklist if target in blacklist: return False + # check targets + if targets and target not in targets: + return False # check follow return follow.check(self.completed[target], self.failed[target]) @@ -426,13 +432,17 @@ class TaskScheduler(SessionFactory): idx = indices[idx] target = self.targets[idx] # print (target, map(str, msg[:3])) + # send job to the engine self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False) self.engine_stream.send_multipart(raw_msg, copy=False) + # update load self.add_job(idx) self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout) + # notify Hub content = dict(msg_id=msg_id, engine_id=target) self.session.send(self.mon_stream, 'task_destination', content=content, ident=['tracktask',self.session.session]) + #----------------------------------------------------------------------- # Result Handling @@ -443,10 +453,13 @@ class TaskScheduler(SessionFactory): try: idents,msg = self.session.feed_identities(raw_msg, copy=False) msg = self.session.unpack_message(msg, content=False, copy=False) - except: + engine = idents[0] + idx = self.targets.index(engine) + self.finish_job(idx) + except Exception: self.log.error("task::Invaid result: %s"%raw_msg, exc_info=True) return - + header = msg['header'] if header.get('dependencies_met', True): success = (header['status'] == 'ok') @@ -496,17 +509,26 @@ class TaskScheduler(SessionFactory): if self.blacklist[msg_id] == targets: self.depending[msg_id] = args - return self.fail_unreachable(msg_id) - + self.fail_unreachable(msg_id) elif not self.maybe_run(msg_id, *args): # resubmit failed, put it back in our dependency tree self.save_unmet(msg_id, *args) + if self.hwm: + idx = self.targets.index(engine) + if self.loads[idx] == self.hwm-1: + self.update_graph(None) + + @logged - def update_graph(self, dep_id, success=True): + def update_graph(self, dep_id=None, success=True): """dep_id just finished. Update our dependency - graph and submit any jobs that just became runable.""" + graph and submit any jobs that just became runable. + + Called with dep_id=None to update graph for hwm, but without finishing + a task. + """ # print ("\n\n***********") # pprint (dep_id) # pprint (self.graph) @@ -514,9 +536,12 @@ class TaskScheduler(SessionFactory): # pprint (self.all_completed) # pprint (self.all_failed) # print ("\n\n***********\n\n") - if dep_id not in self.graph: - return - jobs = self.graph.pop(dep_id) + # update any jobs that depended on the dependency + jobs = self.graph.pop(dep_id, []) + # if we have HWM and an engine just become no longer full + # recheck *all* jobs: + if self.hwm and any( [ load==self.hwm-1 for load in self.loads]): + jobs = self.depending.keys() for msg_id in jobs: raw_msg, targets, after, follow, timeout = self.depending[msg_id]