diff --git a/IPython/parallel/engine/engine.py b/IPython/parallel/engine/engine.py index 07b078b..46ba1f2 100644 --- a/IPython/parallel/engine/engine.py +++ b/IPython/parallel/engine/engine.py @@ -25,7 +25,7 @@ from zmq.eventloop import ioloop, zmqstream from IPython.external.ssh import tunnel # internal from IPython.utils.traitlets import ( - Instance, Dict, Integer, Type, CFloat, Unicode, CBytes, Bool + Instance, Dict, Integer, Type, CFloat, CInt, Unicode, CBytes, Bool ) from IPython.utils.py3compat import cast_bytes @@ -53,6 +53,12 @@ class EngineFactory(RegistrationFactory): timeout=CFloat(5, config=True, help="""The time (in seconds) to wait for the Controller to respond to registration requests before giving up.""") + hb_check_period=CFloat(5, config=True, + help="""The time (in seconds) to check for a heartbeat ping from the + Controller.""") + hb_max_misses=CInt(5, config=True, + help="""The maximum number of times a check for the heartbeat ping of a + controller can be missed before shutting down the engine.""") sshserver=Unicode(config=True, help="""The SSH server to use for tunneling connections to the Controller.""") sshkey=Unicode(config=True, @@ -66,6 +72,13 @@ class EngineFactory(RegistrationFactory): id = Integer(allow_none=True) registrar = Instance('zmq.eventloop.zmqstream.ZMQStream') kernel = Instance(Kernel) + + # States for the heartbeat monitoring + _hb_last_pinged = None + _hb_last_monitored = None + _hb_missed_beats = 0 + # The zmq Stream which receives the pings from the Heart + _hb_listener = None bident = CBytes() ident = Unicode() @@ -134,6 +147,11 @@ class EngineFactory(RegistrationFactory): # print (self.session.key) self.session.send(self.registrar, "registration_request", content=content) + def _report_ping(self, msg): + """Callback for when the heartmonitor.Heart receives a ping""" + self.log.debug("Received a ping: %s", msg) + self._hb_last_pinged = time.time() + def complete_registration(self, msg, connect, maybe_tunnel): # print msg self._abort_dc.stop() @@ -156,8 +174,17 @@ class EngineFactory(RegistrationFactory): # possibly forward hb ports with tunnels hb_ping = maybe_tunnel(url('hb_ping')) hb_pong = maybe_tunnel(url('hb_pong')) + + # Add a monitor socket which will record the last time a ping was seen + mon = self.context.socket(zmq.SUB) + mport = mon.bind_to_random_port('tcp://127.0.0.1') + mon.setsockopt(zmq.SUBSCRIBE, b"") + self._hb_listener = zmqstream.ZMQStream(mon, self.loop) + self._hb_listener.on_recv(self._report_ping) + + hb_monitor = "tcp://127.0.0.1:%i"%mport - heart = Heart(hb_ping, hb_pong, heart_id=identity) + heart = Heart(hb_ping, hb_pong, hb_monitor , heart_id=identity) heart.start() # create Shell Connections (MUX, Task, etc.): @@ -228,9 +255,31 @@ class EngineFactory(RegistrationFactory): time.sleep(1) sys.exit(255) + def _hb_monitor(self): + """Callback to monitor the heartbeat from the controller""" + self._hb_listener.flush() + if self._hb_last_monitored > self._hb_last_pinged: + self._hb_missed_beats += 1 + self.log.warn("No heartbeat in the last %s seconds.", self.hb_check_period) + else: + self._hb_missed_beats = 0 + + if self._hb_missed_beats >= self.hb_max_misses: + self.log.fatal("Maximum number of heartbeats misses reached (%s times %s seconds), shutting down.", + self.hb_max_misses, self.hb_check_period) + self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id)) + self.loop.stop() + + self._hb_last_monitored = time.time() + + def start(self): dc = ioloop.DelayedCallback(self.register, 0, self.loop) dc.start() self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop) self._abort_dc.start() + # periodically check the heartbeat pings of the controller + self._hb_reporter = ioloop.PeriodicCallback(self._hb_monitor, self.hb_check_period* 1000, self.loop) + self._hb_reporter.start() +