From 6caa08fd57896bcb0b1f9012f380e93f0cf1a3ba 2013-03-29 18:48:00 From: MinRK Date: 2013-03-29 18:48:00 Subject: [PATCH] use per-timeout callback, rather than audit for timeouts --- diff --git a/IPython/parallel/controller/scheduler.py b/IPython/parallel/controller/scheduler.py index e7d8d91..61fa37d 100644 --- a/IPython/parallel/controller/scheduler.py +++ b/IPython/parallel/controller/scheduler.py @@ -217,8 +217,6 @@ class TaskScheduler(SessionFactory): all_done = Set() # set of all finished tasks=union(completed,failed) all_ids = Set() # set of all submitted task IDs - auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback') - ident = CBytes() # ZMQ identity. This should just be self.session.session # but ensure Bytes def _ident_default(self): @@ -236,9 +234,7 @@ class TaskScheduler(SessionFactory): unregistration_notification = self._unregister_engine ) self.notifier_stream.on_recv(self.dispatch_notification) - self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz - self.auditor.start() - self.log.info("Scheduler started [%s]"%self.scheme_name) + self.log.info("Scheduler started [%s]" % self.scheme_name) def resume_receiving(self): """Resume accepting jobs.""" @@ -438,15 +434,16 @@ class TaskScheduler(SessionFactory): # turn timeouts into datetime objects: timeout = md.get('timeout', None) if timeout: - # cast to float, because jsonlib returns floats as decimal.Decimal, - # which timedelta does not accept - timeout = datetime.now() + timedelta(0,float(timeout),0) + timeout = time.time() + float(timeout) job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg, header=header, targets=targets, after=after, follow=follow, timeout=timeout, metadata=md, ) - + if timeout: + # schedule timeout callback + self.loop.add_timeout(timeout, lambda : self.job_timeout(job)) + # validate and reduce dependencies: for dep in after,follow: if not dep: # empty dependency @@ -470,21 +467,25 @@ class TaskScheduler(SessionFactory): else: self.save_unmet(job) - def audit_timeouts(self): - """Audit all waiting tasks for expired timeouts.""" - now = datetime.now() - for msg_id in self.queue_map.keys(): - # must recheck, in case one failure cascaded to another: - if msg_id in self.queue_map: - job = self.queue_map[msg_id] - if job.timeout and job.timeout < now: - self.fail_unreachable(msg_id, error.TaskTimeout) + def job_timeout(self, job): + """callback for a job's timeout. + + The job may or may not have been run at this point. + """ + if job.timeout >= (time.time() + 1): + self.log.warn("task %s timeout fired prematurely: %s > %s", + job.msg_id, job.timeout, now + ) + if job.msg_id in self.queue_map: + # still waiting, but ran out of time + self.log.info("task %r timed out", job.msg_id) + self.fail_unreachable(job.msg_id, error.TaskTimeout) def fail_unreachable(self, msg_id, why=error.ImpossibleDependency): """a task has become unreachable, send a reply with an ImpossibleDependency error.""" if msg_id not in self.queue_map: - self.log.error("msg %r already failed!", msg_id) + self.log.error("task %r already failed!", msg_id) return job = self.queue_map.pop(msg_id) # lazy-delete from the queue @@ -497,6 +498,7 @@ class TaskScheduler(SessionFactory): raise why() except: content = error.wrap_exception() + self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename']) self.all_done.add(msg_id) self.all_failed.add(msg_id) @@ -791,7 +793,7 @@ def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=Non identity=b'task', in_thread=False): ZMQStream = zmqstream.ZMQStream - + loglevel = logging.DEBUG if config: # unwrap dict back into Config config = Config(config)