From d58d98ad24ca29a4043b29dae458f0f6e625f2d2 2011-08-16 22:51:48
From: MinRK <benjaminrk@gmail.com>
Date: 2011-08-16 22:51:48
Subject: [PATCH] add ssh tunneling to Engine

'enginessh' alias added to ipcontroller to new IPControllerApp.engine_ssh_server

ssh/keyfile added to ipengine/EngineFactory

---

diff --git a/IPython/parallel/apps/ipcontrollerapp.py b/IPython/parallel/apps/ipcontrollerapp.py
index 695ca17..362318c 100755
--- a/IPython/parallel/apps/ipcontrollerapp.py
+++ b/IPython/parallel/apps/ipcontrollerapp.py
@@ -116,6 +116,7 @@ flags.update(boolean_flag('secure', 'IPControllerApp.secure',
 aliases = dict(
     secure = 'IPControllerApp.secure',
     ssh = 'IPControllerApp.ssh_server',
+    enginessh = 'IPControllerApp.engine_ssh_server',
     location = 'IPControllerApp.location',
 
     ident = 'Session.session',
@@ -158,6 +159,11 @@ class IPControllerApp(BaseParallelApplication):
         processes. It should be of the form: [user@]server[:port]. The
         Controller's listening addresses must be accessible from the ssh server""",
     )
+    engine_ssh_server = Unicode(u'', config=True,
+        help="""ssh url for engines to use when connecting to the Controller
+        processes. It should be of the form: [user@]server[:port]. The
+        Controller's listening addresses must be accessible from the ssh server""",
+    )
     location = Unicode(u'', config=True,
         help="""The external IP or domain name of the Controller, used for disambiguating
         engine and client connections.""",
@@ -218,6 +224,8 @@ class IPControllerApp(BaseParallelApplication):
         c.HubFactory.engine_ip = ip
         c.HubFactory.regport = int(ports)
         self.location = cfg['location']
+        if not self.engine_ssh_server:
+            self.engine_ssh_server = cfg['ssh']
         # load client config
         with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
             cfg = json.loads(f.read())
@@ -226,7 +234,8 @@ class IPControllerApp(BaseParallelApplication):
         c.HubFactory.client_transport = xport
         ip,ports = addr.split(':')
         c.HubFactory.client_ip = ip
-        self.ssh_server = cfg['ssh']
+        if not self.ssh_server:
+            self.ssh_server = cfg['ssh']
         assert int(ports) == c.HubFactory.regport, "regport mismatch"
     
     def init_hub(self):
@@ -271,6 +280,7 @@ class IPControllerApp(BaseParallelApplication):
             self.save_connection_dict('ipcontroller-client.json', cdict)
             edict = cdict
             edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
+            edict['ssh'] = self.engine_ssh_server
             self.save_connection_dict('ipcontroller-engine.json', edict)
 
     #
diff --git a/IPython/parallel/apps/ipengineapp.py b/IPython/parallel/apps/ipengineapp.py
index 1263d91..43fca7a 100755
--- a/IPython/parallel/apps/ipengineapp.py
+++ b/IPython/parallel/apps/ipengineapp.py
@@ -118,6 +118,8 @@ aliases = dict(
     keyfile = 'Session.keyfile',
 
     url = 'EngineFactory.url',
+    ssh = 'EngineFactory.sshserver',
+    sshkey = 'EngineFactory.sshkey',
     ip = 'EngineFactory.ip',
     transport = 'EngineFactory.transport',
     port = 'EngineFactory.regport',
@@ -192,6 +194,40 @@ class IPEngineApp(BaseParallelApplication):
                 self.profile_dir.security_dir,
                 self.url_file_name
             )
+    
+    def load_connector_file(self):
+        """load config from a JSON connector file,
+        at a *lower* priority than command-line/config files.
+        """
+        
+        self.log.info("Loading url_file %r"%self.url_file)
+        config = self.config
+        
+        with open(self.url_file) as f:
+            d = json.loads(f.read())
+        
+        try:
+            config.Session.key
+        except AttributeError:
+            if d['exec_key']:
+                config.Session.key = asbytes(d['exec_key'])
+                
+        try:
+            config.EngineFactory.location
+        except AttributeError:
+            config.EngineFactory.location = d['location']
+        
+        d['url'] = disambiguate_url(d['url'], config.EngineFactory.location)
+        try:
+            config.EngineFactory.url
+        except AttributeError:
+            config.EngineFactory.url = d['url']
+        
+        try:
+            config.EngineFactory.sshserver
+        except AttributeError:
+            config.EngineFactory.sshserver = d['ssh']
+        
     def init_engine(self):
         # This is the working dir by now.
         sys.path.insert(0, '')
@@ -219,14 +255,7 @@ class IPEngineApp(BaseParallelApplication):
                 time.sleep(0.1)
             
         if os.path.exists(self.url_file):
-            self.log.info("Loading url_file %r"%self.url_file)
-            with open(self.url_file) as f:
-                d = json.loads(f.read())
-            if d['exec_key']:
-                config.Session.key = asbytes(d['exec_key'])
-            d['url'] = disambiguate_url(d['url'], d['location'])
-            config.EngineFactory.url = d['url']
-            config.EngineFactory.location = d['location']
+            self.load_connector_file()
         elif not url_specified:
             self.log.critical("Fatal: url file never arrived: %s"%self.url_file)
             self.exit(1)
@@ -253,7 +282,7 @@ class IPEngineApp(BaseParallelApplication):
         except:
             self.log.error("Couldn't start the Engine", exc_info=True)
             self.exit(1)
-        
+    
     def forward_logging(self):
         if self.log_url:
             self.log.info("Forwarding logging to %s"%self.log_url)
@@ -265,7 +294,7 @@ class IPEngineApp(BaseParallelApplication):
             handler.setLevel(self.log_level)
             self.log.addHandler(handler)
             self._log_handler = handler
-    #
+    
     def init_mpi(self):
         global mpi
         self.mpi = MPI(config=self.config)
diff --git a/IPython/parallel/engine/engine.py b/IPython/parallel/engine/engine.py
index 59ef87f..777c5d1 100755
--- a/IPython/parallel/engine/engine.py
+++ b/IPython/parallel/engine/engine.py
@@ -17,12 +17,16 @@ from __future__ import print_function
 
 import sys
 import time
+from getpass import getpass
 
 import zmq
 from zmq.eventloop import ioloop, zmqstream
 
+from IPython.external.ssh import tunnel
 # internal
-from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode, CBytes
+from IPython.utils.traitlets import (
+    Instance, Dict, Int, Type, CFloat, Unicode, CBytes, Bool
+)
 # from IPython.utils.localinterfaces import LOCALHOST 
 
 from IPython.parallel.controller.heartmonitor import Heart
@@ -50,6 +54,12 @@ class EngineFactory(RegistrationFactory):
     timeout=CFloat(2,config=True,
         help="""The time (in seconds) to wait for the Controller to respond
         to registration requests before giving up.""")
+    sshserver=Unicode(config=True,
+        help="""The SSH server to use for tunneling connections to the Controller.""")
+    sshkey=Unicode(config=True,
+        help="""The SSH keyfile to use when tunneling connections to the Controller.""")
+    paramiko=Bool(sys.platform == 'win32', config=True,
+        help="""Whether to use paramiko instead of openssh for tunnels.""")
     
     # not configurable:
     user_ns=Dict()
@@ -61,28 +71,70 @@ class EngineFactory(RegistrationFactory):
     ident = Unicode()
     def _ident_changed(self, name, old, new):
         self.bident = asbytes(new)
+    using_ssh=Bool(False)
     
     
     def __init__(self, **kwargs):
         super(EngineFactory, self).__init__(**kwargs)
         self.ident = self.session.session
-        ctx = self.context
+    
+    def init_connector(self):
+        """construct connection function, which handles tunnels."""
+        self.using_ssh = bool(self.sshkey or self.sshserver)
         
-        reg = ctx.socket(zmq.XREQ)
-        reg.setsockopt(zmq.IDENTITY, self.bident)
-        reg.connect(self.url)
-        self.registrar = zmqstream.ZMQStream(reg, self.loop)
+        if self.sshkey and not self.sshserver:
+            # We are using ssh directly to the controller, tunneling localhost to localhost
+            self.sshserver = self.url.split('://')[1].split(':')[0]
+        
+        if self.using_ssh:
+            if tunnel.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
+                password=False
+            else:
+                password = getpass("SSH Password for %s: "%self.sshserver)
+        else:
+            password = False
+        
+        def connect(s, url):
+            url = disambiguate_url(url, self.location)
+            if self.using_ssh:
+                self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
+                return tunnel.tunnel_connection(s, url, self.sshserver,
+                            keyfile=self.sshkey, paramiko=self.paramiko,
+                            password=password,
+                )
+            else:
+                return s.connect(url)
+        
+        def maybe_tunnel(url):
+            """like connect, but don't complete the connection (for use by heartbeat)"""
+            url = disambiguate_url(url, self.location)
+            if self.using_ssh:
+                self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
+                url,tunnelobj = tunnel.open_tunnel(url, self.sshserver,
+                            keyfile=self.sshkey, paramiko=self.paramiko,
+                            password=password,
+                )
+            return url
+        return connect, maybe_tunnel
         
     def register(self):
         """send the registration_request"""
         
         self.log.info("Registering with controller at %s"%self.url)
+        ctx = self.context
+        connect,maybe_tunnel = self.init_connector()
+        reg = ctx.socket(zmq.XREQ)
+        reg.setsockopt(zmq.IDENTITY, self.bident)
+        connect(reg, self.url)
+        self.registrar = zmqstream.ZMQStream(reg, self.loop)
+        
+        
         content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
-        self.registrar.on_recv(self.complete_registration)
+        self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
         # print (self.session.key)
         self.session.send(self.registrar, "registration_request",content=content)
     
-    def complete_registration(self, msg):
+    def complete_registration(self, msg, connect, maybe_tunnel):
         # print msg
         self._abort_dc.stop()
         ctx = self.context
@@ -94,6 +146,14 @@ class EngineFactory(RegistrationFactory):
         if msg.content.status == 'ok':
             self.id = int(msg.content.id)
             
+            # launch heartbeat
+            hb_addrs = msg.content.heartbeat
+            
+            # possibly forward hb ports with tunnels
+            hb_addrs = [ maybe_tunnel(addr) for addr in hb_addrs ]
+            heart = Heart(*map(str, hb_addrs), heart_id=identity)
+            heart.start()
+            
             # create Shell Streams (MUX, Task, etc.):
             queue_addr = msg.content.mux
             shell_addrs = [ str(queue_addr) ]
@@ -114,24 +174,20 @@ class EngineFactory(RegistrationFactory):
             stream.setsockopt(zmq.IDENTITY, identity)
             shell_streams = [stream]
             for addr in shell_addrs:
-                stream.connect(disambiguate_url(addr, self.location))
+                connect(stream, addr)
             # end single stream-socket
             
             # control stream:
             control_addr = str(msg.content.control)
             control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
             control_stream.setsockopt(zmq.IDENTITY, identity)
-            control_stream.connect(disambiguate_url(control_addr, self.location))
+            connect(control_stream, control_addr)
             
             # create iopub stream:
             iopub_addr = msg.content.iopub
             iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
             iopub_stream.setsockopt(zmq.IDENTITY, identity)
-            iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
-            
-            # launch heartbeat
-            hb_addrs = msg.content.heartbeat
-            # print (hb_addrs)
+            connect(iopub_stream, iopub_addr)
             
             # # Redirect input streams and set a display hook.
             if self.out_stream_factory:
@@ -147,9 +203,6 @@ class EngineFactory(RegistrationFactory):
                     control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream, 
                     loop=loop, user_ns = self.user_ns, log=self.log)
             self.kernel.start()
-            hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
-            heart = Heart(*map(str, hb_addrs), heart_id=identity)
-            heart.start()
             
             
         else: