##// END OF EJS Templates
Test case for setting only lower or upper, no sensible default so should be an error
Test case for setting only lower or upper, no sensible default so should be an error

File last commit:

r17098:2fddb199 merge
r17704:400433c5
Show More
engine.py
301 lines | 12.6 KiB | text/x-python | PythonLexer
"""A simple engine that talks to a controller over 0MQ.
it handles registration, etc. and launches a kernel
connected to the Controller's Schedulers.
"""
# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import print_function
import sys
import time
from getpass import getpass
import zmq
from zmq.eventloop import ioloop, zmqstream
from IPython.utils.localinterfaces import localhost
from IPython.utils.traitlets import (
Instance, Dict, Integer, Type, Float, Unicode, CBytes, Bool
)
from IPython.utils.py3compat import cast_bytes
from IPython.parallel.controller.heartmonitor import Heart
from IPython.parallel.factory import RegistrationFactory
from IPython.parallel.util import disambiguate_url
from IPython.kernel.zmq.session import Message
from IPython.kernel.zmq.ipkernel import IPythonKernel as Kernel
from IPython.kernel.zmq.kernelapp import IPKernelApp
class EngineFactory(RegistrationFactory):
"""IPython engine"""
# configurables:
out_stream_factory=Type('IPython.kernel.zmq.iostream.OutStream', config=True,
help="""The OutStream for handling stdout/err.
Typically 'IPython.kernel.zmq.iostream.OutStream'""")
display_hook_factory=Type('IPython.kernel.zmq.displayhook.ZMQDisplayHook', config=True,
help="""The class for handling displayhook.
Typically 'IPython.kernel.zmq.displayhook.ZMQDisplayHook'""")
location=Unicode(config=True,
help="""The location (an IP address) of the controller. This is
used for disambiguating URLs, to determine whether
loopback should be used to connect or the public address.""")
timeout=Float(5.0, config=True,
help="""The time (in seconds) to wait for the Controller to respond
to registration requests before giving up.""")
max_heartbeat_misses=Integer(50, config=True,
help="""The maximum number of times a check for the heartbeat ping of a
controller can be missed before shutting down the engine.
If set to 0, the check is disabled.""")
sshserver=Unicode(config=True,
help="""The SSH server to use for tunneling connections to the Controller.""")
sshkey=Unicode(config=True,
help="""The SSH private key file to use when tunneling connections to the Controller.""")
paramiko=Bool(sys.platform == 'win32', config=True,
help="""Whether to use paramiko instead of openssh for tunnels.""")
@property
def tunnel_mod(self):
from zmq.ssh import tunnel
return tunnel
# not configurable:
connection_info = Dict()
user_ns = Dict()
id = Integer(allow_none=True)
registrar = Instance('zmq.eventloop.zmqstream.ZMQStream')
kernel = Instance(Kernel)
hb_check_period=Integer()
# States for the heartbeat monitoring
# Initial values for monitored and pinged must satisfy "monitored > pinged == False" so that
# during the first check no "missed" ping is reported. Must be floats for Python 3 compatibility.
_hb_last_pinged = 0.0
_hb_last_monitored = 0.0
_hb_missed_beats = 0
# The zmq Stream which receives the pings from the Heart
_hb_listener = None
bident = CBytes()
ident = Unicode()
def _ident_changed(self, name, old, new):
self.bident = cast_bytes(new)
using_ssh=Bool(False)
def __init__(self, **kwargs):
super(EngineFactory, self).__init__(**kwargs)
self.ident = self.session.session
def init_connector(self):
"""construct connection function, which handles tunnels."""
self.using_ssh = bool(self.sshkey or self.sshserver)
if self.sshkey and not self.sshserver:
# We are using ssh directly to the controller, tunneling localhost to localhost
self.sshserver = self.url.split('://')[1].split(':')[0]
if self.using_ssh:
if self.tunnel_mod.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
password=False
else:
password = getpass("SSH Password for %s: "%self.sshserver)
else:
password = False
def connect(s, url):
url = disambiguate_url(url, self.location)
if self.using_ssh:
self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
return self.tunnel_mod.tunnel_connection(s, url, self.sshserver,
keyfile=self.sshkey, paramiko=self.paramiko,
password=password,
)
else:
return s.connect(url)
def maybe_tunnel(url):
"""like connect, but don't complete the connection (for use by heartbeat)"""
url = disambiguate_url(url, self.location)
if self.using_ssh:
self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
url, tunnelobj = self.tunnel_mod.open_tunnel(url, self.sshserver,
keyfile=self.sshkey, paramiko=self.paramiko,
password=password,
)
return str(url)
return connect, maybe_tunnel
def register(self):
"""send the registration_request"""
self.log.info("Registering with controller at %s"%self.url)
ctx = self.context
connect,maybe_tunnel = self.init_connector()
reg = ctx.socket(zmq.DEALER)
reg.setsockopt(zmq.IDENTITY, self.bident)
connect(reg, self.url)
self.registrar = zmqstream.ZMQStream(reg, self.loop)
content = dict(uuid=self.ident)
self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
# print (self.session.key)
self.session.send(self.registrar, "registration_request", content=content)
def _report_ping(self, msg):
"""Callback for when the heartmonitor.Heart receives a ping"""
#self.log.debug("Received a ping: %s", msg)
self._hb_last_pinged = time.time()
def complete_registration(self, msg, connect, maybe_tunnel):
# print msg
self._abort_dc.stop()
ctx = self.context
loop = self.loop
identity = self.bident
idents,msg = self.session.feed_identities(msg)
msg = self.session.unserialize(msg)
content = msg['content']
info = self.connection_info
def url(key):
"""get zmq url for given channel"""
return str(info["interface"] + ":%i" % info[key])
if content['status'] == 'ok':
self.id = int(content['id'])
# launch heartbeat
# possibly forward hb ports with tunnels
hb_ping = maybe_tunnel(url('hb_ping'))
hb_pong = maybe_tunnel(url('hb_pong'))
hb_monitor = None
if self.max_heartbeat_misses > 0:
# Add a monitor socket which will record the last time a ping was seen
mon = self.context.socket(zmq.SUB)
mport = mon.bind_to_random_port('tcp://%s' % localhost())
mon.setsockopt(zmq.SUBSCRIBE, b"")
self._hb_listener = zmqstream.ZMQStream(mon, self.loop)
self._hb_listener.on_recv(self._report_ping)
hb_monitor = "tcp://%s:%i" % (localhost(), mport)
heart = Heart(hb_ping, hb_pong, hb_monitor , heart_id=identity)
heart.start()
# create Shell Connections (MUX, Task, etc.):
shell_addrs = url('mux'), url('task')
# Use only one shell stream for mux and tasks
stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
stream.setsockopt(zmq.IDENTITY, identity)
shell_streams = [stream]
for addr in shell_addrs:
connect(stream, addr)
# control stream:
control_addr = url('control')
control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
control_stream.setsockopt(zmq.IDENTITY, identity)
connect(control_stream, control_addr)
# create iopub stream:
iopub_addr = url('iopub')
iopub_socket = ctx.socket(zmq.PUB)
iopub_socket.setsockopt(zmq.IDENTITY, identity)
connect(iopub_socket, iopub_addr)
# disable history:
self.config.HistoryManager.hist_file = ':memory:'
# Redirect input streams and set a display hook.
if self.out_stream_factory:
sys.stdout = self.out_stream_factory(self.session, iopub_socket, u'stdout')
sys.stdout.topic = cast_bytes('engine.%i.stdout' % self.id)
sys.stderr = self.out_stream_factory(self.session, iopub_socket, u'stderr')
sys.stderr.topic = cast_bytes('engine.%i.stderr' % self.id)
if self.display_hook_factory:
sys.displayhook = self.display_hook_factory(self.session, iopub_socket)
sys.displayhook.topic = cast_bytes('engine.%i.execute_result' % self.id)
self.kernel = Kernel(parent=self, int_id=self.id, ident=self.ident, session=self.session,
control_stream=control_stream, shell_streams=shell_streams, iopub_socket=iopub_socket,
loop=loop, user_ns=self.user_ns, log=self.log)
self.kernel.shell.display_pub.topic = cast_bytes('engine.%i.displaypub' % self.id)
# periodically check the heartbeat pings of the controller
# Should be started here and not in "start()" so that the right period can be taken
# from the hubs HeartBeatMonitor.period
if self.max_heartbeat_misses > 0:
# Use a slightly bigger check period than the hub signal period to not warn unnecessary
self.hb_check_period = int(content['hb_period'])+10
self.log.info("Starting to monitor the heartbeat signal from the hub every %i ms." , self.hb_check_period)
self._hb_reporter = ioloop.PeriodicCallback(self._hb_monitor, self.hb_check_period, self.loop)
self._hb_reporter.start()
else:
self.log.info("Monitoring of the heartbeat signal from the hub is not enabled.")
# FIXME: This is a hack until IPKernelApp and IPEngineApp can be fully merged
app = IPKernelApp(parent=self, shell=self.kernel.shell, kernel=self.kernel, log=self.log)
app.init_profile_dir()
app.init_code()
self.kernel.start()
else:
self.log.fatal("Registration Failed: %s"%msg)
raise Exception("Registration Failed: %s"%msg)
self.log.info("Completed registration with id %i"%self.id)
def abort(self):
self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
if self.url.startswith('127.'):
self.log.fatal("""
If the controller and engines are not on the same machine,
you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
c.HubFactory.ip='*' # for all interfaces, internal and external
c.HubFactory.ip='192.168.1.101' # or any interface that the engines can see
or tunnel connections via ssh.
""")
self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
time.sleep(1)
sys.exit(255)
def _hb_monitor(self):
"""Callback to monitor the heartbeat from the controller"""
self._hb_listener.flush()
if self._hb_last_monitored > self._hb_last_pinged:
self._hb_missed_beats += 1
self.log.warn("No heartbeat in the last %s ms (%s time(s) in a row).", self.hb_check_period, self._hb_missed_beats)
else:
#self.log.debug("Heartbeat received (after missing %s beats).", self._hb_missed_beats)
self._hb_missed_beats = 0
if self._hb_missed_beats >= self.max_heartbeat_misses:
self.log.fatal("Maximum number of heartbeats misses reached (%s times %s ms), shutting down.",
self.max_heartbeat_misses, self.hb_check_period)
self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
self.loop.stop()
self._hb_last_monitored = time.time()
def start(self):
dc = ioloop.DelayedCallback(self.register, 0, self.loop)
dc.start()
self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
self._abort_dc.start()