diff --git a/IPython/parallel/apps/ipcontrollerapp.py b/IPython/parallel/apps/ipcontrollerapp.py index 695ca17..362318c 100755 --- a/IPython/parallel/apps/ipcontrollerapp.py +++ b/IPython/parallel/apps/ipcontrollerapp.py @@ -116,6 +116,7 @@ flags.update(boolean_flag('secure', 'IPControllerApp.secure', aliases = dict( secure = 'IPControllerApp.secure', ssh = 'IPControllerApp.ssh_server', + enginessh = 'IPControllerApp.engine_ssh_server', location = 'IPControllerApp.location', ident = 'Session.session', @@ -158,6 +159,11 @@ class IPControllerApp(BaseParallelApplication): processes. It should be of the form: [user@]server[:port]. The Controller's listening addresses must be accessible from the ssh server""", ) + engine_ssh_server = Unicode(u'', config=True, + help="""ssh url for engines to use when connecting to the Controller + processes. It should be of the form: [user@]server[:port]. The + Controller's listening addresses must be accessible from the ssh server""", + ) location = Unicode(u'', config=True, help="""The external IP or domain name of the Controller, used for disambiguating engine and client connections.""", @@ -218,6 +224,8 @@ class IPControllerApp(BaseParallelApplication): c.HubFactory.engine_ip = ip c.HubFactory.regport = int(ports) self.location = cfg['location'] + if not self.engine_ssh_server: + self.engine_ssh_server = cfg['ssh'] # load client config with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f: cfg = json.loads(f.read()) @@ -226,7 +234,8 @@ class IPControllerApp(BaseParallelApplication): c.HubFactory.client_transport = xport ip,ports = addr.split(':') c.HubFactory.client_ip = ip - self.ssh_server = cfg['ssh'] + if not self.ssh_server: + self.ssh_server = cfg['ssh'] assert int(ports) == c.HubFactory.regport, "regport mismatch" def init_hub(self): @@ -271,6 +280,7 @@ class IPControllerApp(BaseParallelApplication): self.save_connection_dict('ipcontroller-client.json', cdict) edict = cdict edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport)) + edict['ssh'] = self.engine_ssh_server self.save_connection_dict('ipcontroller-engine.json', edict) # diff --git a/IPython/parallel/apps/ipengineapp.py b/IPython/parallel/apps/ipengineapp.py index 1263d91..43fca7a 100755 --- a/IPython/parallel/apps/ipengineapp.py +++ b/IPython/parallel/apps/ipengineapp.py @@ -118,6 +118,8 @@ aliases = dict( keyfile = 'Session.keyfile', url = 'EngineFactory.url', + ssh = 'EngineFactory.sshserver', + sshkey = 'EngineFactory.sshkey', ip = 'EngineFactory.ip', transport = 'EngineFactory.transport', port = 'EngineFactory.regport', @@ -192,6 +194,40 @@ class IPEngineApp(BaseParallelApplication): self.profile_dir.security_dir, self.url_file_name ) + + def load_connector_file(self): + """load config from a JSON connector file, + at a *lower* priority than command-line/config files. + """ + + self.log.info("Loading url_file %r"%self.url_file) + config = self.config + + with open(self.url_file) as f: + d = json.loads(f.read()) + + try: + config.Session.key + except AttributeError: + if d['exec_key']: + config.Session.key = asbytes(d['exec_key']) + + try: + config.EngineFactory.location + except AttributeError: + config.EngineFactory.location = d['location'] + + d['url'] = disambiguate_url(d['url'], config.EngineFactory.location) + try: + config.EngineFactory.url + except AttributeError: + config.EngineFactory.url = d['url'] + + try: + config.EngineFactory.sshserver + except AttributeError: + config.EngineFactory.sshserver = d['ssh'] + def init_engine(self): # This is the working dir by now. sys.path.insert(0, '') @@ -219,14 +255,7 @@ class IPEngineApp(BaseParallelApplication): time.sleep(0.1) if os.path.exists(self.url_file): - self.log.info("Loading url_file %r"%self.url_file) - with open(self.url_file) as f: - d = json.loads(f.read()) - if d['exec_key']: - config.Session.key = asbytes(d['exec_key']) - d['url'] = disambiguate_url(d['url'], d['location']) - config.EngineFactory.url = d['url'] - config.EngineFactory.location = d['location'] + self.load_connector_file() elif not url_specified: self.log.critical("Fatal: url file never arrived: %s"%self.url_file) self.exit(1) @@ -253,7 +282,7 @@ class IPEngineApp(BaseParallelApplication): except: self.log.error("Couldn't start the Engine", exc_info=True) self.exit(1) - + def forward_logging(self): if self.log_url: self.log.info("Forwarding logging to %s"%self.log_url) @@ -265,7 +294,7 @@ class IPEngineApp(BaseParallelApplication): handler.setLevel(self.log_level) self.log.addHandler(handler) self._log_handler = handler - # + def init_mpi(self): global mpi self.mpi = MPI(config=self.config) diff --git a/IPython/parallel/engine/engine.py b/IPython/parallel/engine/engine.py index 59ef87f..777c5d1 100755 --- a/IPython/parallel/engine/engine.py +++ b/IPython/parallel/engine/engine.py @@ -17,12 +17,16 @@ from __future__ import print_function import sys import time +from getpass import getpass import zmq from zmq.eventloop import ioloop, zmqstream +from IPython.external.ssh import tunnel # internal -from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode, CBytes +from IPython.utils.traitlets import ( + Instance, Dict, Int, Type, CFloat, Unicode, CBytes, Bool +) # from IPython.utils.localinterfaces import LOCALHOST from IPython.parallel.controller.heartmonitor import Heart @@ -50,6 +54,12 @@ class EngineFactory(RegistrationFactory): timeout=CFloat(2,config=True, help="""The time (in seconds) to wait for the Controller to respond to registration requests before giving up.""") + sshserver=Unicode(config=True, + help="""The SSH server to use for tunneling connections to the Controller.""") + sshkey=Unicode(config=True, + help="""The SSH keyfile to use when tunneling connections to the Controller.""") + paramiko=Bool(sys.platform == 'win32', config=True, + help="""Whether to use paramiko instead of openssh for tunnels.""") # not configurable: user_ns=Dict() @@ -61,28 +71,70 @@ class EngineFactory(RegistrationFactory): ident = Unicode() def _ident_changed(self, name, old, new): self.bident = asbytes(new) + using_ssh=Bool(False) def __init__(self, **kwargs): super(EngineFactory, self).__init__(**kwargs) self.ident = self.session.session - ctx = self.context + + def init_connector(self): + """construct connection function, which handles tunnels.""" + self.using_ssh = bool(self.sshkey or self.sshserver) - reg = ctx.socket(zmq.XREQ) - reg.setsockopt(zmq.IDENTITY, self.bident) - reg.connect(self.url) - self.registrar = zmqstream.ZMQStream(reg, self.loop) + if self.sshkey and not self.sshserver: + # We are using ssh directly to the controller, tunneling localhost to localhost + self.sshserver = self.url.split('://')[1].split(':')[0] + + if self.using_ssh: + if tunnel.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko): + password=False + else: + password = getpass("SSH Password for %s: "%self.sshserver) + else: + password = False + + def connect(s, url): + url = disambiguate_url(url, self.location) + if self.using_ssh: + self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver)) + return tunnel.tunnel_connection(s, url, self.sshserver, + keyfile=self.sshkey, paramiko=self.paramiko, + password=password, + ) + else: + return s.connect(url) + + def maybe_tunnel(url): + """like connect, but don't complete the connection (for use by heartbeat)""" + url = disambiguate_url(url, self.location) + if self.using_ssh: + self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver)) + url,tunnelobj = tunnel.open_tunnel(url, self.sshserver, + keyfile=self.sshkey, paramiko=self.paramiko, + password=password, + ) + return url + return connect, maybe_tunnel def register(self): """send the registration_request""" self.log.info("Registering with controller at %s"%self.url) + ctx = self.context + connect,maybe_tunnel = self.init_connector() + reg = ctx.socket(zmq.XREQ) + reg.setsockopt(zmq.IDENTITY, self.bident) + connect(reg, self.url) + self.registrar = zmqstream.ZMQStream(reg, self.loop) + + content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident) - self.registrar.on_recv(self.complete_registration) + self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel)) # print (self.session.key) self.session.send(self.registrar, "registration_request",content=content) - def complete_registration(self, msg): + def complete_registration(self, msg, connect, maybe_tunnel): # print msg self._abort_dc.stop() ctx = self.context @@ -94,6 +146,14 @@ class EngineFactory(RegistrationFactory): if msg.content.status == 'ok': self.id = int(msg.content.id) + # launch heartbeat + hb_addrs = msg.content.heartbeat + + # possibly forward hb ports with tunnels + hb_addrs = [ maybe_tunnel(addr) for addr in hb_addrs ] + heart = Heart(*map(str, hb_addrs), heart_id=identity) + heart.start() + # create Shell Streams (MUX, Task, etc.): queue_addr = msg.content.mux shell_addrs = [ str(queue_addr) ] @@ -114,24 +174,20 @@ class EngineFactory(RegistrationFactory): stream.setsockopt(zmq.IDENTITY, identity) shell_streams = [stream] for addr in shell_addrs: - stream.connect(disambiguate_url(addr, self.location)) + connect(stream, addr) # end single stream-socket # control stream: control_addr = str(msg.content.control) control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop) control_stream.setsockopt(zmq.IDENTITY, identity) - control_stream.connect(disambiguate_url(control_addr, self.location)) + connect(control_stream, control_addr) # create iopub stream: iopub_addr = msg.content.iopub iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop) iopub_stream.setsockopt(zmq.IDENTITY, identity) - iopub_stream.connect(disambiguate_url(iopub_addr, self.location)) - - # launch heartbeat - hb_addrs = msg.content.heartbeat - # print (hb_addrs) + connect(iopub_stream, iopub_addr) # # Redirect input streams and set a display hook. if self.out_stream_factory: @@ -147,9 +203,6 @@ class EngineFactory(RegistrationFactory): control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream, loop=loop, user_ns = self.user_ns, log=self.log) self.kernel.start() - hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ] - heart = Heart(*map(str, hb_addrs), heart_id=identity) - heart.start() else: