diff --git a/IPython/parallel/controller/scheduler.py b/IPython/parallel/controller/scheduler.py index cc589fc..68ee8db 100644 --- a/IPython/parallel/controller/scheduler.py +++ b/IPython/parallel/controller/scheduler.py @@ -140,9 +140,10 @@ class Job(object): self.after = after self.follow = follow self.timeout = timeout - self.removed = False # used for lazy-delete from sorted queue + self.removed = False # used for lazy-delete from sorted queue self.timestamp = time.time() + self.timeout_id = 0 self.blacklist = set() def __lt__(self, other): @@ -155,6 +156,7 @@ class Job(object): def dependents(self): return self.follow.union(self.after) + class TaskScheduler(SessionFactory): """Python TaskScheduler object. @@ -433,19 +435,14 @@ class TaskScheduler(SessionFactory): # location dependencies follow = Dependency(md.get('follow', [])) - # turn timeouts into datetime objects: timeout = md.get('timeout', None) if timeout: - timeout = time.time() + float(timeout) + timeout = float(timeout) job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg, header=header, targets=targets, after=after, follow=follow, timeout=timeout, metadata=md, ) - if timeout: - # schedule timeout callback - self.loop.add_timeout(timeout, lambda : self.job_timeout(job)) - # validate and reduce dependencies: for dep in after,follow: if not dep: # empty dependency @@ -469,11 +466,14 @@ class TaskScheduler(SessionFactory): else: self.save_unmet(job) - def job_timeout(self, job): + def job_timeout(self, job, timeout_id): """callback for a job's timeout. The job may or may not have been run at this point. """ + if job.timeout_id != timeout_id: + # not the most recent call + return now = time.time() if job.timeout >= (now + 1): self.log.warn("task %s timeout fired prematurely: %s > %s", @@ -590,6 +590,14 @@ class TaskScheduler(SessionFactory): if dep_id not in self.graph: self.graph[dep_id] = set() self.graph[dep_id].add(msg_id) + + # schedule timeout callback + if job.timeout: + timeout_id = job.timeout_id = job.timeout_id + 1 + self.loop.add_timeout(time.time() + job.timeout, + lambda : self.job_timeout(job, timeout_id) + ) + def submit_task(self, job, indices=None): """Submit a task to any of a subset of our targets.""" @@ -633,7 +641,7 @@ class TaskScheduler(SessionFactory): else: self.finish_job(idx) except Exception: - self.log.error("task::Invaid result: %r", raw_msg, exc_info=True) + self.log.error("task::Invalid result: %r", raw_msg, exc_info=True) return md = msg['metadata']