##// END OF EJS Templates
added preliminary ssh tunneling support for clients
Min RK -
Show More
@@ -1,183 +1,183
1 1 #!/usr/bin/env python
2 2
3 3 #
4 # This file is adapted from a paramiko demo, and thus LGPL 2.1.
4 # This file is adapted from a paramiko demo, and thus licensed under LGPL 2.1.
5 5 # Original Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
6 6 # Edits Copyright (C) 2010 The IPython Team
7 7 #
8 8 # Paramiko is free software; you can redistribute it and/or modify it under the
9 9 # terms of the GNU Lesser General Public License as published by the Free
10 10 # Software Foundation; either version 2.1 of the License, or (at your option)
11 11 # any later version.
12 12 #
13 13 # Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
14 14 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15 15 # A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
16 16 # details.
17 17 #
18 18 # You should have received a copy of the GNU Lesser General Public License
19 19 # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
20 20 # 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
21 21
22 22 """
23 23 Sample script showing how to do local port forwarding over paramiko.
24 24
25 25 This script connects to the requested SSH server and sets up local port
26 26 forwarding (the openssh -L option) from a local port through a tunneled
27 27 connection to a destination reachable from the SSH server machine.
28 28 """
29 29
30 30 from __future__ import print_function
31 31
32 32 import getpass
33 33 import os
34 34 import socket
35 35 import select
36 36 import SocketServer
37 37 import sys
38 38 from optparse import OptionParser
39 39
40 40 import paramiko
41 41
42 42 SSH_PORT = 22
43 43 DEFAULT_PORT = 4000
44 44
45 45 g_verbose = False
46 46
47 47
48 48 class ForwardServer (SocketServer.ThreadingTCPServer):
49 49 daemon_threads = True
50 50 allow_reuse_address = True
51 51
52 52
53 53 class Handler (SocketServer.BaseRequestHandler):
54 54
55 55 def handle(self):
56 56 try:
57 57 chan = self.ssh_transport.open_channel('direct-tcpip',
58 58 (self.chain_host, self.chain_port),
59 59 self.request.getpeername())
60 60 except Exception, e:
61 61 verbose('Incoming request to %s:%d failed: %s' % (self.chain_host,
62 62 self.chain_port,
63 63 repr(e)))
64 64 return
65 65 if chan is None:
66 66 verbose('Incoming request to %s:%d was rejected by the SSH server.' %
67 67 (self.chain_host, self.chain_port))
68 68 return
69 69
70 70 verbose('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(),
71 71 chan.getpeername(), (self.chain_host, self.chain_port)))
72 72 while True:
73 73 r, w, x = select.select([self.request, chan], [], [])
74 74 if self.request in r:
75 75 data = self.request.recv(1024)
76 76 if len(data) == 0:
77 77 break
78 78 chan.send(data)
79 79 if chan in r:
80 80 data = chan.recv(1024)
81 81 if len(data) == 0:
82 82 break
83 83 self.request.send(data)
84 84 chan.close()
85 85 self.request.close()
86 verbose('Tunnel closed from %r' % (self.request.getpeername(),))
86 verbose('Tunnel closed ')
87 87
88 88
89 89 def forward_tunnel(local_port, remote_host, remote_port, transport):
90 90 # this is a little convoluted, but lets me configure things for the Handler
91 91 # object. (SocketServer doesn't give Handlers any way to access the outer
92 92 # server normally.)
93 93 class SubHander (Handler):
94 94 chain_host = remote_host
95 95 chain_port = remote_port
96 96 ssh_transport = transport
97 ForwardServer(('', local_port), SubHander).serve_forever()
97 ForwardServer(('127.0.0.1', local_port), SubHander).serve_forever()
98 98
99 99
100 100 def verbose(s):
101 101 if g_verbose:
102 102 print (s)
103 103
104 104
105 105 HELP = """\
106 106 Set up a forward tunnel across an SSH server, using paramiko. A local port
107 107 (given with -p) is forwarded across an SSH session to an address:port from
108 108 the SSH server. This is similar to the openssh -L option.
109 109 """
110 110
111 111
112 112 def get_host_port(spec, default_port):
113 113 "parse 'hostname:22' into a host and port, with the port optional"
114 114 args = (spec.split(':', 1) + [default_port])[:2]
115 115 args[1] = int(args[1])
116 116 return args[0], args[1]
117 117
118 118
119 119 def parse_options():
120 120 global g_verbose
121 121
122 122 parser = OptionParser(usage='usage: %prog [options] <ssh-server>[:<server-port>]',
123 123 version='%prog 1.0', description=HELP)
124 124 parser.add_option('-q', '--quiet', action='store_false', dest='verbose', default=True,
125 125 help='squelch all informational output')
126 126 parser.add_option('-p', '--local-port', action='store', type='int', dest='port',
127 127 default=DEFAULT_PORT,
128 128 help='local port to forward (default: %d)' % DEFAULT_PORT)
129 129 parser.add_option('-u', '--user', action='store', type='string', dest='user',
130 130 default=getpass.getuser(),
131 131 help='username for SSH authentication (default: %s)' % getpass.getuser())
132 132 parser.add_option('-K', '--key', action='store', type='string', dest='keyfile',
133 133 default=None,
134 134 help='private key file to use for SSH authentication')
135 135 parser.add_option('', '--no-key', action='store_false', dest='look_for_keys', default=True,
136 136 help='don\'t look for or use a private key file')
137 137 parser.add_option('-P', '--password', action='store_true', dest='readpass', default=False,
138 138 help='read password (for key or password auth) from stdin')
139 139 parser.add_option('-r', '--remote', action='store', type='string', dest='remote', default=None, metavar='host:port',
140 140 help='remote host and port to forward to')
141 141 options, args = parser.parse_args()
142 142
143 143 if len(args) != 1:
144 144 parser.error('Incorrect number of arguments.')
145 145 if options.remote is None:
146 146 parser.error('Remote address required (-r).')
147 147
148 148 g_verbose = options.verbose
149 149 server_host, server_port = get_host_port(args[0], SSH_PORT)
150 150 remote_host, remote_port = get_host_port(options.remote, SSH_PORT)
151 151 return options, (server_host, server_port), (remote_host, remote_port)
152 152
153 153
154 154 def main():
155 155 options, server, remote = parse_options()
156 156
157 157 password = None
158 158 if options.readpass:
159 159 password = getpass.getpass('Enter SSH password: ')
160 160
161 161 client = paramiko.SSHClient()
162 162 client.load_system_host_keys()
163 163 client.set_missing_host_key_policy(paramiko.WarningPolicy())
164 164
165 165 verbose('Connecting to ssh host %s:%d ...' % (server[0], server[1]))
166 166 try:
167 167 client.connect(server[0], server[1], username=options.user, key_filename=options.keyfile,
168 168 look_for_keys=options.look_for_keys, password=password)
169 169 except Exception as e:
170 170 print ('*** Failed to connect to %s:%d: %r' % (server[0], server[1], e))
171 171 sys.exit(1)
172 172
173 173 verbose('Now forwarding port %d to %s:%d ...' % (options.port, remote[0], remote[1]))
174 174
175 175 try:
176 176 forward_tunnel(options.port, remote[0], remote[1], client.get_transport())
177 177 except KeyboardInterrupt:
178 178 print ('C-c: Port forwarding stopped.')
179 179 sys.exit(0)
180 180
181 181
182 182 if __name__ == '__main__':
183 183 main()
@@ -1,855 +1,905
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from __future__ import print_function
14 14
15 15 import time
16 16 from pprint import pprint
17 17
18 18 import zmq
19 19 from zmq.eventloop import ioloop, zmqstream
20 20
21 21 from IPython.external.decorator import decorator
22 from IPython.zmq import tunnel
22 23
23 24 import streamsession as ss
24 25 # from remotenamespace import RemoteNamespace
25 26 from view import DirectView, LoadBalancedView
26 27 from dependency import Dependency, depend, require
27 28
28 29 def _push(ns):
29 30 globals().update(ns)
30 31
31 32 def _pull(keys):
32 33 g = globals()
33 34 if isinstance(keys, (list,tuple, set)):
34 35 for key in keys:
35 36 if not g.has_key(key):
36 37 raise NameError("name '%s' is not defined"%key)
37 38 return map(g.get, keys)
38 39 else:
39 40 if not g.has_key(keys):
40 41 raise NameError("name '%s' is not defined"%keys)
41 42 return g.get(keys)
42 43
43 44 def _clear():
44 45 globals().clear()
45 46
46 47 def execute(code):
47 48 exec code in globals()
48 49
49 50 #--------------------------------------------------------------------------
50 51 # Decorators for Client methods
51 52 #--------------------------------------------------------------------------
52 53
53 54 @decorator
54 55 def spinfirst(f, self, *args, **kwargs):
55 56 """Call spin() to sync state prior to calling the method."""
56 57 self.spin()
57 58 return f(self, *args, **kwargs)
58 59
59 60 @decorator
60 61 def defaultblock(f, self, *args, **kwargs):
61 62 """Default to self.block; preserve self.block."""
62 63 block = kwargs.get('block',None)
63 64 block = self.block if block is None else block
64 65 saveblock = self.block
65 66 self.block = block
66 67 ret = f(self, *args, **kwargs)
67 68 self.block = saveblock
68 69 return ret
69 70
70 71 def remote(client, bound=False, block=None, targets=None):
71 72 """Turn a function into a remote function.
72 73
73 74 This method can be used for map:
74 75
75 76 >>> @remote(client,block=True)
76 77 def func(a)
77 78 """
78 79 def remote_function(f):
79 80 return RemoteFunction(client, f, bound, block, targets)
80 81 return remote_function
81 82
82 83 #--------------------------------------------------------------------------
83 84 # Classes
84 85 #--------------------------------------------------------------------------
85 86
86 87 class RemoteFunction(object):
87 88 """Turn an existing function into a remote function"""
88 89
89 90 def __init__(self, client, f, bound=False, block=None, targets=None):
90 91 self.client = client
91 92 self.func = f
92 93 self.block=block
93 94 self.bound=bound
94 95 self.targets=targets
95 96
96 97 def __call__(self, *args, **kwargs):
97 98 return self.client.apply(self.func, args=args, kwargs=kwargs,
98 99 block=self.block, targets=self.targets, bound=self.bound)
99 100
100 101
101 102 class AbortedTask(object):
102 103 """A basic wrapper object describing an aborted task."""
103 104 def __init__(self, msg_id):
104 105 self.msg_id = msg_id
105 106
106 107 class ControllerError(Exception):
107 108 def __init__(self, etype, evalue, tb):
108 109 self.etype = etype
109 110 self.evalue = evalue
110 111 self.traceback=tb
111 112
112 113 class Client(object):
113 114 """A semi-synchronous client to the IPython ZMQ controller
114 115
115 116 Parameters
116 117 ----------
117 118
118 119 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
119 120 The address of the controller's registration socket.
120
121 [Default: 'tcp://127.0.0.1:10101']
122 context : zmq.Context
123 Pass an existing zmq.Context instance, otherwise the client will create its own
124 username : bytes
125 set username to be passed to the Session object
126 debug : bool
127 flag for lots of message printing for debug purposes
128
129 #-------------- ssh related args ----------------
130 # These are args for configuring the ssh tunnel to be used
131 # credentials are used to forward connections over ssh to the Controller
132 # Note that the ip given in `addr` needs to be relative to sshserver
133 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
134 # and set sshserver as the same machine the Controller is on. However,
135 # the only requirement is that sshserver is able to see the Controller
136 # (i.e. is within the same trusted network).
137
138 sshserver : str
139 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
140 If keyfile or password is specified, and this is not, it will default to
141 the ip given in addr.
142 keyfile : str; path to public key file
143 This specifies a key to be used in ssh login, default None.
144 Regular default ssh keys will be used without specifying this argument.
145 password : str;
146 Your ssh password to sshserver. Note that if this is left None,
147 you will be prompted for it if passwordless key based login is unavailable.
121 148
122 149 Attributes
123 150 ----------
124 151 ids : set of int engine IDs
125 152 requesting the ids attribute always synchronizes
126 153 the registration state. To request ids without synchronization,
127 154 use semi-private _ids.
128 155
129 156 history : list of msg_ids
130 157 a list of msg_ids, keeping track of all the execution
131 158 messages you have submitted in order.
132 159
133 160 outstanding : set of msg_ids
134 161 a set of msg_ids that have been submitted, but whose
135 162 results have not yet been received.
136 163
137 164 results : dict
138 165 a dict of all our results, keyed by msg_id
139 166
140 167 block : bool
141 168 determines default behavior when block not specified
142 169 in execution methods
143 170
144 171 Methods
145 172 -------
146 173 spin : flushes incoming results and registration state changes
147 174 control methods spin, and requesting `ids` also ensures up to date
148 175
149 176 barrier : wait on one or more msg_ids
150 177
151 178 execution methods: apply/apply_bound/apply_to/applu_bount
152 179 legacy: execute, run
153 180
154 181 query methods: queue_status, get_result, purge
155 182
156 183 control methods: abort, kill
157 184
158 185 """
159 186
160 187
161 188 _connected=False
189 _ssh=False
162 190 _engines=None
163 191 _addr='tcp://127.0.0.1:10101'
164 192 _registration_socket=None
165 193 _query_socket=None
166 194 _control_socket=None
167 195 _notification_socket=None
168 196 _mux_socket=None
169 197 _task_socket=None
170 198 block = False
171 199 outstanding=None
172 200 results = None
173 201 history = None
174 202 debug = False
175 203
176 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False):
204 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
205 sshserver=None, keyfile=None, password=None, paramiko=None):
177 206 if context is None:
178 207 context = zmq.Context()
179 208 self.context = context
180 209 self._addr = addr
210 self._ssh = bool(sshserver or keyfile or password)
211 if self._ssh and sshserver is None:
212 # default to the same
213 sshserver = addr.split('://')[1].split(':')[0]
214 if self._ssh and password is None:
215 if tunnel.try_passwordless_ssh(sshserver, keyfile, paramiko):
216 password=False
217 else:
218 password = getpass("SSH Password for %s: "%sshserver)
219 ssh_kwargs = dict(keyfile=keyfile, password=password, paramiko=paramiko)
220
181 221 if username is None:
182 222 self.session = ss.StreamSession()
183 223 else:
184 224 self.session = ss.StreamSession(username)
185 self._registration_socket = self.context.socket(zmq.PAIR)
225 self._registration_socket = self.context.socket(zmq.XREQ)
186 226 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
227 if self._ssh:
228 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
229 else:
187 230 self._registration_socket.connect(addr)
188 231 self._engines = {}
189 232 self._ids = set()
190 233 self.outstanding=set()
191 234 self.results = {}
192 235 self.history = []
193 236 self.debug = debug
194 237 self.session.debug = debug
195 238
196 239 self._notification_handlers = {'registration_notification' : self._register_engine,
197 240 'unregistration_notification' : self._unregister_engine,
198 241 }
199 242 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
200 243 'apply_reply' : self._handle_apply_reply}
201 self._connect()
244 self._connect(sshserver, ssh_kwargs)
202 245
203 246
204 247 @property
205 248 def ids(self):
206 249 """Always up to date ids property."""
207 250 self._flush_notifications()
208 251 return self._ids
209 252
210 253 def _update_engines(self, engines):
211 254 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
212 255 for k,v in engines.iteritems():
213 256 eid = int(k)
214 257 self._engines[eid] = bytes(v) # force not unicode
215 258 self._ids.add(eid)
216 259
217 260 def _build_targets(self, targets):
218 261 """Turn valid target IDs or 'all' into two lists:
219 262 (int_ids, uuids).
220 263 """
221 264 if targets is None:
222 265 targets = self._ids
223 266 elif isinstance(targets, str):
224 267 if targets.lower() == 'all':
225 268 targets = self._ids
226 269 else:
227 270 raise TypeError("%r not valid str target, must be 'all'"%(targets))
228 271 elif isinstance(targets, int):
229 272 targets = [targets]
230 273 return [self._engines[t] for t in targets], list(targets)
231 274
232 def _connect(self):
275 def _connect(self, sshserver, ssh_kwargs):
233 276 """setup all our socket connections to the controller. This is called from
234 277 __init__."""
235 278 if self._connected:
236 279 return
237 280 self._connected=True
281
282 def connect_socket(s, addr):
283 if self._ssh:
284 return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
285 else:
286 return s.connect(addr)
287
238 288 self.session.send(self._registration_socket, 'connection_request')
239 289 idents,msg = self.session.recv(self._registration_socket,mode=0)
240 290 if self.debug:
241 291 pprint(msg)
242 292 msg = ss.Message(msg)
243 293 content = msg.content
244 294 if content.status == 'ok':
245 295 if content.queue:
246 296 self._mux_socket = self.context.socket(zmq.PAIR)
247 297 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
248 self._mux_socket.connect(content.queue)
298 connect_socket(self._mux_socket, content.queue)
249 299 if content.task:
250 300 self._task_socket = self.context.socket(zmq.PAIR)
251 301 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
252 self._task_socket.connect(content.task)
302 connect_socket(self._task_socket, content.task)
253 303 if content.notification:
254 304 self._notification_socket = self.context.socket(zmq.SUB)
255 self._notification_socket.connect(content.notification)
305 connect_socket(self._notification_socket, content.notification)
256 306 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
257 307 if content.query:
258 308 self._query_socket = self.context.socket(zmq.PAIR)
259 309 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
260 self._query_socket.connect(content.query)
310 connect_socket(self._query_socket, content.query)
261 311 if content.control:
262 312 self._control_socket = self.context.socket(zmq.PAIR)
263 313 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
264 self._control_socket.connect(content.control)
314 connect_socket(self._control_socket, content.control)
265 315 self._update_engines(dict(content.engines))
266 316
267 317 else:
268 318 self._connected = False
269 319 raise Exception("Failed to connect!")
270 320
271 321 #--------------------------------------------------------------------------
272 322 # handlers and callbacks for incoming messages
273 323 #--------------------------------------------------------------------------
274 324
275 325 def _register_engine(self, msg):
276 326 """Register a new engine, and update our connection info."""
277 327 content = msg['content']
278 328 eid = content['id']
279 329 d = {eid : content['queue']}
280 330 self._update_engines(d)
281 331 self._ids.add(int(eid))
282 332
283 333 def _unregister_engine(self, msg):
284 334 """Unregister an engine that has died."""
285 335 content = msg['content']
286 336 eid = int(content['id'])
287 337 if eid in self._ids:
288 338 self._ids.remove(eid)
289 339 self._engines.pop(eid)
290 340
291 341 def _handle_execute_reply(self, msg):
292 342 """Save the reply to an execute_request into our results."""
293 343 parent = msg['parent_header']
294 344 msg_id = parent['msg_id']
295 345 if msg_id not in self.outstanding:
296 346 print("got unknown result: %s"%msg_id)
297 347 else:
298 348 self.outstanding.remove(msg_id)
299 349 self.results[msg_id] = ss.unwrap_exception(msg['content'])
300 350
301 351 def _handle_apply_reply(self, msg):
302 352 """Save the reply to an apply_request into our results."""
303 353 parent = msg['parent_header']
304 354 msg_id = parent['msg_id']
305 355 if msg_id not in self.outstanding:
306 356 print ("got unknown result: %s"%msg_id)
307 357 else:
308 358 self.outstanding.remove(msg_id)
309 359 content = msg['content']
310 360 if content['status'] == 'ok':
311 361 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
312 362 elif content['status'] == 'aborted':
313 363 self.results[msg_id] = AbortedTask(msg_id)
314 364 elif content['status'] == 'resubmitted':
315 365 # TODO: handle resubmission
316 366 pass
317 367 else:
318 368 self.results[msg_id] = ss.unwrap_exception(content)
319 369
320 370 def _flush_notifications(self):
321 371 """Flush notifications of engine registrations waiting
322 372 in ZMQ queue."""
323 373 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
324 374 while msg is not None:
325 375 if self.debug:
326 376 pprint(msg)
327 377 msg = msg[-1]
328 378 msg_type = msg['msg_type']
329 379 handler = self._notification_handlers.get(msg_type, None)
330 380 if handler is None:
331 381 raise Exception("Unhandled message type: %s"%msg.msg_type)
332 382 else:
333 383 handler(msg)
334 384 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
335 385
336 386 def _flush_results(self, sock):
337 387 """Flush task or queue results waiting in ZMQ queue."""
338 388 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
339 389 while msg is not None:
340 390 if self.debug:
341 391 pprint(msg)
342 392 msg = msg[-1]
343 393 msg_type = msg['msg_type']
344 394 handler = self._queue_handlers.get(msg_type, None)
345 395 if handler is None:
346 396 raise Exception("Unhandled message type: %s"%msg.msg_type)
347 397 else:
348 398 handler(msg)
349 399 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
350 400
351 401 def _flush_control(self, sock):
352 402 """Flush replies from the control channel waiting
353 403 in the ZMQ queue.
354 404
355 405 Currently: ignore them."""
356 406 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
357 407 while msg is not None:
358 408 if self.debug:
359 409 pprint(msg)
360 410 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
361 411
362 412 #--------------------------------------------------------------------------
363 413 # getitem
364 414 #--------------------------------------------------------------------------
365 415
366 416 def __getitem__(self, key):
367 417 """Dict access returns DirectView multiplexer objects or,
368 418 if key is None, a LoadBalancedView."""
369 419 if key is None:
370 420 return LoadBalancedView(self)
371 421 if isinstance(key, int):
372 422 if key not in self.ids:
373 423 raise IndexError("No such engine: %i"%key)
374 424 return DirectView(self, key)
375 425
376 426 if isinstance(key, slice):
377 427 indices = range(len(self.ids))[key]
378 428 ids = sorted(self._ids)
379 429 key = [ ids[i] for i in indices ]
380 430 # newkeys = sorted(self._ids)[thekeys[k]]
381 431
382 432 if isinstance(key, (tuple, list, xrange)):
383 433 _,targets = self._build_targets(list(key))
384 434 return DirectView(self, targets)
385 435 else:
386 436 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
387 437
388 438 #--------------------------------------------------------------------------
389 439 # Begin public methods
390 440 #--------------------------------------------------------------------------
391 441
392 442 def spin(self):
393 443 """Flush any registration notifications and execution results
394 444 waiting in the ZMQ queue.
395 445 """
396 446 if self._notification_socket:
397 447 self._flush_notifications()
398 448 if self._mux_socket:
399 449 self._flush_results(self._mux_socket)
400 450 if self._task_socket:
401 451 self._flush_results(self._task_socket)
402 452 if self._control_socket:
403 453 self._flush_control(self._control_socket)
404 454
405 455 def barrier(self, msg_ids=None, timeout=-1):
406 456 """waits on one or more `msg_ids`, for up to `timeout` seconds.
407 457
408 458 Parameters
409 459 ----------
410 460 msg_ids : int, str, or list of ints and/or strs
411 461 ints are indices to self.history
412 462 strs are msg_ids
413 463 default: wait on all outstanding messages
414 464 timeout : float
415 465 a time in seconds, after which to give up.
416 466 default is -1, which means no timeout
417 467
418 468 Returns
419 469 -------
420 470 True : when all msg_ids are done
421 471 False : timeout reached, some msg_ids still outstanding
422 472 """
423 473 tic = time.time()
424 474 if msg_ids is None:
425 475 theids = self.outstanding
426 476 else:
427 477 if isinstance(msg_ids, (int, str)):
428 478 msg_ids = [msg_ids]
429 479 theids = set()
430 480 for msg_id in msg_ids:
431 481 if isinstance(msg_id, int):
432 482 msg_id = self.history[msg_id]
433 483 theids.add(msg_id)
434 484 self.spin()
435 485 while theids.intersection(self.outstanding):
436 486 if timeout >= 0 and ( time.time()-tic ) > timeout:
437 487 break
438 488 time.sleep(1e-3)
439 489 self.spin()
440 490 return len(theids.intersection(self.outstanding)) == 0
441 491
442 492 #--------------------------------------------------------------------------
443 493 # Control methods
444 494 #--------------------------------------------------------------------------
445 495
446 496 @spinfirst
447 497 @defaultblock
448 498 def clear(self, targets=None, block=None):
449 499 """Clear the namespace in target(s)."""
450 500 targets = self._build_targets(targets)[0]
451 501 for t in targets:
452 502 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
453 503 error = False
454 504 if self.block:
455 505 for i in range(len(targets)):
456 506 idents,msg = self.session.recv(self._control_socket,0)
457 507 if self.debug:
458 508 pprint(msg)
459 509 if msg['content']['status'] != 'ok':
460 510 error = ss.unwrap_exception(msg['content'])
461 511 if error:
462 512 return error
463 513
464 514
465 515 @spinfirst
466 516 @defaultblock
467 517 def abort(self, msg_ids = None, targets=None, block=None):
468 518 """Abort the execution queues of target(s)."""
469 519 targets = self._build_targets(targets)[0]
470 520 if isinstance(msg_ids, basestring):
471 521 msg_ids = [msg_ids]
472 522 content = dict(msg_ids=msg_ids)
473 523 for t in targets:
474 524 self.session.send(self._control_socket, 'abort_request',
475 525 content=content, ident=t)
476 526 error = False
477 527 if self.block:
478 528 for i in range(len(targets)):
479 529 idents,msg = self.session.recv(self._control_socket,0)
480 530 if self.debug:
481 531 pprint(msg)
482 532 if msg['content']['status'] != 'ok':
483 533 error = ss.unwrap_exception(msg['content'])
484 534 if error:
485 535 return error
486 536
487 537 @spinfirst
488 538 @defaultblock
489 539 def kill(self, targets=None, block=None):
490 540 """Terminates one or more engine processes."""
491 541 targets = self._build_targets(targets)[0]
492 542 for t in targets:
493 543 self.session.send(self._control_socket, 'kill_request', content={},ident=t)
494 544 error = False
495 545 if self.block:
496 546 for i in range(len(targets)):
497 547 idents,msg = self.session.recv(self._control_socket,0)
498 548 if self.debug:
499 549 pprint(msg)
500 550 if msg['content']['status'] != 'ok':
501 551 error = ss.unwrap_exception(msg['content'])
502 552 if error:
503 553 return error
504 554
505 555 #--------------------------------------------------------------------------
506 556 # Execution methods
507 557 #--------------------------------------------------------------------------
508 558
509 559 @defaultblock
510 560 def execute(self, code, targets='all', block=None):
511 561 """Executes `code` on `targets` in blocking or nonblocking manner.
512 562
513 563 Parameters
514 564 ----------
515 565 code : str
516 566 the code string to be executed
517 567 targets : int/str/list of ints/strs
518 568 the engines on which to execute
519 569 default : all
520 570 block : bool
521 571 whether or not to wait until done to return
522 572 default: self.block
523 573 """
524 574 # block = self.block if block is None else block
525 575 # saveblock = self.block
526 576 # self.block = block
527 577 result = self.apply(execute, (code,), targets=targets, block=block, bound=True)
528 578 # self.block = saveblock
529 579 return result
530 580
531 581 def run(self, code, block=None):
532 582 """Runs `code` on an engine.
533 583
534 584 Calls to this are load-balanced.
535 585
536 586 Parameters
537 587 ----------
538 588 code : str
539 589 the code string to be executed
540 590 block : bool
541 591 whether or not to wait until done
542 592
543 593 """
544 594 result = self.apply(execute, (code,), targets=None, block=block, bound=False)
545 595 return result
546 596
547 597 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
548 598 after=None, follow=None):
549 599 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
550 600
551 601 This is the central execution command for the client.
552 602
553 603 Parameters
554 604 ----------
555 605
556 606 f : function
557 607 The fuction to be called remotely
558 608 args : tuple/list
559 609 The positional arguments passed to `f`
560 610 kwargs : dict
561 611 The keyword arguments passed to `f`
562 612 bound : bool (default: True)
563 613 Whether to execute in the Engine(s) namespace, or in a clean
564 614 namespace not affecting the engine.
565 615 block : bool (default: self.block)
566 616 Whether to wait for the result, or return immediately.
567 617 False:
568 618 returns msg_id(s)
569 619 if multiple targets:
570 620 list of ids
571 621 True:
572 622 returns actual result(s) of f(*args, **kwargs)
573 623 if multiple targets:
574 624 dict of results, by engine ID
575 625 targets : int,list of ints, 'all', None
576 626 Specify the destination of the job.
577 627 if None:
578 628 Submit via Task queue for load-balancing.
579 629 if 'all':
580 630 Run on all active engines
581 631 if list:
582 632 Run on each specified engine
583 633 if int:
584 634 Run on single engine
585 635
586 636 after : Dependency or collection of msg_ids
587 637 Only for load-balanced execution (targets=None)
588 638 Specify a list of msg_ids as a time-based dependency.
589 639 This job will only be run *after* the dependencies
590 640 have been met.
591 641
592 642 follow : Dependency or collection of msg_ids
593 643 Only for load-balanced execution (targets=None)
594 644 Specify a list of msg_ids as a location-based dependency.
595 645 This job will only be run on an engine where this dependency
596 646 is met.
597 647
598 648 Returns
599 649 -------
600 650 if block is False:
601 651 if single target:
602 652 return msg_id
603 653 else:
604 654 return list of msg_ids
605 655 ? (should this be dict like block=True) ?
606 656 else:
607 657 if single target:
608 658 return result of f(*args, **kwargs)
609 659 else:
610 660 return dict of results, keyed by engine
611 661 """
612 662
613 663 # defaults:
614 664 block = block if block is not None else self.block
615 665 args = args if args is not None else []
616 666 kwargs = kwargs if kwargs is not None else {}
617 667
618 668 # enforce types of f,args,kwrags
619 669 if not callable(f):
620 670 raise TypeError("f must be callable, not %s"%type(f))
621 671 if not isinstance(args, (tuple, list)):
622 672 raise TypeError("args must be tuple or list, not %s"%type(args))
623 673 if not isinstance(kwargs, dict):
624 674 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
625 675
626 676 options = dict(bound=bound, block=block, after=after, follow=follow)
627 677
628 678 if targets is None:
629 679 return self._apply_balanced(f, args, kwargs, **options)
630 680 else:
631 681 return self._apply_direct(f, args, kwargs, targets=targets, **options)
632 682
633 683 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
634 684 after=None, follow=None):
635 685 """The underlying method for applying functions in a load balanced
636 686 manner, via the task queue."""
637 687 if isinstance(after, Dependency):
638 688 after = after.as_dict()
639 689 elif after is None:
640 690 after = []
641 691 if isinstance(follow, Dependency):
642 692 follow = follow.as_dict()
643 693 elif follow is None:
644 694 follow = []
645 695 subheader = dict(after=after, follow=follow)
646 696
647 697 bufs = ss.pack_apply_message(f,args,kwargs)
648 698 content = dict(bound=bound)
649 699 msg = self.session.send(self._task_socket, "apply_request",
650 700 content=content, buffers=bufs, subheader=subheader)
651 701 msg_id = msg['msg_id']
652 702 self.outstanding.add(msg_id)
653 703 self.history.append(msg_id)
654 704 if block:
655 705 self.barrier(msg_id)
656 706 return self.results[msg_id]
657 707 else:
658 708 return msg_id
659 709
660 710 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
661 711 after=None, follow=None):
662 712 """Then underlying method for applying functions to specific engines
663 713 via the MUX queue."""
664 714
665 715 queues,targets = self._build_targets(targets)
666 716 bufs = ss.pack_apply_message(f,args,kwargs)
667 717 if isinstance(after, Dependency):
668 718 after = after.as_dict()
669 719 elif after is None:
670 720 after = []
671 721 if isinstance(follow, Dependency):
672 722 follow = follow.as_dict()
673 723 elif follow is None:
674 724 follow = []
675 725 subheader = dict(after=after, follow=follow)
676 726 content = dict(bound=bound)
677 727 msg_ids = []
678 728 for queue in queues:
679 729 msg = self.session.send(self._mux_socket, "apply_request",
680 730 content=content, buffers=bufs,ident=queue, subheader=subheader)
681 731 msg_id = msg['msg_id']
682 732 self.outstanding.add(msg_id)
683 733 self.history.append(msg_id)
684 734 msg_ids.append(msg_id)
685 735 if block:
686 736 self.barrier(msg_ids)
687 737 else:
688 738 if len(msg_ids) == 1:
689 739 return msg_ids[0]
690 740 else:
691 741 return msg_ids
692 742 if len(msg_ids) == 1:
693 743 return self.results[msg_ids[0]]
694 744 else:
695 745 result = {}
696 746 for target,mid in zip(targets, msg_ids):
697 747 result[target] = self.results[mid]
698 748 return result
699 749
700 750 #--------------------------------------------------------------------------
701 751 # Data movement
702 752 #--------------------------------------------------------------------------
703 753
704 754 @defaultblock
705 755 def push(self, ns, targets=None, block=None):
706 756 """Push the contents of `ns` into the namespace on `target`"""
707 757 if not isinstance(ns, dict):
708 758 raise TypeError("Must be a dict, not %s"%type(ns))
709 759 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
710 760 return result
711 761
712 762 @defaultblock
713 763 def pull(self, keys, targets=None, block=True):
714 764 """Pull objects from `target`'s namespace by `keys`"""
715 765 if isinstance(keys, str):
716 766 pass
717 767 elif isistance(keys, (list,tuple,set)):
718 768 for key in keys:
719 769 if not isinstance(key, str):
720 770 raise TypeError
721 771 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
722 772 return result
723 773
724 774 #--------------------------------------------------------------------------
725 775 # Query methods
726 776 #--------------------------------------------------------------------------
727 777
728 778 @spinfirst
729 779 def get_results(self, msg_ids, status_only=False):
730 780 """Returns the result of the execute or task request with `msg_ids`.
731 781
732 782 Parameters
733 783 ----------
734 784 msg_ids : list of ints or msg_ids
735 785 if int:
736 786 Passed as index to self.history for convenience.
737 787 status_only : bool (default: False)
738 788 if False:
739 789 return the actual results
740 790 """
741 791 if not isinstance(msg_ids, (list,tuple)):
742 792 msg_ids = [msg_ids]
743 793 theids = []
744 794 for msg_id in msg_ids:
745 795 if isinstance(msg_id, int):
746 796 msg_id = self.history[msg_id]
747 797 if not isinstance(msg_id, str):
748 798 raise TypeError("msg_ids must be str, not %r"%msg_id)
749 799 theids.append(msg_id)
750 800
751 801 completed = []
752 802 local_results = {}
753 803 for msg_id in list(theids):
754 804 if msg_id in self.results:
755 805 completed.append(msg_id)
756 806 local_results[msg_id] = self.results[msg_id]
757 807 theids.remove(msg_id)
758 808
759 809 if theids: # some not locally cached
760 810 content = dict(msg_ids=theids, status_only=status_only)
761 811 msg = self.session.send(self._query_socket, "result_request", content=content)
762 812 zmq.select([self._query_socket], [], [])
763 813 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
764 814 if self.debug:
765 815 pprint(msg)
766 816 content = msg['content']
767 817 if content['status'] != 'ok':
768 818 raise ss.unwrap_exception(content)
769 819 else:
770 820 content = dict(completed=[],pending=[])
771 821 if not status_only:
772 822 # load cached results into result:
773 823 content['completed'].extend(completed)
774 824 content.update(local_results)
775 825 # update cache with results:
776 826 for msg_id in msg_ids:
777 827 if msg_id in content['completed']:
778 828 self.results[msg_id] = content[msg_id]
779 829 return content
780 830
781 831 @spinfirst
782 832 def queue_status(self, targets=None, verbose=False):
783 833 """Fetch the status of engine queues.
784 834
785 835 Parameters
786 836 ----------
787 837 targets : int/str/list of ints/strs
788 838 the engines on which to execute
789 839 default : all
790 840 verbose : bool
791 841 whether to return lengths only, or lists of ids for each element
792 842 """
793 843 targets = self._build_targets(targets)[1]
794 844 content = dict(targets=targets, verbose=verbose)
795 845 self.session.send(self._query_socket, "queue_request", content=content)
796 846 idents,msg = self.session.recv(self._query_socket, 0)
797 847 if self.debug:
798 848 pprint(msg)
799 849 content = msg['content']
800 850 status = content.pop('status')
801 851 if status != 'ok':
802 852 raise ss.unwrap_exception(content)
803 853 return content
804 854
805 855 @spinfirst
806 856 def purge_results(self, msg_ids=[], targets=[]):
807 857 """Tell the controller to forget results.
808 858
809 859 Individual results can be purged by msg_id, or the entire
810 860 history of specific targets can
811 861
812 862 Parameters
813 863 ----------
814 864 targets : int/str/list of ints/strs
815 865 the targets
816 866 default : None
817 867 """
818 868 if not targets and not msg_ids:
819 869 raise ValueError
820 870 if targets:
821 871 targets = self._build_targets(targets)[1]
822 872 content = dict(targets=targets, msg_ids=msg_ids)
823 873 self.session.send(self._query_socket, "purge_request", content=content)
824 874 idents, msg = self.session.recv(self._query_socket, 0)
825 875 if self.debug:
826 876 pprint(msg)
827 877 content = msg['content']
828 878 if content['status'] != 'ok':
829 879 raise ss.unwrap_exception(content)
830 880
831 881 class AsynClient(Client):
832 882 """An Asynchronous client, using the Tornado Event Loop.
833 883 !!!unfinished!!!"""
834 884 io_loop = None
835 885 _queue_stream = None
836 886 _notifier_stream = None
837 887 _task_stream = None
838 888 _control_stream = None
839 889
840 890 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
841 891 Client.__init__(self, addr, context, username, debug)
842 892 if io_loop is None:
843 893 io_loop = ioloop.IOLoop.instance()
844 894 self.io_loop = io_loop
845 895
846 896 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
847 897 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
848 898 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
849 899 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
850 900
851 901 def spin(self):
852 902 for stream in (self.queue_stream, self.notifier_stream,
853 903 self.task_stream, self.control_stream):
854 904 stream.flush()
855 905
@@ -1,123 +1,300
1 """Basic ssh tunneling utilities."""
1 2
3 #-----------------------------------------------------------------------------
4 # Copyright (C) 2008-2010 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-----------------------------------------------------------------------------
2 9
3 #-----------------------------------------
10
11
12 #-----------------------------------------------------------------------------
4 13 # Imports
5 #-----------------------------------------
14 #-----------------------------------------------------------------------------
6 15
7 16 from __future__ import print_function
8 17
9 import os,sys
18 import os,sys, atexit
10 19 from multiprocessing import Process
11 20 from getpass import getpass, getuser
12 21
13 22 try:
14 23 import paramiko
15 24 except ImportError:
16 25 paramiko = None
17 26 else:
18 27 from forward import forward_tunnel
19 28
29 try:
20 30 from IPython.external import pexpect
31 except ImportError:
32 pexpect = None
33
34 from IPython.zmq.parallel.entry_point import select_random_ports
35
36 #-----------------------------------------------------------------------------
37 # Code
38 #-----------------------------------------------------------------------------
21 39
40 #-----------------------------------------------------------------------------
41 # Check for passwordless login
42 #-----------------------------------------------------------------------------
43
44 def try_passwordless_ssh(server, keyfile, paramiko=None):
45 """Attempt to make an ssh connection without a password.
46 This is mainly used for requiring password input only once
47 when many tunnels may be connected to the same server.
48
49 If paramiko is None, the default for the platform is chosen.
50 """
51 if paramiko is None:
52 paramiko = sys.platform == 'win32'
53 if not paramiko:
54 f = _try_passwordless_openssh
55 else:
56 f = _try_passwordless_paramiko
57 return f(server, keyfile)
58
59 def _try_passwordless_openssh(server, keyfile):
60 """Try passwordless login with shell ssh command."""
61 if pexpect is None:
62 raise ImportError("pexpect unavailable, use paramiko")
63 cmd = 'ssh -f '+ server
64 if keyfile:
65 cmd += ' -i ' + keyfile
66 cmd += ' exit'
67 p = pexpect.spawn(cmd)
68 while True:
69 try:
70 p.expect('[Ppassword]:', timeout=.1)
71 except pexpect.TIMEOUT:
72 continue
73 except pexpect.EOF:
74 return True
75 else:
76 return False
22 77
23 def launch_ssh_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, timeout=15):
78 def _try_passwordless_paramiko(server, keyfile):
79 """Try passwordless login with paramiko."""
80 if paramiko is None:
81 raise ImportError("paramiko unavailable, use openssh")
82 username, server, port = _split_server(server)
83 client = paramiko.SSHClient()
84 client.load_system_host_keys()
85 client.set_missing_host_key_policy(paramiko.WarningPolicy())
86 try:
87 client.connect(server, port, username=username, key_filename=keyfile,
88 look_for_keys=True)
89 except paramiko.AuthenticationException:
90 return False
91 else:
92 client.close()
93 return True
94
95
96 def tunnel_connection(socket, addr, server, keyfile=None, password=None, paramiko=None):
97 """Connect a socket to an address via an ssh tunnel.
98
99 This is a wrapper for socket.connect(addr), when addr is not accessible
100 from the local machine. It simply creates an ssh tunnel using the remaining args,
101 and calls socket.connect('tcp://localhost:lport') where lport is the randomly
102 selected local port of the tunnel.
103
104 """
105 lport = select_random_ports(1)[0]
106 transport, addr = addr.split('://')
107 ip,rport = addr.split(':')
108 rport = int(rport)
109 if paramiko is None:
110 paramiko = sys.platform == 'win32'
111 if paramiko:
112 tunnelf = paramiko_tunnel
113 else:
114 tunnelf = openssh_tunnel
115 tunnel = tunnelf(lport, rport, server, remoteip=ip, keyfile=keyfile, password=password)
116 socket.connect('tcp://127.0.0.1:%i'%lport)
117 return tunnel
118
119 def openssh_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=15):
24 120 """Create an ssh tunnel using command-line ssh that connects port lport
25 121 on this machine to localhost:rport on server. The tunnel
26 122 will automatically close when not in use, remaining open
27 123 for a minimum of timeout seconds for an initial connection.
124
125 This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`,
126 as seen from `server`.
127
128 keyfile and password may be specified, but ssh config is checked for defaults.
129
130 Parameters
131 ----------
132
133 lport : int
134 local port for connecting to the tunnel from this machine.
135 rport : int
136 port on the remote machine to connect to.
137 server : str
138 The ssh server to connect to. The full ssh server string will be parsed.
139 user@server:port
140 remoteip : str [Default: 127.0.0.1]
141 The remote ip, specifying the destination of the tunnel.
142 Default is localhost, which means that the tunnel would redirect
143 localhost:lport on this machine to localhost:rport on the *server*.
144
145 keyfile : str; path to public key file
146 This specifies a key to be used in ssh login, default None.
147 Regular default ssh keys will be used without specifying this argument.
148 password : str;
149 Your ssh password to the ssh server. Note that if this is left None,
150 you will be prompted for it if passwordless key based login is unavailable.
151
28 152 """
153 if pexpect is None:
154 raise ImportError("pexpect unavailable, use paramiko_tunnel")
29 155 ssh="ssh "
30 156 if keyfile:
31 157 ssh += "-i " + keyfile
32 cmd = ssh + " -f -L %i:127.0.0.1:%i %s sleep %i"%(lport, rport, server, timeout)
158 cmd = ssh + " -f -L 127.0.0.1:%i:127.0.0.1:%i %s sleep %i"%(lport, rport, server, timeout)
33 159 tunnel = pexpect.spawn(cmd)
34 160 failed = False
35 161 while True:
36 162 try:
37 163 tunnel.expect('[Pp]assword:', timeout=.1)
38 164 except pexpect.TIMEOUT:
39 165 continue
40 166 except pexpect.EOF:
41 167 if tunnel.exitstatus:
42 168 print (tunnel.exitstatus)
43 169 print (tunnel.before)
44 170 print (tunnel.after)
45 171 raise RuntimeError("tunnel '%s' failed to start"%(cmd))
46 172 else:
47 173 return tunnel.pid
48 174 else:
49 175 if failed:
50 176 print("Password rejected, try again")
51 tunnel.sendline(getpass())
177 password=None
178 if password is None:
179 password = getpass("%s's password: "%(server))
180 tunnel.sendline(password)
52 181 failed = True
53 182
54 183 def _split_server(server):
55 184 if '@' in server:
56 185 username,server = server.split('@', 1)
57 186 else:
58 187 username = getuser()
59 188 if ':' in server:
60 189 server, port = server.split(':')
61 190 port = int(port)
62 191 else:
63 192 port = 22
64 193 return username, server, port
65 194
66 def launch_paramiko_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None):
67 """launch a tunner with paramiko in a subprocess"""
195 def paramiko_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=15):
196 """launch a tunner with paramiko in a subprocess. This should only be used
197 when shell ssh is unavailable (e.g. Windows).
198
199 This creates a tunnel redirecting `localhost:lport` to `remoteip:rport`,
200 as seen from `server`.
201
202 keyfile and password may be specified, but ssh config is checked for defaults.
203
204 Parameters
205 ----------
206
207 lport : int
208 local port for connecting to the tunnel from this machine.
209 rport : int
210 port on the remote machine to connect to.
211 server : str
212 The ssh server to connect to. The full ssh server string will be parsed.
213 user@server:port
214 remoteip : str [Default: 127.0.0.1]
215 The remote ip, specifying the destination of the tunnel.
216 Default is localhost, which means that the tunnel would redirect
217 localhost:lport on this machine to localhost:rport on the *server*.
218
219 keyfile : str; path to public key file
220 This specifies a key to be used in ssh login, default None.
221 Regular default ssh keys will be used without specifying this argument.
222 password : str;
223 Your ssh password to the ssh server. Note that if this is left None,
224 you will be prompted for it if passwordless key based login is unavailable.
225
226 """
68 227 if paramiko is None:
69 228 raise ImportError("Paramiko not available")
70 server = _split_server(server)
71 if keyfile is None:
72 passwd = getpass("%s@%s's password: "%(server[0], server[1]))
73 else:
74 passwd = None
229
230 if password is None:
231 if not _check_passwordless_paramiko(server, keyfile):
232 password = getpass("%s's password: "%(server))
233
75 234 p = Process(target=_paramiko_tunnel,
76 235 args=(lport, rport, server, remoteip),
77 kwargs=dict(keyfile=keyfile, password=passwd))
236 kwargs=dict(keyfile=keyfile, password=password))
78 237 p.daemon=False
79 238 p.start()
239 atexit.register(_shutdown_process, p)
80 240 return p
81 241
242 def _shutdown_process(p):
243 if p.isalive():
244 p.terminate()
82 245
83 246 def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None):
84 247 """function for actually starting a paramiko tunnel, to be passed
85 248 to multiprocessing.Process(target=this).
86 249 """
87 username, server, port = server
250 username, server, port = _split_server(server)
88 251 client = paramiko.SSHClient()
89 252 client.load_system_host_keys()
90 253 client.set_missing_host_key_policy(paramiko.WarningPolicy())
91 254
92 255 try:
93 256 client.connect(server, port, username=username, key_filename=keyfile,
94 257 look_for_keys=True, password=password)
258 # except paramiko.AuthenticationException:
259 # if password is None:
260 # password = getpass("%s@%s's password: "%(username, server))
261 # client.connect(server, port, username=username, password=password)
262 # else:
263 # raise
95 264 except Exception as e:
96 265 print ('*** Failed to connect to %s:%d: %r' % (server, port, e))
97 266 sys.exit(1)
98 267
99 print ('Now forwarding port %d to %s:%d ...' % (lport, server, rport))
268 # print ('Now forwarding port %d to %s:%d ...' % (lport, server, rport))
100 269
101 270 try:
102 271 forward_tunnel(lport, remoteip, rport, client.get_transport())
103 272 except KeyboardInterrupt:
104 print ('C-c: Port forwarding stopped.')
273 print ('SIGINT: Port forwarding stopped cleanly')
105 274 sys.exit(0)
275 except Exception as e:
276 print ("Port forwarding stopped uncleanly: %s"%e)
277 sys.exit(255)
278
279 if sys.platform == 'win32':
280 ssh_tunnel = paramiko_tunnel
281 else:
282 ssh_tunnel = openssh_tunnel
106 283
107 284
108 __all__ = ['launch_ssh_tunnel', 'launch_paramiko_tunnel']
285 __all__ = ['tunnel_connection', 'ssh_tunnel', 'openssh_tunnel', 'paramiko_tunnel', 'try_passwordless_ssh']
109 286
110 287
111 288
112 289
113 290
114 291
115 292
116 293
117 294
118 295
119 296
120 297
121 298
122 299
123 300
General Comments 0
You need to be logged in to leave comments. Login now