diff --git a/IPython/zmq/forward.py b/IPython/zmq/forward.py index d39ce97..657075a 100644 --- a/IPython/zmq/forward.py +++ b/IPython/zmq/forward.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -# This file is adapted from a paramiko demo, and thus LGPL 2.1. +# This file is adapted from a paramiko demo, and thus licensed under LGPL 2.1. # Original Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com> # Edits Copyright (C) 2010 The IPython Team # @@ -83,7 +83,7 @@ class Handler (SocketServer.BaseRequestHandler): self.request.send(data) chan.close() self.request.close() - verbose('Tunnel closed from %r' % (self.request.getpeername(),)) + verbose('Tunnel closed ') def forward_tunnel(local_port, remote_host, remote_port, transport): @@ -94,7 +94,7 @@ def forward_tunnel(local_port, remote_host, remote_port, transport): chain_host = remote_host chain_port = remote_port ssh_transport = transport - ForwardServer(('', local_port), SubHander).serve_forever() + ForwardServer(('127.0.0.1', local_port), SubHander).serve_forever() def verbose(s): diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index 4c011a3..5eb2b77 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -19,6 +19,7 @@ import zmq from zmq.eventloop import ioloop, zmqstream from IPython.external.decorator import decorator +from IPython.zmq import tunnel import streamsession as ss # from remotenamespace import RemoteNamespace @@ -117,7 +118,33 @@ class Client(object): addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101' The address of the controller's registration socket. - + [Default: 'tcp://127.0.0.1:10101'] + context : zmq.Context + Pass an existing zmq.Context instance, otherwise the client will create its own + username : bytes + set username to be passed to the Session object + debug : bool + flag for lots of message printing for debug purposes + + #-------------- ssh related args ---------------- + # These are args for configuring the ssh tunnel to be used + # credentials are used to forward connections over ssh to the Controller + # Note that the ip given in `addr` needs to be relative to sshserver + # The most basic case is to leave addr as pointing to localhost (127.0.0.1), + # and set sshserver as the same machine the Controller is on. However, + # the only requirement is that sshserver is able to see the Controller + # (i.e. is within the same trusted network). + + sshserver : str + A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port' + If keyfile or password is specified, and this is not, it will default to + the ip given in addr. + keyfile : str; path to public key file + This specifies a key to be used in ssh login, default None. + Regular default ssh keys will be used without specifying this argument. + password : str; + Your ssh password to sshserver. Note that if this is left None, + you will be prompted for it if passwordless key based login is unavailable. Attributes ---------- @@ -159,6 +186,7 @@ class Client(object): _connected=False + _ssh=False _engines=None _addr='tcp://127.0.0.1:10101' _registration_socket=None @@ -173,18 +201,33 @@ class Client(object): history = None debug = False - def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False): + def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False, + sshserver=None, keyfile=None, password=None, paramiko=None): if context is None: context = zmq.Context() self.context = context self._addr = addr + self._ssh = bool(sshserver or keyfile or password) + if self._ssh and sshserver is None: + # default to the same + sshserver = addr.split('://')[1].split(':')[0] + if self._ssh and password is None: + if tunnel.try_passwordless_ssh(sshserver, keyfile, paramiko): + password=False + else: + password = getpass("SSH Password for %s: "%sshserver) + ssh_kwargs = dict(keyfile=keyfile, password=password, paramiko=paramiko) + if username is None: self.session = ss.StreamSession() else: self.session = ss.StreamSession(username) - self._registration_socket = self.context.socket(zmq.PAIR) + self._registration_socket = self.context.socket(zmq.XREQ) self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session) - self._registration_socket.connect(addr) + if self._ssh: + tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs) + else: + self._registration_socket.connect(addr) self._engines = {} self._ids = set() self.outstanding=set() @@ -198,7 +241,7 @@ class Client(object): } self._queue_handlers = {'execute_reply' : self._handle_execute_reply, 'apply_reply' : self._handle_apply_reply} - self._connect() + self._connect(sshserver, ssh_kwargs) @property @@ -229,12 +272,19 @@ class Client(object): targets = [targets] return [self._engines[t] for t in targets], list(targets) - def _connect(self): + def _connect(self, sshserver, ssh_kwargs): """setup all our socket connections to the controller. This is called from __init__.""" if self._connected: return self._connected=True + + def connect_socket(s, addr): + if self._ssh: + return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs) + else: + return s.connect(addr) + self.session.send(self._registration_socket, 'connection_request') idents,msg = self.session.recv(self._registration_socket,mode=0) if self.debug: @@ -245,23 +295,23 @@ class Client(object): if content.queue: self._mux_socket = self.context.socket(zmq.PAIR) self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session) - self._mux_socket.connect(content.queue) + connect_socket(self._mux_socket, content.queue) if content.task: self._task_socket = self.context.socket(zmq.PAIR) self._task_socket.setsockopt(zmq.IDENTITY, self.session.session) - self._task_socket.connect(content.task) + connect_socket(self._task_socket, content.task) if content.notification: self._notification_socket = self.context.socket(zmq.SUB) - self._notification_socket.connect(content.notification) + connect_socket(self._notification_socket, content.notification) self._notification_socket.setsockopt(zmq.SUBSCRIBE, "") if content.query: self._query_socket = self.context.socket(zmq.PAIR) self._query_socket.setsockopt(zmq.IDENTITY, self.session.session) - self._query_socket.connect(content.query) + connect_socket(self._query_socket, content.query) if content.control: self._control_socket = self.context.socket(zmq.PAIR) self._control_socket.setsockopt(zmq.IDENTITY, self.session.session) - self._control_socket.connect(content.control) + connect_socket(self._control_socket, content.control) self._update_engines(dict(content.engines)) else: @@ -852,4 +902,4 @@ class AsynClient(Client): for stream in (self.queue_stream, self.notifier_stream, self.task_stream, self.control_stream): stream.flush() - \ No newline at end of file + diff --git a/IPython/zmq/tunnel.py b/IPython/zmq/tunnel.py index b6b1a36..0972381 100644 --- a/IPython/zmq/tunnel.py +++ b/IPython/zmq/tunnel.py @@ -1,12 +1,21 @@ +"""Basic ssh tunneling utilities.""" +#----------------------------------------------------------------------------- +# Copyright (C) 2008-2010 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#----------------------------------------------------------------------------- -#----------------------------------------- + + +#----------------------------------------------------------------------------- # Imports -#----------------------------------------- +#----------------------------------------------------------------------------- from __future__ import print_function -import os,sys +import os,sys, atexit from multiprocessing import Process from getpass import getpass, getuser @@ -16,20 +25,137 @@ except ImportError: paramiko = None else: from forward import forward_tunnel + +try: + from IPython.external import pexpect +except ImportError: + pexpect = None + +from IPython.zmq.parallel.entry_point import select_random_ports + +#----------------------------------------------------------------------------- +# Code +#----------------------------------------------------------------------------- + +#----------------------------------------------------------------------------- +# Check for passwordless login +#----------------------------------------------------------------------------- + +def try_passwordless_ssh(server, keyfile, paramiko=None): + """Attempt to make an ssh connection without a password. + This is mainly used for requiring password input only once + when many tunnels may be connected to the same server. -from IPython.external import pexpect + If paramiko is None, the default for the platform is chosen. + """ + if paramiko is None: + paramiko = sys.platform == 'win32' + if not paramiko: + f = _try_passwordless_openssh + else: + f = _try_passwordless_paramiko + return f(server, keyfile) +def _try_passwordless_openssh(server, keyfile): + """Try passwordless login with shell ssh command.""" + if pexpect is None: + raise ImportError("pexpect unavailable, use paramiko") + cmd = 'ssh -f '+ server + if keyfile: + cmd += ' -i ' + keyfile + cmd += ' exit' + p = pexpect.spawn(cmd) + while True: + try: + p.expect('[Ppassword]:', timeout=.1) + except pexpect.TIMEOUT: + continue + except pexpect.EOF: + return True + else: + return False -def launch_ssh_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, timeout=15): +def _try_passwordless_paramiko(server, keyfile): + """Try passwordless login with paramiko.""" + if paramiko is None: + raise ImportError("paramiko unavailable, use openssh") + username, server, port = _split_server(server) + client = paramiko.SSHClient() + client.load_system_host_keys() + client.set_missing_host_key_policy(paramiko.WarningPolicy()) + try: + client.connect(server, port, username=username, key_filename=keyfile, + look_for_keys=True) + except paramiko.AuthenticationException: + return False + else: + client.close() + return True + + +def tunnel_connection(socket, addr, server, keyfile=None, password=None, paramiko=None): + """Connect a socket to an address via an ssh tunnel. + + This is a wrapper for socket.connect(addr), when addr is not accessible + from the local machine. It simply creates an ssh tunnel using the remaining args, + and calls socket.connect('tcp://localhost:lport') where lport is the randomly + selected local port of the tunnel. + + """ + lport = select_random_ports(1)[0] + transport, addr = addr.split('://') + ip,rport = addr.split(':') + rport = int(rport) + if paramiko is None: + paramiko = sys.platform == 'win32' + if paramiko: + tunnelf = paramiko_tunnel + else: + tunnelf = openssh_tunnel + tunnel = tunnelf(lport, rport, server, remoteip=ip, keyfile=keyfile, password=password) + socket.connect('tcp://127.0.0.1:%i'%lport) + return tunnel + +def openssh_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=15): """Create an ssh tunnel using command-line ssh that connects port lport on this machine to localhost:rport on server. The tunnel will automatically close when not in use, remaining open for a minimum of timeout seconds for an initial connection. + + This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`, + as seen from `server`. + + keyfile and password may be specified, but ssh config is checked for defaults. + + Parameters + ---------- + + lport : int + local port for connecting to the tunnel from this machine. + rport : int + port on the remote machine to connect to. + server : str + The ssh server to connect to. The full ssh server string will be parsed. + user@server:port + remoteip : str [Default: 127.0.0.1] + The remote ip, specifying the destination of the tunnel. + Default is localhost, which means that the tunnel would redirect + localhost:lport on this machine to localhost:rport on the *server*. + + keyfile : str; path to public key file + This specifies a key to be used in ssh login, default None. + Regular default ssh keys will be used without specifying this argument. + password : str; + Your ssh password to the ssh server. Note that if this is left None, + you will be prompted for it if passwordless key based login is unavailable. + """ + if pexpect is None: + raise ImportError("pexpect unavailable, use paramiko_tunnel") ssh="ssh " if keyfile: ssh += "-i " + keyfile - cmd = ssh + " -f -L %i:127.0.0.1:%i %s sleep %i"%(lport, rport, server, timeout) + cmd = ssh + " -f -L 127.0.0.1:%i:127.0.0.1:%i %s sleep %i"%(lport, rport, server, timeout) tunnel = pexpect.spawn(cmd) failed = False while True: @@ -48,7 +174,10 @@ def launch_ssh_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, else: if failed: print("Password rejected, try again") - tunnel.sendline(getpass()) + password=None + if password is None: + password = getpass("%s's password: "%(server)) + tunnel.sendline(password) failed = True def _split_server(server): @@ -63,28 +192,62 @@ def _split_server(server): port = 22 return username, server, port -def launch_paramiko_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None): - """launch a tunner with paramiko in a subprocess""" +def paramiko_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=15): + """launch a tunner with paramiko in a subprocess. This should only be used + when shell ssh is unavailable (e.g. Windows). + + This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`, + as seen from `server`. + + keyfile and password may be specified, but ssh config is checked for defaults. + + Parameters + ---------- + + lport : int + local port for connecting to the tunnel from this machine. + rport : int + port on the remote machine to connect to. + server : str + The ssh server to connect to. The full ssh server string will be parsed. + user@server:port + remoteip : str [Default: 127.0.0.1] + The remote ip, specifying the destination of the tunnel. + Default is localhost, which means that the tunnel would redirect + localhost:lport on this machine to localhost:rport on the *server*. + + keyfile : str; path to public key file + This specifies a key to be used in ssh login, default None. + Regular default ssh keys will be used without specifying this argument. + password : str; + Your ssh password to the ssh server. Note that if this is left None, + you will be prompted for it if passwordless key based login is unavailable. + + """ if paramiko is None: raise ImportError("Paramiko not available") - server = _split_server(server) - if keyfile is None: - passwd = getpass("%s@%s's password: "%(server[0], server[1])) - else: - passwd = None + + if password is None: + if not _check_passwordless_paramiko(server, keyfile): + password = getpass("%s's password: "%(server)) + p = Process(target=_paramiko_tunnel, args=(lport, rport, server, remoteip), - kwargs=dict(keyfile=keyfile, password=passwd)) + kwargs=dict(keyfile=keyfile, password=password)) p.daemon=False p.start() + atexit.register(_shutdown_process, p) return p +def _shutdown_process(p): + if p.isalive(): + p.terminate() def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None): """function for actually starting a paramiko tunnel, to be passed to multiprocessing.Process(target=this). """ - username, server, port = server + username, server, port = _split_server(server) client = paramiko.SSHClient() client.load_system_host_keys() client.set_missing_host_key_policy(paramiko.WarningPolicy()) @@ -92,20 +255,34 @@ def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None try: client.connect(server, port, username=username, key_filename=keyfile, look_for_keys=True, password=password) +# except paramiko.AuthenticationException: +# if password is None: +# password = getpass("%s@%s's password: "%(username, server)) +# client.connect(server, port, username=username, password=password) +# else: +# raise except Exception as e: print ('*** Failed to connect to %s:%d: %r' % (server, port, e)) sys.exit(1) - print ('Now forwarding port %d to %s:%d ...' % (lport, server, rport)) + # print ('Now forwarding port %d to %s:%d ...' % (lport, server, rport)) try: forward_tunnel(lport, remoteip, rport, client.get_transport()) except KeyboardInterrupt: - print ('C-c: Port forwarding stopped.') + print ('SIGINT: Port forwarding stopped cleanly') sys.exit(0) + except Exception as e: + print ("Port forwarding stopped uncleanly: %s"%e) + sys.exit(255) + +if sys.platform == 'win32': + ssh_tunnel = paramiko_tunnel +else: + ssh_tunnel = openssh_tunnel -__all__ = ['launch_ssh_tunnel', 'launch_paramiko_tunnel'] +__all__ = ['tunnel_connection', 'ssh_tunnel', 'openssh_tunnel', 'paramiko_tunnel', 'try_passwordless_ssh']