diff --git a/IPython/parallel/controller/scheduler.py b/IPython/parallel/controller/scheduler.py index 2f49a90..400e16c 100644 --- a/IPython/parallel/controller/scheduler.py +++ b/IPython/parallel/controller/scheduler.py @@ -23,6 +23,7 @@ from __future__ import print_function import logging import sys +import time from datetime import datetime, timedelta from random import randint, random @@ -119,9 +120,33 @@ def leastload(loads): #--------------------------------------------------------------------- # Classes #--------------------------------------------------------------------- + + # store empty default dependency: MET = Dependency([]) + +class Job(object): + """Simple container for a job""" + def __init__(self, msg_id, raw_msg, idents, msg, header, targets, after, follow, timeout): + self.msg_id = msg_id + self.raw_msg = raw_msg + self.idents = idents + self.msg = msg + self.header = header + self.targets = targets + self.after = after + self.follow = follow + self.timeout = timeout + + + self.timestamp = time.time() + self.blacklist = set() + + @property + def dependents(self): + return self.follow.union(self.after) + class TaskScheduler(SessionFactory): """Python TaskScheduler object. @@ -168,7 +193,7 @@ class TaskScheduler(SessionFactory): graph = Dict() # dict by msg_id of [ msg_ids that depend on key ] retries = Dict() # dict by msg_id of retries remaining (non-neg ints) # waiting = List() # list of msg_ids ready to run, but haven't due to HWM - depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow) + depending = Dict() # dict by msg_id of Jobs pending = Dict() # dict by engine_uuid of submitted tasks completed = Dict() # dict by engine_uuid of completed tasks failed = Dict() # dict by engine_uuid of failed tasks @@ -181,7 +206,7 @@ class TaskScheduler(SessionFactory): all_failed = Set() # set of all failed tasks all_done = Set() # set of all finished tasks=union(completed,failed) all_ids = Set() # set of all submitted task IDs - blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency + auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback') ident = CBytes() # ZMQ identity. This should just be self.session.session @@ -380,7 +405,10 @@ class TaskScheduler(SessionFactory): # which timedelta does not accept timeout = datetime.now() + timedelta(0,float(timeout),0) - args = [raw_msg, targets, after, follow, timeout] + job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg, + header=header, targets=targets, after=after, follow=follow, + timeout=timeout, + ) # validate and reduce dependencies: for dep in after,follow: @@ -388,22 +416,22 @@ class TaskScheduler(SessionFactory): continue # check valid: if msg_id in dep or dep.difference(self.all_ids): - self.depending[msg_id] = args + self.depending[msg_id] = job return self.fail_unreachable(msg_id, error.InvalidDependency) # check if unreachable: if dep.unreachable(self.all_completed, self.all_failed): - self.depending[msg_id] = args + self.depending[msg_id] = job return self.fail_unreachable(msg_id) if after.check(self.all_completed, self.all_failed): # time deps already met, try to run - if not self.maybe_run(msg_id, *args): + if not self.maybe_run(job): # can't run yet if msg_id not in self.all_failed: # could have failed as unreachable - self.save_unmet(msg_id, *args) + self.save_unmet(job) else: - self.save_unmet(msg_id, *args) + self.save_unmet(job) def audit_timeouts(self): """Audit all waiting tasks for expired timeouts.""" @@ -411,8 +439,8 @@ class TaskScheduler(SessionFactory): for msg_id in self.depending.keys(): # must recheck, in case one failure cascaded to another: if msg_id in self.depending: - raw,after,targets,follow,timeout = self.depending[msg_id] - if timeout and timeout < now: + job = self.depending[msg_id] + if job.timeout and job.timeout < now: self.fail_unreachable(msg_id, error.TaskTimeout) def fail_unreachable(self, msg_id, why=error.ImpossibleDependency): @@ -421,15 +449,11 @@ class TaskScheduler(SessionFactory): if msg_id not in self.depending: self.log.error("msg %r already failed!", msg_id) return - raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id) - for mid in follow.union(after): + job = self.depending.pop(msg_id) + for mid in job.dependents: if mid in self.graph: self.graph[mid].remove(msg_id) - # FIXME: unpacking a message I've already unpacked, but didn't save: - idents,msg = self.session.feed_identities(raw_msg, copy=False) - header = self.session.unpack(msg[1].bytes) - try: raise why() except: @@ -439,20 +463,20 @@ class TaskScheduler(SessionFactory): self.all_failed.add(msg_id) msg = self.session.send(self.client_stream, 'apply_reply', content, - parent=header, ident=idents) - self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents) + parent=job.header, ident=job.idents) + self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents) self.update_graph(msg_id, success=False) - def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout): + def maybe_run(self, job): """check location dependencies, and run if they are met.""" + msg_id = job.msg_id self.log.debug("Attempting to assign task %s", msg_id) if not self.targets: # no engines, definitely can't run return False - blacklist = self.blacklist.setdefault(msg_id, set()) - if follow or targets or blacklist or self.hwm: + if job.follow or job.targets or job.blacklist or self.hwm: # we need a can_run filter def can_run(idx): # check hwm @@ -460,56 +484,57 @@ class TaskScheduler(SessionFactory): return False target = self.targets[idx] # check blacklist - if target in blacklist: + if target in job.blacklist: return False # check targets - if targets and target not in targets: + if job.targets and target not in job.targets: return False # check follow - return follow.check(self.completed[target], self.failed[target]) + return job.follow.check(self.completed[target], self.failed[target]) indices = filter(can_run, range(len(self.targets))) if not indices: # couldn't run - if follow.all: + if job.follow.all: # check follow for impossibility dests = set() relevant = set() - if follow.success: + if job.follow.success: relevant = self.all_completed - if follow.failure: + if job.follow.failure: relevant = relevant.union(self.all_failed) - for m in follow.intersection(relevant): + for m in job.follow.intersection(relevant): dests.add(self.destinations[m]) if len(dests) > 1: - self.depending[msg_id] = (raw_msg, targets, after, follow, timeout) + self.depending[msg_id] = job self.fail_unreachable(msg_id) return False - if targets: + if job.targets: # check blacklist+targets for impossibility - targets.difference_update(blacklist) - if not targets or not targets.intersection(self.targets): - self.depending[msg_id] = (raw_msg, targets, after, follow, timeout) + job.targets.difference_update(blacklist) + if not job.targets or not job.targets.intersection(self.targets): + self.depending[msg_id] = job self.fail_unreachable(msg_id) return False return False else: indices = None - self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices) + self.submit_task(job, indices) return True - def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout): + def save_unmet(self, job): """Save a message for later submission when its dependencies are met.""" - self.depending[msg_id] = [raw_msg,targets,after,follow,timeout] + msg_id = job.msg_id + self.depending[msg_id] = job # track the ids in follow or after, but not those already finished - for dep_id in after.union(follow).difference(self.all_done): + for dep_id in job.after.union(job.follow).difference(self.all_done): if dep_id not in self.graph: self.graph[dep_id] = set() self.graph[dep_id].add(msg_id) - def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None): + def submit_task(self, job, indices=None): """Submit a task to any of a subset of our targets.""" if indices: loads = [self.loads[i] for i in indices] @@ -522,12 +547,12 @@ class TaskScheduler(SessionFactory): # print (target, map(str, msg[:3])) # send job to the engine self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False) - self.engine_stream.send_multipart(raw_msg, copy=False) + self.engine_stream.send_multipart(job.raw_msg, copy=False) # update load self.add_job(idx) - self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout) + self.pending[target][job.msg_id] = job # notify Hub - content = dict(msg_id=msg_id, engine_id=target.decode('ascii')) + content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii')) self.session.send(self.mon_stream, 'task_destination', content=content, ident=[b'tracktask',self.ident]) @@ -581,7 +606,6 @@ class TaskScheduler(SessionFactory): self.client_stream.send_multipart(raw_msg, copy=False) # now, update our data structures msg_id = parent['msg_id'] - self.blacklist.pop(msg_id, None) self.pending[engine].pop(msg_id) if success: self.completed[engine].add(msg_id) @@ -599,21 +623,17 @@ class TaskScheduler(SessionFactory): engine = idents[0] msg_id = parent['msg_id'] - if msg_id not in self.blacklist: - self.blacklist[msg_id] = set() - self.blacklist[msg_id].add(engine) - - args = self.pending[engine].pop(msg_id) - raw,targets,after,follow,timeout = args + job = self.pending[engine].pop(msg_id) + job.blacklist.add(engine) - if self.blacklist[msg_id] == targets: - self.depending[msg_id] = args + if job.blacklist == job.targets: + self.depending[msg_id] = job self.fail_unreachable(msg_id) - elif not self.maybe_run(msg_id, *args): + elif not self.maybe_run(job): # resubmit failed if msg_id not in self.all_failed: # put it back in our dependency tree - self.save_unmet(msg_id, *args) + self.save_unmet(job) if self.hwm: try: @@ -646,21 +666,22 @@ class TaskScheduler(SessionFactory): # recheck *all* jobs if # a) we have HWM and an engine just become no longer full # or b) dep_id was given as None + if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]): jobs = self.depending.keys() + + for msg_id in sorted(jobs, key=lambda msg_id: self.depending[msg_id].timestamp): + job = self.depending[msg_id] - for msg_id in jobs: - raw_msg, targets, after, follow, timeout = self.depending[msg_id] - - if after.unreachable(self.all_completed, self.all_failed)\ - or follow.unreachable(self.all_completed, self.all_failed): + if job.after.unreachable(self.all_completed, self.all_failed)\ + or job.follow.unreachable(self.all_completed, self.all_failed): self.fail_unreachable(msg_id) - elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run - if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout): + elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run + if self.maybe_run(job): self.depending.pop(msg_id) - for mid in follow.union(after): + for mid in job.dependents: if mid in self.graph: self.graph[mid].remove(msg_id)