##// END OF EJS Templates
added preliminary ssh tunneling support for clients
Min RK -
Show More
@@ -1,183 +1,183
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2
2
3 #
3 #
4 # This file is adapted from a paramiko demo, and thus LGPL 2.1.
4 # This file is adapted from a paramiko demo, and thus licensed under LGPL 2.1.
5 # Original Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
5 # Original Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
6 # Edits Copyright (C) 2010 The IPython Team
6 # Edits Copyright (C) 2010 The IPython Team
7 #
7 #
8 # Paramiko is free software; you can redistribute it and/or modify it under the
8 # Paramiko is free software; you can redistribute it and/or modify it under the
9 # terms of the GNU Lesser General Public License as published by the Free
9 # terms of the GNU Lesser General Public License as published by the Free
10 # Software Foundation; either version 2.1 of the License, or (at your option)
10 # Software Foundation; either version 2.1 of the License, or (at your option)
11 # any later version.
11 # any later version.
12 #
12 #
13 # Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
13 # Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
14 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
14 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15 # A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
15 # A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
16 # details.
16 # details.
17 #
17 #
18 # You should have received a copy of the GNU Lesser General Public License
18 # You should have received a copy of the GNU Lesser General Public License
19 # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
19 # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
20 # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
20 # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
21
21
22 """
22 """
23 Sample script showing how to do local port forwarding over paramiko.
23 Sample script showing how to do local port forwarding over paramiko.
24
24
25 This script connects to the requested SSH server and sets up local port
25 This script connects to the requested SSH server and sets up local port
26 forwarding (the openssh -L option) from a local port through a tunneled
26 forwarding (the openssh -L option) from a local port through a tunneled
27 connection to a destination reachable from the SSH server machine.
27 connection to a destination reachable from the SSH server machine.
28 """
28 """
29
29
30 from __future__ import print_function
30 from __future__ import print_function
31
31
32 import getpass
32 import getpass
33 import os
33 import os
34 import socket
34 import socket
35 import select
35 import select
36 import SocketServer
36 import SocketServer
37 import sys
37 import sys
38 from optparse import OptionParser
38 from optparse import OptionParser
39
39
40 import paramiko
40 import paramiko
41
41
42 SSH_PORT = 22
42 SSH_PORT = 22
43 DEFAULT_PORT = 4000
43 DEFAULT_PORT = 4000
44
44
45 g_verbose = False
45 g_verbose = False
46
46
47
47
48 class ForwardServer (SocketServer.ThreadingTCPServer):
48 class ForwardServer (SocketServer.ThreadingTCPServer):
49 daemon_threads = True
49 daemon_threads = True
50 allow_reuse_address = True
50 allow_reuse_address = True
51
51
52
52
53 class Handler (SocketServer.BaseRequestHandler):
53 class Handler (SocketServer.BaseRequestHandler):
54
54
55 def handle(self):
55 def handle(self):
56 try:
56 try:
57 chan = self.ssh_transport.open_channel('direct-tcpip',
57 chan = self.ssh_transport.open_channel('direct-tcpip',
58 (self.chain_host, self.chain_port),
58 (self.chain_host, self.chain_port),
59 self.request.getpeername())
59 self.request.getpeername())
60 except Exception, e:
60 except Exception, e:
61 verbose('Incoming request to %s:%d failed: %s' % (self.chain_host,
61 verbose('Incoming request to %s:%d failed: %s' % (self.chain_host,
62 self.chain_port,
62 self.chain_port,
63 repr(e)))
63 repr(e)))
64 return
64 return
65 if chan is None:
65 if chan is None:
66 verbose('Incoming request to %s:%d was rejected by the SSH server.' %
66 verbose('Incoming request to %s:%d was rejected by the SSH server.' %
67 (self.chain_host, self.chain_port))
67 (self.chain_host, self.chain_port))
68 return
68 return
69
69
70 verbose('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(),
70 verbose('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(),
71 chan.getpeername(), (self.chain_host, self.chain_port)))
71 chan.getpeername(), (self.chain_host, self.chain_port)))
72 while True:
72 while True:
73 r, w, x = select.select([self.request, chan], [], [])
73 r, w, x = select.select([self.request, chan], [], [])
74 if self.request in r:
74 if self.request in r:
75 data = self.request.recv(1024)
75 data = self.request.recv(1024)
76 if len(data) == 0:
76 if len(data) == 0:
77 break
77 break
78 chan.send(data)
78 chan.send(data)
79 if chan in r:
79 if chan in r:
80 data = chan.recv(1024)
80 data = chan.recv(1024)
81 if len(data) == 0:
81 if len(data) == 0:
82 break
82 break
83 self.request.send(data)
83 self.request.send(data)
84 chan.close()
84 chan.close()
85 self.request.close()
85 self.request.close()
86 verbose('Tunnel closed from %r' % (self.request.getpeername(),))
86 verbose('Tunnel closed ')
87
87
88
88
89 def forward_tunnel(local_port, remote_host, remote_port, transport):
89 def forward_tunnel(local_port, remote_host, remote_port, transport):
90 # this is a little convoluted, but lets me configure things for the Handler
90 # this is a little convoluted, but lets me configure things for the Handler
91 # object. (SocketServer doesn't give Handlers any way to access the outer
91 # object. (SocketServer doesn't give Handlers any way to access the outer
92 # server normally.)
92 # server normally.)
93 class SubHander (Handler):
93 class SubHander (Handler):
94 chain_host = remote_host
94 chain_host = remote_host
95 chain_port = remote_port
95 chain_port = remote_port
96 ssh_transport = transport
96 ssh_transport = transport
97 ForwardServer(('', local_port), SubHander).serve_forever()
97 ForwardServer(('127.0.0.1', local_port), SubHander).serve_forever()
98
98
99
99
100 def verbose(s):
100 def verbose(s):
101 if g_verbose:
101 if g_verbose:
102 print (s)
102 print (s)
103
103
104
104
105 HELP = """\
105 HELP = """\
106 Set up a forward tunnel across an SSH server, using paramiko. A local port
106 Set up a forward tunnel across an SSH server, using paramiko. A local port
107 (given with -p) is forwarded across an SSH session to an address:port from
107 (given with -p) is forwarded across an SSH session to an address:port from
108 the SSH server. This is similar to the openssh -L option.
108 the SSH server. This is similar to the openssh -L option.
109 """
109 """
110
110
111
111
112 def get_host_port(spec, default_port):
112 def get_host_port(spec, default_port):
113 "parse 'hostname:22' into a host and port, with the port optional"
113 "parse 'hostname:22' into a host and port, with the port optional"
114 args = (spec.split(':', 1) + [default_port])[:2]
114 args = (spec.split(':', 1) + [default_port])[:2]
115 args[1] = int(args[1])
115 args[1] = int(args[1])
116 return args[0], args[1]
116 return args[0], args[1]
117
117
118
118
119 def parse_options():
119 def parse_options():
120 global g_verbose
120 global g_verbose
121
121
122 parser = OptionParser(usage='usage: %prog [options] <ssh-server>[:<server-port>]',
122 parser = OptionParser(usage='usage: %prog [options] <ssh-server>[:<server-port>]',
123 version='%prog 1.0', description=HELP)
123 version='%prog 1.0', description=HELP)
124 parser.add_option('-q', '--quiet', action='store_false', dest='verbose', default=True,
124 parser.add_option('-q', '--quiet', action='store_false', dest='verbose', default=True,
125 help='squelch all informational output')
125 help='squelch all informational output')
126 parser.add_option('-p', '--local-port', action='store', type='int', dest='port',
126 parser.add_option('-p', '--local-port', action='store', type='int', dest='port',
127 default=DEFAULT_PORT,
127 default=DEFAULT_PORT,
128 help='local port to forward (default: %d)' % DEFAULT_PORT)
128 help='local port to forward (default: %d)' % DEFAULT_PORT)
129 parser.add_option('-u', '--user', action='store', type='string', dest='user',
129 parser.add_option('-u', '--user', action='store', type='string', dest='user',
130 default=getpass.getuser(),
130 default=getpass.getuser(),
131 help='username for SSH authentication (default: %s)' % getpass.getuser())
131 help='username for SSH authentication (default: %s)' % getpass.getuser())
132 parser.add_option('-K', '--key', action='store', type='string', dest='keyfile',
132 parser.add_option('-K', '--key', action='store', type='string', dest='keyfile',
133 default=None,
133 default=None,
134 help='private key file to use for SSH authentication')
134 help='private key file to use for SSH authentication')
135 parser.add_option('', '--no-key', action='store_false', dest='look_for_keys', default=True,
135 parser.add_option('', '--no-key', action='store_false', dest='look_for_keys', default=True,
136 help='don\'t look for or use a private key file')
136 help='don\'t look for or use a private key file')
137 parser.add_option('-P', '--password', action='store_true', dest='readpass', default=False,
137 parser.add_option('-P', '--password', action='store_true', dest='readpass', default=False,
138 help='read password (for key or password auth) from stdin')
138 help='read password (for key or password auth) from stdin')
139 parser.add_option('-r', '--remote', action='store', type='string', dest='remote', default=None, metavar='host:port',
139 parser.add_option('-r', '--remote', action='store', type='string', dest='remote', default=None, metavar='host:port',
140 help='remote host and port to forward to')
140 help='remote host and port to forward to')
141 options, args = parser.parse_args()
141 options, args = parser.parse_args()
142
142
143 if len(args) != 1:
143 if len(args) != 1:
144 parser.error('Incorrect number of arguments.')
144 parser.error('Incorrect number of arguments.')
145 if options.remote is None:
145 if options.remote is None:
146 parser.error('Remote address required (-r).')
146 parser.error('Remote address required (-r).')
147
147
148 g_verbose = options.verbose
148 g_verbose = options.verbose
149 server_host, server_port = get_host_port(args[0], SSH_PORT)
149 server_host, server_port = get_host_port(args[0], SSH_PORT)
150 remote_host, remote_port = get_host_port(options.remote, SSH_PORT)
150 remote_host, remote_port = get_host_port(options.remote, SSH_PORT)
151 return options, (server_host, server_port), (remote_host, remote_port)
151 return options, (server_host, server_port), (remote_host, remote_port)
152
152
153
153
154 def main():
154 def main():
155 options, server, remote = parse_options()
155 options, server, remote = parse_options()
156
156
157 password = None
157 password = None
158 if options.readpass:
158 if options.readpass:
159 password = getpass.getpass('Enter SSH password: ')
159 password = getpass.getpass('Enter SSH password: ')
160
160
161 client = paramiko.SSHClient()
161 client = paramiko.SSHClient()
162 client.load_system_host_keys()
162 client.load_system_host_keys()
163 client.set_missing_host_key_policy(paramiko.WarningPolicy())
163 client.set_missing_host_key_policy(paramiko.WarningPolicy())
164
164
165 verbose('Connecting to ssh host %s:%d ...' % (server[0], server[1]))
165 verbose('Connecting to ssh host %s:%d ...' % (server[0], server[1]))
166 try:
166 try:
167 client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile,
167 client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile,
168 look_for_keys=options.look_for_keys, password=password)
168 look_for_keys=options.look_for_keys, password=password)
169 except Exception as e:
169 except Exception as e:
170 print ('*** Failed to connect to %s:%d: %r' % (server[0], server[1], e))
170 print ('*** Failed to connect to %s:%d: %r' % (server[0], server[1], e))
171 sys.exit(1)
171 sys.exit(1)
172
172
173 verbose('Now forwarding port %d to %s:%d ...' % (options.port, remote[0], remote[1]))
173 verbose('Now forwarding port %d to %s:%d ...' % (options.port, remote[0], remote[1]))
174
174
175 try:
175 try:
176 forward_tunnel(options.port, remote[0], remote[1], client.get_transport())
176 forward_tunnel(options.port, remote[0], remote[1], client.get_transport())
177 except KeyboardInterrupt:
177 except KeyboardInterrupt:
178 print ('C-c: Port forwarding stopped.')
178 print ('C-c: Port forwarding stopped.')
179 sys.exit(0)
179 sys.exit(0)
180
180
181
181
182 if __name__ == '__main__':
182 if __name__ == '__main__':
183 main()
183 main()
@@ -1,855 +1,905
1 """A semi-synchronous Client for the ZMQ controller"""
1 """A semi-synchronous Client for the ZMQ controller"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 from __future__ import print_function
13 from __future__ import print_function
14
14
15 import time
15 import time
16 from pprint import pprint
16 from pprint import pprint
17
17
18 import zmq
18 import zmq
19 from zmq.eventloop import ioloop, zmqstream
19 from zmq.eventloop import ioloop, zmqstream
20
20
21 from IPython.external.decorator import decorator
21 from IPython.external.decorator import decorator
22 from IPython.zmq import tunnel
22
23
23 import streamsession as ss
24 import streamsession as ss
24 # from remotenamespace import RemoteNamespace
25 # from remotenamespace import RemoteNamespace
25 from view import DirectView, LoadBalancedView
26 from view import DirectView, LoadBalancedView
26 from dependency import Dependency, depend, require
27 from dependency import Dependency, depend, require
27
28
28 def _push(ns):
29 def _push(ns):
29 globals().update(ns)
30 globals().update(ns)
30
31
31 def _pull(keys):
32 def _pull(keys):
32 g = globals()
33 g = globals()
33 if isinstance(keys, (list,tuple, set)):
34 if isinstance(keys, (list,tuple, set)):
34 for key in keys:
35 for key in keys:
35 if not g.has_key(key):
36 if not g.has_key(key):
36 raise NameError("name '%s' is not defined"%key)
37 raise NameError("name '%s' is not defined"%key)
37 return map(g.get, keys)
38 return map(g.get, keys)
38 else:
39 else:
39 if not g.has_key(keys):
40 if not g.has_key(keys):
40 raise NameError("name '%s' is not defined"%keys)
41 raise NameError("name '%s' is not defined"%keys)
41 return g.get(keys)
42 return g.get(keys)
42
43
43 def _clear():
44 def _clear():
44 globals().clear()
45 globals().clear()
45
46
46 def execute(code):
47 def execute(code):
47 exec code in globals()
48 exec code in globals()
48
49
49 #--------------------------------------------------------------------------
50 #--------------------------------------------------------------------------
50 # Decorators for Client methods
51 # Decorators for Client methods
51 #--------------------------------------------------------------------------
52 #--------------------------------------------------------------------------
52
53
53 @decorator
54 @decorator
54 def spinfirst(f, self, *args, **kwargs):
55 def spinfirst(f, self, *args, **kwargs):
55 """Call spin() to sync state prior to calling the method."""
56 """Call spin() to sync state prior to calling the method."""
56 self.spin()
57 self.spin()
57 return f(self, *args, **kwargs)
58 return f(self, *args, **kwargs)
58
59
59 @decorator
60 @decorator
60 def defaultblock(f, self, *args, **kwargs):
61 def defaultblock(f, self, *args, **kwargs):
61 """Default to self.block; preserve self.block."""
62 """Default to self.block; preserve self.block."""
62 block = kwargs.get('block',None)
63 block = kwargs.get('block',None)
63 block = self.block if block is None else block
64 block = self.block if block is None else block
64 saveblock = self.block
65 saveblock = self.block
65 self.block = block
66 self.block = block
66 ret = f(self, *args, **kwargs)
67 ret = f(self, *args, **kwargs)
67 self.block = saveblock
68 self.block = saveblock
68 return ret
69 return ret
69
70
70 def remote(client, bound=False, block=None, targets=None):
71 def remote(client, bound=False, block=None, targets=None):
71 """Turn a function into a remote function.
72 """Turn a function into a remote function.
72
73
73 This method can be used for map:
74 This method can be used for map:
74
75
75 >>> @remote(client,block=True)
76 >>> @remote(client,block=True)
76 def func(a)
77 def func(a)
77 """
78 """
78 def remote_function(f):
79 def remote_function(f):
79 return RemoteFunction(client, f, bound, block, targets)
80 return RemoteFunction(client, f, bound, block, targets)
80 return remote_function
81 return remote_function
81
82
82 #--------------------------------------------------------------------------
83 #--------------------------------------------------------------------------
83 # Classes
84 # Classes
84 #--------------------------------------------------------------------------
85 #--------------------------------------------------------------------------
85
86
86 class RemoteFunction(object):
87 class RemoteFunction(object):
87 """Turn an existing function into a remote function"""
88 """Turn an existing function into a remote function"""
88
89
89 def __init__(self, client, f, bound=False, block=None, targets=None):
90 def __init__(self, client, f, bound=False, block=None, targets=None):
90 self.client = client
91 self.client = client
91 self.func = f
92 self.func = f
92 self.block=block
93 self.block=block
93 self.bound=bound
94 self.bound=bound
94 self.targets=targets
95 self.targets=targets
95
96
96 def __call__(self, *args, **kwargs):
97 def __call__(self, *args, **kwargs):
97 return self.client.apply(self.func, args=args, kwargs=kwargs,
98 return self.client.apply(self.func, args=args, kwargs=kwargs,
98 block=self.block, targets=self.targets, bound=self.bound)
99 block=self.block, targets=self.targets, bound=self.bound)
99
100
100
101
101 class AbortedTask(object):
102 class AbortedTask(object):
102 """A basic wrapper object describing an aborted task."""
103 """A basic wrapper object describing an aborted task."""
103 def __init__(self, msg_id):
104 def __init__(self, msg_id):
104 self.msg_id = msg_id
105 self.msg_id = msg_id
105
106
106 class ControllerError(Exception):
107 class ControllerError(Exception):
107 def __init__(self, etype, evalue, tb):
108 def __init__(self, etype, evalue, tb):
108 self.etype = etype
109 self.etype = etype
109 self.evalue = evalue
110 self.evalue = evalue
110 self.traceback=tb
111 self.traceback=tb
111
112
112 class Client(object):
113 class Client(object):
113 """A semi-synchronous client to the IPython ZMQ controller
114 """A semi-synchronous client to the IPython ZMQ controller
114
115
115 Parameters
116 Parameters
116 ----------
117 ----------
117
118
118 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
119 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
119 The address of the controller's registration socket.
120 The address of the controller's registration socket.
120
121 [Default: 'tcp://127.0.0.1:10101']
122 context : zmq.Context
123 Pass an existing zmq.Context instance, otherwise the client will create its own
124 username : bytes
125 set username to be passed to the Session object
126 debug : bool
127 flag for lots of message printing for debug purposes
128
129 #-------------- ssh related args ----------------
130 # These are args for configuring the ssh tunnel to be used
131 # credentials are used to forward connections over ssh to the Controller
132 # Note that the ip given in `addr` needs to be relative to sshserver
133 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
134 # and set sshserver as the same machine the Controller is on. However,
135 # the only requirement is that sshserver is able to see the Controller
136 # (i.e. is within the same trusted network).
137
138 sshserver : str
139 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
140 If keyfile or password is specified, and this is not, it will default to
141 the ip given in addr.
142 keyfile : str; path to public key file
143 This specifies a key to be used in ssh login, default None.
144 Regular default ssh keys will be used without specifying this argument.
145 password : str;
146 Your ssh password to sshserver. Note that if this is left None,
147 you will be prompted for it if passwordless key based login is unavailable.
121
148
122 Attributes
149 Attributes
123 ----------
150 ----------
124 ids : set of int engine IDs
151 ids : set of int engine IDs
125 requesting the ids attribute always synchronizes
152 requesting the ids attribute always synchronizes
126 the registration state. To request ids without synchronization,
153 the registration state. To request ids without synchronization,
127 use semi-private _ids.
154 use semi-private _ids.
128
155
129 history : list of msg_ids
156 history : list of msg_ids
130 a list of msg_ids, keeping track of all the execution
157 a list of msg_ids, keeping track of all the execution
131 messages you have submitted in order.
158 messages you have submitted in order.
132
159
133 outstanding : set of msg_ids
160 outstanding : set of msg_ids
134 a set of msg_ids that have been submitted, but whose
161 a set of msg_ids that have been submitted, but whose
135 results have not yet been received.
162 results have not yet been received.
136
163
137 results : dict
164 results : dict
138 a dict of all our results, keyed by msg_id
165 a dict of all our results, keyed by msg_id
139
166
140 block : bool
167 block : bool
141 determines default behavior when block not specified
168 determines default behavior when block not specified
142 in execution methods
169 in execution methods
143
170
144 Methods
171 Methods
145 -------
172 -------
146 spin : flushes incoming results and registration state changes
173 spin : flushes incoming results and registration state changes
147 control methods spin, and requesting `ids` also ensures up to date
174 control methods spin, and requesting `ids` also ensures up to date
148
175
149 barrier : wait on one or more msg_ids
176 barrier : wait on one or more msg_ids
150
177
151 execution methods: apply/apply_bound/apply_to/applu_bount
178 execution methods: apply/apply_bound/apply_to/applu_bount
152 legacy: execute, run
179 legacy: execute, run
153
180
154 query methods: queue_status, get_result, purge
181 query methods: queue_status, get_result, purge
155
182
156 control methods: abort, kill
183 control methods: abort, kill
157
184
158 """
185 """
159
186
160
187
161 _connected=False
188 _connected=False
189 _ssh=False
162 _engines=None
190 _engines=None
163 _addr='tcp://127.0.0.1:10101'
191 _addr='tcp://127.0.0.1:10101'
164 _registration_socket=None
192 _registration_socket=None
165 _query_socket=None
193 _query_socket=None
166 _control_socket=None
194 _control_socket=None
167 _notification_socket=None
195 _notification_socket=None
168 _mux_socket=None
196 _mux_socket=None
169 _task_socket=None
197 _task_socket=None
170 block = False
198 block = False
171 outstanding=None
199 outstanding=None
172 results = None
200 results = None
173 history = None
201 history = None
174 debug = False
202 debug = False
175
203
176 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False):
204 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
205 sshserver=None, keyfile=None, password=None, paramiko=None):
177 if context is None:
206 if context is None:
178 context = zmq.Context()
207 context = zmq.Context()
179 self.context = context
208 self.context = context
180 self._addr = addr
209 self._addr = addr
210 self._ssh = bool(sshserver or keyfile or password)
211 if self._ssh and sshserver is None:
212 # default to the same
213 sshserver = addr.split('://')[1].split(':')[0]
214 if self._ssh and password is None:
215 if tunnel.try_passwordless_ssh(sshserver, keyfile, paramiko):
216 password=False
217 else:
218 password = getpass("SSH Password for %s: "%sshserver)
219 ssh_kwargs = dict(keyfile=keyfile, password=password, paramiko=paramiko)
220
181 if username is None:
221 if username is None:
182 self.session = ss.StreamSession()
222 self.session = ss.StreamSession()
183 else:
223 else:
184 self.session = ss.StreamSession(username)
224 self.session = ss.StreamSession(username)
185 self._registration_socket = self.context.socket(zmq.PAIR)
225 self._registration_socket = self.context.socket(zmq.XREQ)
186 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
226 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
187 self._registration_socket.connect(addr)
227 if self._ssh:
228 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
229 else:
230 self._registration_socket.connect(addr)
188 self._engines = {}
231 self._engines = {}
189 self._ids = set()
232 self._ids = set()
190 self.outstanding=set()
233 self.outstanding=set()
191 self.results = {}
234 self.results = {}
192 self.history = []
235 self.history = []
193 self.debug = debug
236 self.debug = debug
194 self.session.debug = debug
237 self.session.debug = debug
195
238
196 self._notification_handlers = {'registration_notification' : self._register_engine,
239 self._notification_handlers = {'registration_notification' : self._register_engine,
197 'unregistration_notification' : self._unregister_engine,
240 'unregistration_notification' : self._unregister_engine,
198 }
241 }
199 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
242 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
200 'apply_reply' : self._handle_apply_reply}
243 'apply_reply' : self._handle_apply_reply}
201 self._connect()
244 self._connect(sshserver, ssh_kwargs)
202
245
203
246
204 @property
247 @property
205 def ids(self):
248 def ids(self):
206 """Always up to date ids property."""
249 """Always up to date ids property."""
207 self._flush_notifications()
250 self._flush_notifications()
208 return self._ids
251 return self._ids
209
252
210 def _update_engines(self, engines):
253 def _update_engines(self, engines):
211 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
254 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
212 for k,v in engines.iteritems():
255 for k,v in engines.iteritems():
213 eid = int(k)
256 eid = int(k)
214 self._engines[eid] = bytes(v) # force not unicode
257 self._engines[eid] = bytes(v) # force not unicode
215 self._ids.add(eid)
258 self._ids.add(eid)
216
259
217 def _build_targets(self, targets):
260 def _build_targets(self, targets):
218 """Turn valid target IDs or 'all' into two lists:
261 """Turn valid target IDs or 'all' into two lists:
219 (int_ids, uuids).
262 (int_ids, uuids).
220 """
263 """
221 if targets is None:
264 if targets is None:
222 targets = self._ids
265 targets = self._ids
223 elif isinstance(targets, str):
266 elif isinstance(targets, str):
224 if targets.lower() == 'all':
267 if targets.lower() == 'all':
225 targets = self._ids
268 targets = self._ids
226 else:
269 else:
227 raise TypeError("%r not valid str target, must be 'all'"%(targets))
270 raise TypeError("%r not valid str target, must be 'all'"%(targets))
228 elif isinstance(targets, int):
271 elif isinstance(targets, int):
229 targets = [targets]
272 targets = [targets]
230 return [self._engines[t] for t in targets], list(targets)
273 return [self._engines[t] for t in targets], list(targets)
231
274
232 def _connect(self):
275 def _connect(self, sshserver, ssh_kwargs):
233 """setup all our socket connections to the controller. This is called from
276 """setup all our socket connections to the controller. This is called from
234 __init__."""
277 __init__."""
235 if self._connected:
278 if self._connected:
236 return
279 return
237 self._connected=True
280 self._connected=True
281
282 def connect_socket(s, addr):
283 if self._ssh:
284 return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
285 else:
286 return s.connect(addr)
287
238 self.session.send(self._registration_socket, 'connection_request')
288 self.session.send(self._registration_socket, 'connection_request')
239 idents,msg = self.session.recv(self._registration_socket,mode=0)
289 idents,msg = self.session.recv(self._registration_socket,mode=0)
240 if self.debug:
290 if self.debug:
241 pprint(msg)
291 pprint(msg)
242 msg = ss.Message(msg)
292 msg = ss.Message(msg)
243 content = msg.content
293 content = msg.content
244 if content.status == 'ok':
294 if content.status == 'ok':
245 if content.queue:
295 if content.queue:
246 self._mux_socket = self.context.socket(zmq.PAIR)
296 self._mux_socket = self.context.socket(zmq.PAIR)
247 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
297 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
248 self._mux_socket.connect(content.queue)
298 connect_socket(self._mux_socket, content.queue)
249 if content.task:
299 if content.task:
250 self._task_socket = self.context.socket(zmq.PAIR)
300 self._task_socket = self.context.socket(zmq.PAIR)
251 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
301 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
252 self._task_socket.connect(content.task)
302 connect_socket(self._task_socket, content.task)
253 if content.notification:
303 if content.notification:
254 self._notification_socket = self.context.socket(zmq.SUB)
304 self._notification_socket = self.context.socket(zmq.SUB)
255 self._notification_socket.connect(content.notification)
305 connect_socket(self._notification_socket, content.notification)
256 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
306 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
257 if content.query:
307 if content.query:
258 self._query_socket = self.context.socket(zmq.PAIR)
308 self._query_socket = self.context.socket(zmq.PAIR)
259 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
309 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
260 self._query_socket.connect(content.query)
310 connect_socket(self._query_socket, content.query)
261 if content.control:
311 if content.control:
262 self._control_socket = self.context.socket(zmq.PAIR)
312 self._control_socket = self.context.socket(zmq.PAIR)
263 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
313 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
264 self._control_socket.connect(content.control)
314 connect_socket(self._control_socket, content.control)
265 self._update_engines(dict(content.engines))
315 self._update_engines(dict(content.engines))
266
316
267 else:
317 else:
268 self._connected = False
318 self._connected = False
269 raise Exception("Failed to connect!")
319 raise Exception("Failed to connect!")
270
320
271 #--------------------------------------------------------------------------
321 #--------------------------------------------------------------------------
272 # handlers and callbacks for incoming messages
322 # handlers and callbacks for incoming messages
273 #--------------------------------------------------------------------------
323 #--------------------------------------------------------------------------
274
324
275 def _register_engine(self, msg):
325 def _register_engine(self, msg):
276 """Register a new engine, and update our connection info."""
326 """Register a new engine, and update our connection info."""
277 content = msg['content']
327 content = msg['content']
278 eid = content['id']
328 eid = content['id']
279 d = {eid : content['queue']}
329 d = {eid : content['queue']}
280 self._update_engines(d)
330 self._update_engines(d)
281 self._ids.add(int(eid))
331 self._ids.add(int(eid))
282
332
283 def _unregister_engine(self, msg):
333 def _unregister_engine(self, msg):
284 """Unregister an engine that has died."""
334 """Unregister an engine that has died."""
285 content = msg['content']
335 content = msg['content']
286 eid = int(content['id'])
336 eid = int(content['id'])
287 if eid in self._ids:
337 if eid in self._ids:
288 self._ids.remove(eid)
338 self._ids.remove(eid)
289 self._engines.pop(eid)
339 self._engines.pop(eid)
290
340
291 def _handle_execute_reply(self, msg):
341 def _handle_execute_reply(self, msg):
292 """Save the reply to an execute_request into our results."""
342 """Save the reply to an execute_request into our results."""
293 parent = msg['parent_header']
343 parent = msg['parent_header']
294 msg_id = parent['msg_id']
344 msg_id = parent['msg_id']
295 if msg_id not in self.outstanding:
345 if msg_id not in self.outstanding:
296 print("got unknown result: %s"%msg_id)
346 print("got unknown result: %s"%msg_id)
297 else:
347 else:
298 self.outstanding.remove(msg_id)
348 self.outstanding.remove(msg_id)
299 self.results[msg_id] = ss.unwrap_exception(msg['content'])
349 self.results[msg_id] = ss.unwrap_exception(msg['content'])
300
350
301 def _handle_apply_reply(self, msg):
351 def _handle_apply_reply(self, msg):
302 """Save the reply to an apply_request into our results."""
352 """Save the reply to an apply_request into our results."""
303 parent = msg['parent_header']
353 parent = msg['parent_header']
304 msg_id = parent['msg_id']
354 msg_id = parent['msg_id']
305 if msg_id not in self.outstanding:
355 if msg_id not in self.outstanding:
306 print ("got unknown result: %s"%msg_id)
356 print ("got unknown result: %s"%msg_id)
307 else:
357 else:
308 self.outstanding.remove(msg_id)
358 self.outstanding.remove(msg_id)
309 content = msg['content']
359 content = msg['content']
310 if content['status'] == 'ok':
360 if content['status'] == 'ok':
311 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
361 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
312 elif content['status'] == 'aborted':
362 elif content['status'] == 'aborted':
313 self.results[msg_id] = AbortedTask(msg_id)
363 self.results[msg_id] = AbortedTask(msg_id)
314 elif content['status'] == 'resubmitted':
364 elif content['status'] == 'resubmitted':
315 # TODO: handle resubmission
365 # TODO: handle resubmission
316 pass
366 pass
317 else:
367 else:
318 self.results[msg_id] = ss.unwrap_exception(content)
368 self.results[msg_id] = ss.unwrap_exception(content)
319
369
320 def _flush_notifications(self):
370 def _flush_notifications(self):
321 """Flush notifications of engine registrations waiting
371 """Flush notifications of engine registrations waiting
322 in ZMQ queue."""
372 in ZMQ queue."""
323 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
373 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
324 while msg is not None:
374 while msg is not None:
325 if self.debug:
375 if self.debug:
326 pprint(msg)
376 pprint(msg)
327 msg = msg[-1]
377 msg = msg[-1]
328 msg_type = msg['msg_type']
378 msg_type = msg['msg_type']
329 handler = self._notification_handlers.get(msg_type, None)
379 handler = self._notification_handlers.get(msg_type, None)
330 if handler is None:
380 if handler is None:
331 raise Exception("Unhandled message type: %s"%msg.msg_type)
381 raise Exception("Unhandled message type: %s"%msg.msg_type)
332 else:
382 else:
333 handler(msg)
383 handler(msg)
334 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
384 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
335
385
336 def _flush_results(self, sock):
386 def _flush_results(self, sock):
337 """Flush task or queue results waiting in ZMQ queue."""
387 """Flush task or queue results waiting in ZMQ queue."""
338 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
388 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
339 while msg is not None:
389 while msg is not None:
340 if self.debug:
390 if self.debug:
341 pprint(msg)
391 pprint(msg)
342 msg = msg[-1]
392 msg = msg[-1]
343 msg_type = msg['msg_type']
393 msg_type = msg['msg_type']
344 handler = self._queue_handlers.get(msg_type, None)
394 handler = self._queue_handlers.get(msg_type, None)
345 if handler is None:
395 if handler is None:
346 raise Exception("Unhandled message type: %s"%msg.msg_type)
396 raise Exception("Unhandled message type: %s"%msg.msg_type)
347 else:
397 else:
348 handler(msg)
398 handler(msg)
349 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
399 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
350
400
351 def _flush_control(self, sock):
401 def _flush_control(self, sock):
352 """Flush replies from the control channel waiting
402 """Flush replies from the control channel waiting
353 in the ZMQ queue.
403 in the ZMQ queue.
354
404
355 Currently: ignore them."""
405 Currently: ignore them."""
356 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
406 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
357 while msg is not None:
407 while msg is not None:
358 if self.debug:
408 if self.debug:
359 pprint(msg)
409 pprint(msg)
360 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
410 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
361
411
362 #--------------------------------------------------------------------------
412 #--------------------------------------------------------------------------
363 # getitem
413 # getitem
364 #--------------------------------------------------------------------------
414 #--------------------------------------------------------------------------
365
415
366 def __getitem__(self, key):
416 def __getitem__(self, key):
367 """Dict access returns DirectView multiplexer objects or,
417 """Dict access returns DirectView multiplexer objects or,
368 if key is None, a LoadBalancedView."""
418 if key is None, a LoadBalancedView."""
369 if key is None:
419 if key is None:
370 return LoadBalancedView(self)
420 return LoadBalancedView(self)
371 if isinstance(key, int):
421 if isinstance(key, int):
372 if key not in self.ids:
422 if key not in self.ids:
373 raise IndexError("No such engine: %i"%key)
423 raise IndexError("No such engine: %i"%key)
374 return DirectView(self, key)
424 return DirectView(self, key)
375
425
376 if isinstance(key, slice):
426 if isinstance(key, slice):
377 indices = range(len(self.ids))[key]
427 indices = range(len(self.ids))[key]
378 ids = sorted(self._ids)
428 ids = sorted(self._ids)
379 key = [ ids[i] for i in indices ]
429 key = [ ids[i] for i in indices ]
380 # newkeys = sorted(self._ids)[thekeys[k]]
430 # newkeys = sorted(self._ids)[thekeys[k]]
381
431
382 if isinstance(key, (tuple, list, xrange)):
432 if isinstance(key, (tuple, list, xrange)):
383 _,targets = self._build_targets(list(key))
433 _,targets = self._build_targets(list(key))
384 return DirectView(self, targets)
434 return DirectView(self, targets)
385 else:
435 else:
386 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
436 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
387
437
388 #--------------------------------------------------------------------------
438 #--------------------------------------------------------------------------
389 # Begin public methods
439 # Begin public methods
390 #--------------------------------------------------------------------------
440 #--------------------------------------------------------------------------
391
441
392 def spin(self):
442 def spin(self):
393 """Flush any registration notifications and execution results
443 """Flush any registration notifications and execution results
394 waiting in the ZMQ queue.
444 waiting in the ZMQ queue.
395 """
445 """
396 if self._notification_socket:
446 if self._notification_socket:
397 self._flush_notifications()
447 self._flush_notifications()
398 if self._mux_socket:
448 if self._mux_socket:
399 self._flush_results(self._mux_socket)
449 self._flush_results(self._mux_socket)
400 if self._task_socket:
450 if self._task_socket:
401 self._flush_results(self._task_socket)
451 self._flush_results(self._task_socket)
402 if self._control_socket:
452 if self._control_socket:
403 self._flush_control(self._control_socket)
453 self._flush_control(self._control_socket)
404
454
405 def barrier(self, msg_ids=None, timeout=-1):
455 def barrier(self, msg_ids=None, timeout=-1):
406 """waits on one or more `msg_ids`, for up to `timeout` seconds.
456 """waits on one or more `msg_ids`, for up to `timeout` seconds.
407
457
408 Parameters
458 Parameters
409 ----------
459 ----------
410 msg_ids : int, str, or list of ints and/or strs
460 msg_ids : int, str, or list of ints and/or strs
411 ints are indices to self.history
461 ints are indices to self.history
412 strs are msg_ids
462 strs are msg_ids
413 default: wait on all outstanding messages
463 default: wait on all outstanding messages
414 timeout : float
464 timeout : float
415 a time in seconds, after which to give up.
465 a time in seconds, after which to give up.
416 default is -1, which means no timeout
466 default is -1, which means no timeout
417
467
418 Returns
468 Returns
419 -------
469 -------
420 True : when all msg_ids are done
470 True : when all msg_ids are done
421 False : timeout reached, some msg_ids still outstanding
471 False : timeout reached, some msg_ids still outstanding
422 """
472 """
423 tic = time.time()
473 tic = time.time()
424 if msg_ids is None:
474 if msg_ids is None:
425 theids = self.outstanding
475 theids = self.outstanding
426 else:
476 else:
427 if isinstance(msg_ids, (int, str)):
477 if isinstance(msg_ids, (int, str)):
428 msg_ids = [msg_ids]
478 msg_ids = [msg_ids]
429 theids = set()
479 theids = set()
430 for msg_id in msg_ids:
480 for msg_id in msg_ids:
431 if isinstance(msg_id, int):
481 if isinstance(msg_id, int):
432 msg_id = self.history[msg_id]
482 msg_id = self.history[msg_id]
433 theids.add(msg_id)
483 theids.add(msg_id)
434 self.spin()
484 self.spin()
435 while theids.intersection(self.outstanding):
485 while theids.intersection(self.outstanding):
436 if timeout >= 0 and ( time.time()-tic ) > timeout:
486 if timeout >= 0 and ( time.time()-tic ) > timeout:
437 break
487 break
438 time.sleep(1e-3)
488 time.sleep(1e-3)
439 self.spin()
489 self.spin()
440 return len(theids.intersection(self.outstanding)) == 0
490 return len(theids.intersection(self.outstanding)) == 0
441
491
442 #--------------------------------------------------------------------------
492 #--------------------------------------------------------------------------
443 # Control methods
493 # Control methods
444 #--------------------------------------------------------------------------
494 #--------------------------------------------------------------------------
445
495
446 @spinfirst
496 @spinfirst
447 @defaultblock
497 @defaultblock
448 def clear(self, targets=None, block=None):
498 def clear(self, targets=None, block=None):
449 """Clear the namespace in target(s)."""
499 """Clear the namespace in target(s)."""
450 targets = self._build_targets(targets)[0]
500 targets = self._build_targets(targets)[0]
451 for t in targets:
501 for t in targets:
452 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
502 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
453 error = False
503 error = False
454 if self.block:
504 if self.block:
455 for i in range(len(targets)):
505 for i in range(len(targets)):
456 idents,msg = self.session.recv(self._control_socket,0)
506 idents,msg = self.session.recv(self._control_socket,0)
457 if self.debug:
507 if self.debug:
458 pprint(msg)
508 pprint(msg)
459 if msg['content']['status'] != 'ok':
509 if msg['content']['status'] != 'ok':
460 error = ss.unwrap_exception(msg['content'])
510 error = ss.unwrap_exception(msg['content'])
461 if error:
511 if error:
462 return error
512 return error
463
513
464
514
465 @spinfirst
515 @spinfirst
466 @defaultblock
516 @defaultblock
467 def abort(self, msg_ids = None, targets=None, block=None):
517 def abort(self, msg_ids = None, targets=None, block=None):
468 """Abort the execution queues of target(s)."""
518 """Abort the execution queues of target(s)."""
469 targets = self._build_targets(targets)[0]
519 targets = self._build_targets(targets)[0]
470 if isinstance(msg_ids, basestring):
520 if isinstance(msg_ids, basestring):
471 msg_ids = [msg_ids]
521 msg_ids = [msg_ids]
472 content = dict(msg_ids=msg_ids)
522 content = dict(msg_ids=msg_ids)
473 for t in targets:
523 for t in targets:
474 self.session.send(self._control_socket, 'abort_request',
524 self.session.send(self._control_socket, 'abort_request',
475 content=content, ident=t)
525 content=content, ident=t)
476 error = False
526 error = False
477 if self.block:
527 if self.block:
478 for i in range(len(targets)):
528 for i in range(len(targets)):
479 idents,msg = self.session.recv(self._control_socket,0)
529 idents,msg = self.session.recv(self._control_socket,0)
480 if self.debug:
530 if self.debug:
481 pprint(msg)
531 pprint(msg)
482 if msg['content']['status'] != 'ok':
532 if msg['content']['status'] != 'ok':
483 error = ss.unwrap_exception(msg['content'])
533 error = ss.unwrap_exception(msg['content'])
484 if error:
534 if error:
485 return error
535 return error
486
536
487 @spinfirst
537 @spinfirst
488 @defaultblock
538 @defaultblock
489 def kill(self, targets=None, block=None):
539 def kill(self, targets=None, block=None):
490 """Terminates one or more engine processes."""
540 """Terminates one or more engine processes."""
491 targets = self._build_targets(targets)[0]
541 targets = self._build_targets(targets)[0]
492 for t in targets:
542 for t in targets:
493 self.session.send(self._control_socket, 'kill_request', content={},ident=t)
543 self.session.send(self._control_socket, 'kill_request', content={},ident=t)
494 error = False
544 error = False
495 if self.block:
545 if self.block:
496 for i in range(len(targets)):
546 for i in range(len(targets)):
497 idents,msg = self.session.recv(self._control_socket,0)
547 idents,msg = self.session.recv(self._control_socket,0)
498 if self.debug:
548 if self.debug:
499 pprint(msg)
549 pprint(msg)
500 if msg['content']['status'] != 'ok':
550 if msg['content']['status'] != 'ok':
501 error = ss.unwrap_exception(msg['content'])
551 error = ss.unwrap_exception(msg['content'])
502 if error:
552 if error:
503 return error
553 return error
504
554
505 #--------------------------------------------------------------------------
555 #--------------------------------------------------------------------------
506 # Execution methods
556 # Execution methods
507 #--------------------------------------------------------------------------
557 #--------------------------------------------------------------------------
508
558
509 @defaultblock
559 @defaultblock
510 def execute(self, code, targets='all', block=None):
560 def execute(self, code, targets='all', block=None):
511 """Executes `code` on `targets` in blocking or nonblocking manner.
561 """Executes `code` on `targets` in blocking or nonblocking manner.
512
562
513 Parameters
563 Parameters
514 ----------
564 ----------
515 code : str
565 code : str
516 the code string to be executed
566 the code string to be executed
517 targets : int/str/list of ints/strs
567 targets : int/str/list of ints/strs
518 the engines on which to execute
568 the engines on which to execute
519 default : all
569 default : all
520 block : bool
570 block : bool
521 whether or not to wait until done to return
571 whether or not to wait until done to return
522 default: self.block
572 default: self.block
523 """
573 """
524 # block = self.block if block is None else block
574 # block = self.block if block is None else block
525 # saveblock = self.block
575 # saveblock = self.block
526 # self.block = block
576 # self.block = block
527 result = self.apply(execute, (code,), targets=targets, block=block, bound=True)
577 result = self.apply(execute, (code,), targets=targets, block=block, bound=True)
528 # self.block = saveblock
578 # self.block = saveblock
529 return result
579 return result
530
580
531 def run(self, code, block=None):
581 def run(self, code, block=None):
532 """Runs `code` on an engine.
582 """Runs `code` on an engine.
533
583
534 Calls to this are load-balanced.
584 Calls to this are load-balanced.
535
585
536 Parameters
586 Parameters
537 ----------
587 ----------
538 code : str
588 code : str
539 the code string to be executed
589 the code string to be executed
540 block : bool
590 block : bool
541 whether or not to wait until done
591 whether or not to wait until done
542
592
543 """
593 """
544 result = self.apply(execute, (code,), targets=None, block=block, bound=False)
594 result = self.apply(execute, (code,), targets=None, block=block, bound=False)
545 return result
595 return result
546
596
547 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
597 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
548 after=None, follow=None):
598 after=None, follow=None):
549 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
599 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
550
600
551 This is the central execution command for the client.
601 This is the central execution command for the client.
552
602
553 Parameters
603 Parameters
554 ----------
604 ----------
555
605
556 f : function
606 f : function
557 The fuction to be called remotely
607 The fuction to be called remotely
558 args : tuple/list
608 args : tuple/list
559 The positional arguments passed to `f`
609 The positional arguments passed to `f`
560 kwargs : dict
610 kwargs : dict
561 The keyword arguments passed to `f`
611 The keyword arguments passed to `f`
562 bound : bool (default: True)
612 bound : bool (default: True)
563 Whether to execute in the Engine(s) namespace, or in a clean
613 Whether to execute in the Engine(s) namespace, or in a clean
564 namespace not affecting the engine.
614 namespace not affecting the engine.
565 block : bool (default: self.block)
615 block : bool (default: self.block)
566 Whether to wait for the result, or return immediately.
616 Whether to wait for the result, or return immediately.
567 False:
617 False:
568 returns msg_id(s)
618 returns msg_id(s)
569 if multiple targets:
619 if multiple targets:
570 list of ids
620 list of ids
571 True:
621 True:
572 returns actual result(s) of f(*args, **kwargs)
622 returns actual result(s) of f(*args, **kwargs)
573 if multiple targets:
623 if multiple targets:
574 dict of results, by engine ID
624 dict of results, by engine ID
575 targets : int,list of ints, 'all', None
625 targets : int,list of ints, 'all', None
576 Specify the destination of the job.
626 Specify the destination of the job.
577 if None:
627 if None:
578 Submit via Task queue for load-balancing.
628 Submit via Task queue for load-balancing.
579 if 'all':
629 if 'all':
580 Run on all active engines
630 Run on all active engines
581 if list:
631 if list:
582 Run on each specified engine
632 Run on each specified engine
583 if int:
633 if int:
584 Run on single engine
634 Run on single engine
585
635
586 after : Dependency or collection of msg_ids
636 after : Dependency or collection of msg_ids
587 Only for load-balanced execution (targets=None)
637 Only for load-balanced execution (targets=None)
588 Specify a list of msg_ids as a time-based dependency.
638 Specify a list of msg_ids as a time-based dependency.
589 This job will only be run *after* the dependencies
639 This job will only be run *after* the dependencies
590 have been met.
640 have been met.
591
641
592 follow : Dependency or collection of msg_ids
642 follow : Dependency or collection of msg_ids
593 Only for load-balanced execution (targets=None)
643 Only for load-balanced execution (targets=None)
594 Specify a list of msg_ids as a location-based dependency.
644 Specify a list of msg_ids as a location-based dependency.
595 This job will only be run on an engine where this dependency
645 This job will only be run on an engine where this dependency
596 is met.
646 is met.
597
647
598 Returns
648 Returns
599 -------
649 -------
600 if block is False:
650 if block is False:
601 if single target:
651 if single target:
602 return msg_id
652 return msg_id
603 else:
653 else:
604 return list of msg_ids
654 return list of msg_ids
605 ? (should this be dict like block=True) ?
655 ? (should this be dict like block=True) ?
606 else:
656 else:
607 if single target:
657 if single target:
608 return result of f(*args, **kwargs)
658 return result of f(*args, **kwargs)
609 else:
659 else:
610 return dict of results, keyed by engine
660 return dict of results, keyed by engine
611 """
661 """
612
662
613 # defaults:
663 # defaults:
614 block = block if block is not None else self.block
664 block = block if block is not None else self.block
615 args = args if args is not None else []
665 args = args if args is not None else []
616 kwargs = kwargs if kwargs is not None else {}
666 kwargs = kwargs if kwargs is not None else {}
617
667
618 # enforce types of f,args,kwrags
668 # enforce types of f,args,kwrags
619 if not callable(f):
669 if not callable(f):
620 raise TypeError("f must be callable, not %s"%type(f))
670 raise TypeError("f must be callable, not %s"%type(f))
621 if not isinstance(args, (tuple, list)):
671 if not isinstance(args, (tuple, list)):
622 raise TypeError("args must be tuple or list, not %s"%type(args))
672 raise TypeError("args must be tuple or list, not %s"%type(args))
623 if not isinstance(kwargs, dict):
673 if not isinstance(kwargs, dict):
624 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
674 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
625
675
626 options = dict(bound=bound, block=block, after=after, follow=follow)
676 options = dict(bound=bound, block=block, after=after, follow=follow)
627
677
628 if targets is None:
678 if targets is None:
629 return self._apply_balanced(f, args, kwargs, **options)
679 return self._apply_balanced(f, args, kwargs, **options)
630 else:
680 else:
631 return self._apply_direct(f, args, kwargs, targets=targets, **options)
681 return self._apply_direct(f, args, kwargs, targets=targets, **options)
632
682
633 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
683 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
634 after=None, follow=None):
684 after=None, follow=None):
635 """The underlying method for applying functions in a load balanced
685 """The underlying method for applying functions in a load balanced
636 manner, via the task queue."""
686 manner, via the task queue."""
637 if isinstance(after, Dependency):
687 if isinstance(after, Dependency):
638 after = after.as_dict()
688 after = after.as_dict()
639 elif after is None:
689 elif after is None:
640 after = []
690 after = []
641 if isinstance(follow, Dependency):
691 if isinstance(follow, Dependency):
642 follow = follow.as_dict()
692 follow = follow.as_dict()
643 elif follow is None:
693 elif follow is None:
644 follow = []
694 follow = []
645 subheader = dict(after=after, follow=follow)
695 subheader = dict(after=after, follow=follow)
646
696
647 bufs = ss.pack_apply_message(f,args,kwargs)
697 bufs = ss.pack_apply_message(f,args,kwargs)
648 content = dict(bound=bound)
698 content = dict(bound=bound)
649 msg = self.session.send(self._task_socket, "apply_request",
699 msg = self.session.send(self._task_socket, "apply_request",
650 content=content, buffers=bufs, subheader=subheader)
700 content=content, buffers=bufs, subheader=subheader)
651 msg_id = msg['msg_id']
701 msg_id = msg['msg_id']
652 self.outstanding.add(msg_id)
702 self.outstanding.add(msg_id)
653 self.history.append(msg_id)
703 self.history.append(msg_id)
654 if block:
704 if block:
655 self.barrier(msg_id)
705 self.barrier(msg_id)
656 return self.results[msg_id]
706 return self.results[msg_id]
657 else:
707 else:
658 return msg_id
708 return msg_id
659
709
660 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
710 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
661 after=None, follow=None):
711 after=None, follow=None):
662 """Then underlying method for applying functions to specific engines
712 """Then underlying method for applying functions to specific engines
663 via the MUX queue."""
713 via the MUX queue."""
664
714
665 queues,targets = self._build_targets(targets)
715 queues,targets = self._build_targets(targets)
666 bufs = ss.pack_apply_message(f,args,kwargs)
716 bufs = ss.pack_apply_message(f,args,kwargs)
667 if isinstance(after, Dependency):
717 if isinstance(after, Dependency):
668 after = after.as_dict()
718 after = after.as_dict()
669 elif after is None:
719 elif after is None:
670 after = []
720 after = []
671 if isinstance(follow, Dependency):
721 if isinstance(follow, Dependency):
672 follow = follow.as_dict()
722 follow = follow.as_dict()
673 elif follow is None:
723 elif follow is None:
674 follow = []
724 follow = []
675 subheader = dict(after=after, follow=follow)
725 subheader = dict(after=after, follow=follow)
676 content = dict(bound=bound)
726 content = dict(bound=bound)
677 msg_ids = []
727 msg_ids = []
678 for queue in queues:
728 for queue in queues:
679 msg = self.session.send(self._mux_socket, "apply_request",
729 msg = self.session.send(self._mux_socket, "apply_request",
680 content=content, buffers=bufs,ident=queue, subheader=subheader)
730 content=content, buffers=bufs,ident=queue, subheader=subheader)
681 msg_id = msg['msg_id']
731 msg_id = msg['msg_id']
682 self.outstanding.add(msg_id)
732 self.outstanding.add(msg_id)
683 self.history.append(msg_id)
733 self.history.append(msg_id)
684 msg_ids.append(msg_id)
734 msg_ids.append(msg_id)
685 if block:
735 if block:
686 self.barrier(msg_ids)
736 self.barrier(msg_ids)
687 else:
737 else:
688 if len(msg_ids) == 1:
738 if len(msg_ids) == 1:
689 return msg_ids[0]
739 return msg_ids[0]
690 else:
740 else:
691 return msg_ids
741 return msg_ids
692 if len(msg_ids) == 1:
742 if len(msg_ids) == 1:
693 return self.results[msg_ids[0]]
743 return self.results[msg_ids[0]]
694 else:
744 else:
695 result = {}
745 result = {}
696 for target,mid in zip(targets, msg_ids):
746 for target,mid in zip(targets, msg_ids):
697 result[target] = self.results[mid]
747 result[target] = self.results[mid]
698 return result
748 return result
699
749
700 #--------------------------------------------------------------------------
750 #--------------------------------------------------------------------------
701 # Data movement
751 # Data movement
702 #--------------------------------------------------------------------------
752 #--------------------------------------------------------------------------
703
753
704 @defaultblock
754 @defaultblock
705 def push(self, ns, targets=None, block=None):
755 def push(self, ns, targets=None, block=None):
706 """Push the contents of `ns` into the namespace on `target`"""
756 """Push the contents of `ns` into the namespace on `target`"""
707 if not isinstance(ns, dict):
757 if not isinstance(ns, dict):
708 raise TypeError("Must be a dict, not %s"%type(ns))
758 raise TypeError("Must be a dict, not %s"%type(ns))
709 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
759 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
710 return result
760 return result
711
761
712 @defaultblock
762 @defaultblock
713 def pull(self, keys, targets=None, block=True):
763 def pull(self, keys, targets=None, block=True):
714 """Pull objects from `target`'s namespace by `keys`"""
764 """Pull objects from `target`'s namespace by `keys`"""
715 if isinstance(keys, str):
765 if isinstance(keys, str):
716 pass
766 pass
717 elif isistance(keys, (list,tuple,set)):
767 elif isistance(keys, (list,tuple,set)):
718 for key in keys:
768 for key in keys:
719 if not isinstance(key, str):
769 if not isinstance(key, str):
720 raise TypeError
770 raise TypeError
721 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
771 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
722 return result
772 return result
723
773
724 #--------------------------------------------------------------------------
774 #--------------------------------------------------------------------------
725 # Query methods
775 # Query methods
726 #--------------------------------------------------------------------------
776 #--------------------------------------------------------------------------
727
777
728 @spinfirst
778 @spinfirst
729 def get_results(self, msg_ids, status_only=False):
779 def get_results(self, msg_ids, status_only=False):
730 """Returns the result of the execute or task request with `msg_ids`.
780 """Returns the result of the execute or task request with `msg_ids`.
731
781
732 Parameters
782 Parameters
733 ----------
783 ----------
734 msg_ids : list of ints or msg_ids
784 msg_ids : list of ints or msg_ids
735 if int:
785 if int:
736 Passed as index to self.history for convenience.
786 Passed as index to self.history for convenience.
737 status_only : bool (default: False)
787 status_only : bool (default: False)
738 if False:
788 if False:
739 return the actual results
789 return the actual results
740 """
790 """
741 if not isinstance(msg_ids, (list,tuple)):
791 if not isinstance(msg_ids, (list,tuple)):
742 msg_ids = [msg_ids]
792 msg_ids = [msg_ids]
743 theids = []
793 theids = []
744 for msg_id in msg_ids:
794 for msg_id in msg_ids:
745 if isinstance(msg_id, int):
795 if isinstance(msg_id, int):
746 msg_id = self.history[msg_id]
796 msg_id = self.history[msg_id]
747 if not isinstance(msg_id, str):
797 if not isinstance(msg_id, str):
748 raise TypeError("msg_ids must be str, not %r"%msg_id)
798 raise TypeError("msg_ids must be str, not %r"%msg_id)
749 theids.append(msg_id)
799 theids.append(msg_id)
750
800
751 completed = []
801 completed = []
752 local_results = {}
802 local_results = {}
753 for msg_id in list(theids):
803 for msg_id in list(theids):
754 if msg_id in self.results:
804 if msg_id in self.results:
755 completed.append(msg_id)
805 completed.append(msg_id)
756 local_results[msg_id] = self.results[msg_id]
806 local_results[msg_id] = self.results[msg_id]
757 theids.remove(msg_id)
807 theids.remove(msg_id)
758
808
759 if theids: # some not locally cached
809 if theids: # some not locally cached
760 content = dict(msg_ids=theids, status_only=status_only)
810 content = dict(msg_ids=theids, status_only=status_only)
761 msg = self.session.send(self._query_socket, "result_request", content=content)
811 msg = self.session.send(self._query_socket, "result_request", content=content)
762 zmq.select([self._query_socket], [], [])
812 zmq.select([self._query_socket], [], [])
763 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
813 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
764 if self.debug:
814 if self.debug:
765 pprint(msg)
815 pprint(msg)
766 content = msg['content']
816 content = msg['content']
767 if content['status'] != 'ok':
817 if content['status'] != 'ok':
768 raise ss.unwrap_exception(content)
818 raise ss.unwrap_exception(content)
769 else:
819 else:
770 content = dict(completed=[],pending=[])
820 content = dict(completed=[],pending=[])
771 if not status_only:
821 if not status_only:
772 # load cached results into result:
822 # load cached results into result:
773 content['completed'].extend(completed)
823 content['completed'].extend(completed)
774 content.update(local_results)
824 content.update(local_results)
775 # update cache with results:
825 # update cache with results:
776 for msg_id in msg_ids:
826 for msg_id in msg_ids:
777 if msg_id in content['completed']:
827 if msg_id in content['completed']:
778 self.results[msg_id] = content[msg_id]
828 self.results[msg_id] = content[msg_id]
779 return content
829 return content
780
830
781 @spinfirst
831 @spinfirst
782 def queue_status(self, targets=None, verbose=False):
832 def queue_status(self, targets=None, verbose=False):
783 """Fetch the status of engine queues.
833 """Fetch the status of engine queues.
784
834
785 Parameters
835 Parameters
786 ----------
836 ----------
787 targets : int/str/list of ints/strs
837 targets : int/str/list of ints/strs
788 the engines on which to execute
838 the engines on which to execute
789 default : all
839 default : all
790 verbose : bool
840 verbose : bool
791 whether to return lengths only, or lists of ids for each element
841 whether to return lengths only, or lists of ids for each element
792 """
842 """
793 targets = self._build_targets(targets)[1]
843 targets = self._build_targets(targets)[1]
794 content = dict(targets=targets, verbose=verbose)
844 content = dict(targets=targets, verbose=verbose)
795 self.session.send(self._query_socket, "queue_request", content=content)
845 self.session.send(self._query_socket, "queue_request", content=content)
796 idents,msg = self.session.recv(self._query_socket, 0)
846 idents,msg = self.session.recv(self._query_socket, 0)
797 if self.debug:
847 if self.debug:
798 pprint(msg)
848 pprint(msg)
799 content = msg['content']
849 content = msg['content']
800 status = content.pop('status')
850 status = content.pop('status')
801 if status != 'ok':
851 if status != 'ok':
802 raise ss.unwrap_exception(content)
852 raise ss.unwrap_exception(content)
803 return content
853 return content
804
854
805 @spinfirst
855 @spinfirst
806 def purge_results(self, msg_ids=[], targets=[]):
856 def purge_results(self, msg_ids=[], targets=[]):
807 """Tell the controller to forget results.
857 """Tell the controller to forget results.
808
858
809 Individual results can be purged by msg_id, or the entire
859 Individual results can be purged by msg_id, or the entire
810 history of specific targets can
860 history of specific targets can
811
861
812 Parameters
862 Parameters
813 ----------
863 ----------
814 targets : int/str/list of ints/strs
864 targets : int/str/list of ints/strs
815 the targets
865 the targets
816 default : None
866 default : None
817 """
867 """
818 if not targets and not msg_ids:
868 if not targets and not msg_ids:
819 raise ValueError
869 raise ValueError
820 if targets:
870 if targets:
821 targets = self._build_targets(targets)[1]
871 targets = self._build_targets(targets)[1]
822 content = dict(targets=targets, msg_ids=msg_ids)
872 content = dict(targets=targets, msg_ids=msg_ids)
823 self.session.send(self._query_socket, "purge_request", content=content)
873 self.session.send(self._query_socket, "purge_request", content=content)
824 idents, msg = self.session.recv(self._query_socket, 0)
874 idents, msg = self.session.recv(self._query_socket, 0)
825 if self.debug:
875 if self.debug:
826 pprint(msg)
876 pprint(msg)
827 content = msg['content']
877 content = msg['content']
828 if content['status'] != 'ok':
878 if content['status'] != 'ok':
829 raise ss.unwrap_exception(content)
879 raise ss.unwrap_exception(content)
830
880
831 class AsynClient(Client):
881 class AsynClient(Client):
832 """An Asynchronous client, using the Tornado Event Loop.
882 """An Asynchronous client, using the Tornado Event Loop.
833 !!!unfinished!!!"""
883 !!!unfinished!!!"""
834 io_loop = None
884 io_loop = None
835 _queue_stream = None
885 _queue_stream = None
836 _notifier_stream = None
886 _notifier_stream = None
837 _task_stream = None
887 _task_stream = None
838 _control_stream = None
888 _control_stream = None
839
889
840 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
890 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
841 Client.__init__(self, addr, context, username, debug)
891 Client.__init__(self, addr, context, username, debug)
842 if io_loop is None:
892 if io_loop is None:
843 io_loop = ioloop.IOLoop.instance()
893 io_loop = ioloop.IOLoop.instance()
844 self.io_loop = io_loop
894 self.io_loop = io_loop
845
895
846 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
896 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
847 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
897 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
848 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
898 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
849 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
899 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
850
900
851 def spin(self):
901 def spin(self):
852 for stream in (self.queue_stream, self.notifier_stream,
902 for stream in (self.queue_stream, self.notifier_stream,
853 self.task_stream, self.control_stream):
903 self.task_stream, self.control_stream):
854 stream.flush()
904 stream.flush()
855 No newline at end of file
905
@@ -1,123 +1,300
1 """Basic ssh tunneling utilities."""
1
2
3 #-----------------------------------------------------------------------------
4 # Copyright (C) 2008-2010 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-----------------------------------------------------------------------------
2
9
3 #-----------------------------------------
10
11
12 #-----------------------------------------------------------------------------
4 # Imports
13 # Imports
5 #-----------------------------------------
14 #-----------------------------------------------------------------------------
6
15
7 from __future__ import print_function
16 from __future__ import print_function
8
17
9 import os,sys
18 import os,sys, atexit
10 from multiprocessing import Process
19 from multiprocessing import Process
11 from getpass import getpass, getuser
20 from getpass import getpass, getuser
12
21
13 try:
22 try:
14 import paramiko
23 import paramiko
15 except ImportError:
24 except ImportError:
16 paramiko = None
25 paramiko = None
17 else:
26 else:
18 from forward import forward_tunnel
27 from forward import forward_tunnel
28
29 try:
30 from IPython.external import pexpect
31 except ImportError:
32 pexpect = None
33
34 from IPython.zmq.parallel.entry_point import select_random_ports
35
36 #-----------------------------------------------------------------------------
37 # Code
38 #-----------------------------------------------------------------------------
39
40 #-----------------------------------------------------------------------------
41 # Check for passwordless login
42 #-----------------------------------------------------------------------------
43
44 def try_passwordless_ssh(server, keyfile, paramiko=None):
45 """Attempt to make an ssh connection without a password.
46 This is mainly used for requiring password input only once
47 when many tunnels may be connected to the same server.
19
48
20 from IPython.external import pexpect
49 If paramiko is None, the default for the platform is chosen.
50 """
51 if paramiko is None:
52 paramiko = sys.platform == 'win32'
53 if not paramiko:
54 f = _try_passwordless_openssh
55 else:
56 f = _try_passwordless_paramiko
57 return f(server, keyfile)
21
58
59 def _try_passwordless_openssh(server, keyfile):
60 """Try passwordless login with shell ssh command."""
61 if pexpect is None:
62 raise ImportError("pexpect unavailable, use paramiko")
63 cmd = 'ssh -f '+ server
64 if keyfile:
65 cmd += ' -i ' + keyfile
66 cmd += ' exit'
67 p = pexpect.spawn(cmd)
68 while True:
69 try:
70 p.expect('[Ppassword]:', timeout=.1)
71 except pexpect.TIMEOUT:
72 continue
73 except pexpect.EOF:
74 return True
75 else:
76 return False
22
77
23 def launch_ssh_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, timeout=15):
78 def _try_passwordless_paramiko(server, keyfile):
79 """Try passwordless login with paramiko."""
80 if paramiko is None:
81 raise ImportError("paramiko unavailable, use openssh")
82 username, server, port = _split_server(server)
83 client = paramiko.SSHClient()
84 client.load_system_host_keys()
85 client.set_missing_host_key_policy(paramiko.WarningPolicy())
86 try:
87 client.connect(server, port, username=username, key_filename=keyfile,
88 look_for_keys=True)
89 except paramiko.AuthenticationException:
90 return False
91 else:
92 client.close()
93 return True
94
95
96 def tunnel_connection(socket, addr, server, keyfile=None, password=None, paramiko=None):
97 """Connect a socket to an address via an ssh tunnel.
98
99 This is a wrapper for socket.connect(addr), when addr is not accessible
100 from the local machine. It simply creates an ssh tunnel using the remaining args,
101 and calls socket.connect('tcp://localhost:lport') where lport is the randomly
102 selected local port of the tunnel.
103
104 """
105 lport = select_random_ports(1)[0]
106 transport, addr = addr.split('://')
107 ip,rport = addr.split(':')
108 rport = int(rport)
109 if paramiko is None:
110 paramiko = sys.platform == 'win32'
111 if paramiko:
112 tunnelf = paramiko_tunnel
113 else:
114 tunnelf = openssh_tunnel
115 tunnel = tunnelf(lport, rport, server, remoteip=ip, keyfile=keyfile, password=password)
116 socket.connect('tcp://127.0.0.1:%i'%lport)
117 return tunnel
118
119 def openssh_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=15):
24 """Create an ssh tunnel using command-line ssh that connects port lport
120 """Create an ssh tunnel using command-line ssh that connects port lport
25 on this machine to localhost:rport on server. The tunnel
121 on this machine to localhost:rport on server. The tunnel
26 will automatically close when not in use, remaining open
122 will automatically close when not in use, remaining open
27 for a minimum of timeout seconds for an initial connection.
123 for a minimum of timeout seconds for an initial connection.
124
125 This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`,
126 as seen from `server`.
127
128 keyfile and password may be specified, but ssh config is checked for defaults.
129
130 Parameters
131 ----------
132
133 lport : int
134 local port for connecting to the tunnel from this machine.
135 rport : int
136 port on the remote machine to connect to.
137 server : str
138 The ssh server to connect to. The full ssh server string will be parsed.
139 user@server:port
140 remoteip : str [Default: 127.0.0.1]
141 The remote ip, specifying the destination of the tunnel.
142 Default is localhost, which means that the tunnel would redirect
143 localhost:lport on this machine to localhost:rport on the *server*.
144
145 keyfile : str; path to public key file
146 This specifies a key to be used in ssh login, default None.
147 Regular default ssh keys will be used without specifying this argument.
148 password : str;
149 Your ssh password to the ssh server. Note that if this is left None,
150 you will be prompted for it if passwordless key based login is unavailable.
151
28 """
152 """
153 if pexpect is None:
154 raise ImportError("pexpect unavailable, use paramiko_tunnel")
29 ssh="ssh "
155 ssh="ssh "
30 if keyfile:
156 if keyfile:
31 ssh += "-i " + keyfile
157 ssh += "-i " + keyfile
32 cmd = ssh + " -f -L %i:127.0.0.1:%i %s sleep %i"%(lport, rport, server, timeout)
158 cmd = ssh + " -f -L 127.0.0.1:%i:127.0.0.1:%i %s sleep %i"%(lport, rport, server, timeout)
33 tunnel = pexpect.spawn(cmd)
159 tunnel = pexpect.spawn(cmd)
34 failed = False
160 failed = False
35 while True:
161 while True:
36 try:
162 try:
37 tunnel.expect('[Pp]assword:', timeout=.1)
163 tunnel.expect('[Pp]assword:', timeout=.1)
38 except pexpect.TIMEOUT:
164 except pexpect.TIMEOUT:
39 continue
165 continue
40 except pexpect.EOF:
166 except pexpect.EOF:
41 if tunnel.exitstatus:
167 if tunnel.exitstatus:
42 print (tunnel.exitstatus)
168 print (tunnel.exitstatus)
43 print (tunnel.before)
169 print (tunnel.before)
44 print (tunnel.after)
170 print (tunnel.after)
45 raise RuntimeError("tunnel '%s' failed to start"%(cmd))
171 raise RuntimeError("tunnel '%s' failed to start"%(cmd))
46 else:
172 else:
47 return tunnel.pid
173 return tunnel.pid
48 else:
174 else:
49 if failed:
175 if failed:
50 print("Password rejected, try again")
176 print("Password rejected, try again")
51 tunnel.sendline(getpass())
177 password=None
178 if password is None:
179 password = getpass("%s's password: "%(server))
180 tunnel.sendline(password)
52 failed = True
181 failed = True
53
182
54 def _split_server(server):
183 def _split_server(server):
55 if '@' in server:
184 if '@' in server:
56 username,server = server.split('@', 1)
185 username,server = server.split('@', 1)
57 else:
186 else:
58 username = getuser()
187 username = getuser()
59 if ':' in server:
188 if ':' in server:
60 server, port = server.split(':')
189 server, port = server.split(':')
61 port = int(port)
190 port = int(port)
62 else:
191 else:
63 port = 22
192 port = 22
64 return username, server, port
193 return username, server, port
65
194
66 def launch_paramiko_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None):
195 def paramiko_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=15):
67 """launch a tunner with paramiko in a subprocess"""
196 """launch a tunner with paramiko in a subprocess. This should only be used
197 when shell ssh is unavailable (e.g. Windows).
198
199 This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`,
200 as seen from `server`.
201
202 keyfile and password may be specified, but ssh config is checked for defaults.
203
204 Parameters
205 ----------
206
207 lport : int
208 local port for connecting to the tunnel from this machine.
209 rport : int
210 port on the remote machine to connect to.
211 server : str
212 The ssh server to connect to. The full ssh server string will be parsed.
213 user@server:port
214 remoteip : str [Default: 127.0.0.1]
215 The remote ip, specifying the destination of the tunnel.
216 Default is localhost, which means that the tunnel would redirect
217 localhost:lport on this machine to localhost:rport on the *server*.
218
219 keyfile : str; path to public key file
220 This specifies a key to be used in ssh login, default None.
221 Regular default ssh keys will be used without specifying this argument.
222 password : str;
223 Your ssh password to the ssh server. Note that if this is left None,
224 you will be prompted for it if passwordless key based login is unavailable.
225
226 """
68 if paramiko is None:
227 if paramiko is None:
69 raise ImportError("Paramiko not available")
228 raise ImportError("Paramiko not available")
70 server = _split_server(server)
229
71 if keyfile is None:
230 if password is None:
72 passwd = getpass("%s@%s's password: "%(server[0], server[1]))
231 if not _check_passwordless_paramiko(server, keyfile):
73 else:
232 password = getpass("%s's password: "%(server))
74 passwd = None
233
75 p = Process(target=_paramiko_tunnel,
234 p = Process(target=_paramiko_tunnel,
76 args=(lport, rport, server, remoteip),
235 args=(lport, rport, server, remoteip),
77 kwargs=dict(keyfile=keyfile, password=passwd))
236 kwargs=dict(keyfile=keyfile, password=password))
78 p.daemon=False
237 p.daemon=False
79 p.start()
238 p.start()
239 atexit.register(_shutdown_process, p)
80 return p
240 return p
81
241
242 def _shutdown_process(p):
243 if p.isalive():
244 p.terminate()
82
245
83 def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None):
246 def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None):
84 """function for actually starting a paramiko tunnel, to be passed
247 """function for actually starting a paramiko tunnel, to be passed
85 to multiprocessing.Process(target=this).
248 to multiprocessing.Process(target=this).
86 """
249 """
87 username, server, port = server
250 username, server, port = _split_server(server)
88 client = paramiko.SSHClient()
251 client = paramiko.SSHClient()
89 client.load_system_host_keys()
252 client.load_system_host_keys()
90 client.set_missing_host_key_policy(paramiko.WarningPolicy())
253 client.set_missing_host_key_policy(paramiko.WarningPolicy())
91
254
92 try:
255 try:
93 client.connect(server, port, username=username, key_filename=keyfile,
256 client.connect(server, port, username=username, key_filename=keyfile,
94 look_for_keys=True, password=password)
257 look_for_keys=True, password=password)
258 # except paramiko.AuthenticationException:
259 # if password is None:
260 # password = getpass("%s@%s's password: "%(username, server))
261 # client.connect(server, port, username=username, password=password)
262 # else:
263 # raise
95 except Exception as e:
264 except Exception as e:
96 print ('*** Failed to connect to %s:%d: %r' % (server, port, e))
265 print ('*** Failed to connect to %s:%d: %r' % (server, port, e))
97 sys.exit(1)
266 sys.exit(1)
98
267
99 print ('Now forwarding port %d to %s:%d ...' % (lport, server, rport))
268 # print ('Now forwarding port %d to %s:%d ...' % (lport, server, rport))
100
269
101 try:
270 try:
102 forward_tunnel(lport, remoteip, rport, client.get_transport())
271 forward_tunnel(lport, remoteip, rport, client.get_transport())
103 except KeyboardInterrupt:
272 except KeyboardInterrupt:
104 print ('C-c: Port forwarding stopped.')
273 print ('SIGINT: Port forwarding stopped cleanly')
105 sys.exit(0)
274 sys.exit(0)
275 except Exception as e:
276 print ("Port forwarding stopped uncleanly: %s"%e)
277 sys.exit(255)
278
279 if sys.platform == 'win32':
280 ssh_tunnel = paramiko_tunnel
281 else:
282 ssh_tunnel = openssh_tunnel
106
283
107
284
108 __all__ = ['launch_ssh_tunnel', 'launch_paramiko_tunnel']
285 __all__ = ['tunnel_connection', 'ssh_tunnel', 'openssh_tunnel', 'paramiko_tunnel', 'try_passwordless_ssh']
109
286
110
287
111
288
112
289
113
290
114
291
115
292
116
293
117
294
118
295
119
296
120
297
121
298
122
299
123
300
General Comments 0
You need to be logged in to leave comments. Login now