##// END OF EJS Templates
Merge pull request #3828 from minrk/signature_scheme...
Min RK -
r11850:c9c96208 merge
parent child Browse files
Show More
@@ -1,557 +1,557 b''
1 1 """Utilities for connecting to kernels
2 2
3 3 Authors:
4 4
5 5 * Min Ragan-Kelley
6 6
7 7 """
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2013 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19
20 20 from __future__ import absolute_import
21 21
22 22 import glob
23 23 import json
24 24 import os
25 25 import socket
26 26 import sys
27 27 from getpass import getpass
28 28 from subprocess import Popen, PIPE
29 29 import tempfile
30 30
31 31 import zmq
32 32
33 33 # external imports
34 34 from IPython.external.ssh import tunnel
35 35
36 36 # IPython imports
37 37 # from IPython.config import Configurable
38 38 from IPython.core.profiledir import ProfileDir
39 39 from IPython.utils.localinterfaces import LOCALHOST
40 40 from IPython.utils.path import filefind, get_ipython_dir
41 41 from IPython.utils.py3compat import str_to_bytes, bytes_to_str
42 42 from IPython.utils.traitlets import (
43 43 Bool, Integer, Unicode, CaselessStrEnum,
44 44 HasTraits,
45 45 )
46 46
47 47
48 48 #-----------------------------------------------------------------------------
49 49 # Working with Connection Files
50 50 #-----------------------------------------------------------------------------
51 51
52 52 def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
53 53 control_port=0, ip=LOCALHOST, key=b'', transport='tcp',
54 54 signature_scheme='hmac-sha256',
55 55 ):
56 56 """Generates a JSON config file, including the selection of random ports.
57 57
58 58 Parameters
59 59 ----------
60 60
61 61 fname : unicode
62 62 The path to the file to write
63 63
64 64 shell_port : int, optional
65 65 The port to use for ROUTER (shell) channel.
66 66
67 67 iopub_port : int, optional
68 68 The port to use for the SUB channel.
69 69
70 70 stdin_port : int, optional
71 71 The port to use for the ROUTER (raw input) channel.
72 72
73 73 control_port : int, optional
74 74 The port to use for the ROUTER (control) channel.
75 75
76 76 hb_port : int, optional
77 77 The port to use for the heartbeat REP channel.
78 78
79 79 ip : str, optional
80 80 The ip address the kernel will bind to.
81 81
82 82 key : str, optional
83 83 The Session key used for message authentication.
84 84
85 85 signature_scheme : str, optional
86 86 The scheme used for message authentication.
87 87 This has the form 'digest-hash', where 'digest'
88 88 is the scheme used for digests, and 'hash' is the name of the hash function
89 89 used by the digest scheme.
90 90 Currently, 'hmac' is the only supported digest scheme,
91 91 and 'sha256' is the default hash function.
92 92
93 93 """
94 94 # default to temporary connector file
95 95 if not fname:
96 96 fname = tempfile.mktemp('.json')
97 97
98 98 # Find open ports as necessary.
99 99
100 100 ports = []
101 101 ports_needed = int(shell_port <= 0) + \
102 102 int(iopub_port <= 0) + \
103 103 int(stdin_port <= 0) + \
104 104 int(control_port <= 0) + \
105 105 int(hb_port <= 0)
106 106 if transport == 'tcp':
107 107 for i in range(ports_needed):
108 108 sock = socket.socket()
109 109 sock.bind(('', 0))
110 110 ports.append(sock)
111 111 for i, sock in enumerate(ports):
112 112 port = sock.getsockname()[1]
113 113 sock.close()
114 114 ports[i] = port
115 115 else:
116 116 N = 1
117 117 for i in range(ports_needed):
118 118 while os.path.exists("%s-%s" % (ip, str(N))):
119 119 N += 1
120 120 ports.append(N)
121 121 N += 1
122 122 if shell_port <= 0:
123 123 shell_port = ports.pop(0)
124 124 if iopub_port <= 0:
125 125 iopub_port = ports.pop(0)
126 126 if stdin_port <= 0:
127 127 stdin_port = ports.pop(0)
128 128 if control_port <= 0:
129 129 control_port = ports.pop(0)
130 130 if hb_port <= 0:
131 131 hb_port = ports.pop(0)
132 132
133 133 cfg = dict( shell_port=shell_port,
134 134 iopub_port=iopub_port,
135 135 stdin_port=stdin_port,
136 136 control_port=control_port,
137 137 hb_port=hb_port,
138 138 )
139 139 cfg['ip'] = ip
140 140 cfg['key'] = bytes_to_str(key)
141 141 cfg['transport'] = transport
142 142 cfg['signature_scheme'] = signature_scheme
143 143
144 144 with open(fname, 'w') as f:
145 145 f.write(json.dumps(cfg, indent=2))
146 146
147 147 return fname, cfg
148 148
149 149
150 150 def get_connection_file(app=None):
151 151 """Return the path to the connection file of an app
152 152
153 153 Parameters
154 154 ----------
155 155 app : IPKernelApp instance [optional]
156 156 If unspecified, the currently running app will be used
157 157 """
158 158 if app is None:
159 159 from IPython.kernel.zmq.kernelapp import IPKernelApp
160 160 if not IPKernelApp.initialized():
161 161 raise RuntimeError("app not specified, and not in a running Kernel")
162 162
163 163 app = IPKernelApp.instance()
164 164 return filefind(app.connection_file, ['.', app.profile_dir.security_dir])
165 165
166 166
167 167 def find_connection_file(filename, profile=None):
168 168 """find a connection file, and return its absolute path.
169 169
170 170 The current working directory and the profile's security
171 171 directory will be searched for the file if it is not given by
172 172 absolute path.
173 173
174 174 If profile is unspecified, then the current running application's
175 175 profile will be used, or 'default', if not run from IPython.
176 176
177 177 If the argument does not match an existing file, it will be interpreted as a
178 178 fileglob, and the matching file in the profile's security dir with
179 179 the latest access time will be used.
180 180
181 181 Parameters
182 182 ----------
183 183 filename : str
184 184 The connection file or fileglob to search for.
185 185 profile : str [optional]
186 186 The name of the profile to use when searching for the connection file,
187 187 if different from the current IPython session or 'default'.
188 188
189 189 Returns
190 190 -------
191 191 str : The absolute path of the connection file.
192 192 """
193 193 from IPython.core.application import BaseIPythonApplication as IPApp
194 194 try:
195 195 # quick check for absolute path, before going through logic
196 196 return filefind(filename)
197 197 except IOError:
198 198 pass
199 199
200 200 if profile is None:
201 201 # profile unspecified, check if running from an IPython app
202 202 if IPApp.initialized():
203 203 app = IPApp.instance()
204 204 profile_dir = app.profile_dir
205 205 else:
206 206 # not running in IPython, use default profile
207 207 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), 'default')
208 208 else:
209 209 # find profiledir by profile name:
210 210 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), profile)
211 211 security_dir = profile_dir.security_dir
212 212
213 213 try:
214 214 # first, try explicit name
215 215 return filefind(filename, ['.', security_dir])
216 216 except IOError:
217 217 pass
218 218
219 219 # not found by full name
220 220
221 221 if '*' in filename:
222 222 # given as a glob already
223 223 pat = filename
224 224 else:
225 225 # accept any substring match
226 226 pat = '*%s*' % filename
227 227 matches = glob.glob( os.path.join(security_dir, pat) )
228 228 if not matches:
229 229 raise IOError("Could not find %r in %r" % (filename, security_dir))
230 230 elif len(matches) == 1:
231 231 return matches[0]
232 232 else:
233 233 # get most recent match, by access time:
234 234 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
235 235
236 236
237 237 def get_connection_info(connection_file=None, unpack=False, profile=None):
238 238 """Return the connection information for the current Kernel.
239 239
240 240 Parameters
241 241 ----------
242 242 connection_file : str [optional]
243 243 The connection file to be used. Can be given by absolute path, or
244 244 IPython will search in the security directory of a given profile.
245 245 If run from IPython,
246 246
247 247 If unspecified, the connection file for the currently running
248 248 IPython Kernel will be used, which is only allowed from inside a kernel.
249 249 unpack : bool [default: False]
250 250 if True, return the unpacked dict, otherwise just the string contents
251 251 of the file.
252 252 profile : str [optional]
253 253 The name of the profile to use when searching for the connection file,
254 254 if different from the current IPython session or 'default'.
255 255
256 256
257 257 Returns
258 258 -------
259 259 The connection dictionary of the current kernel, as string or dict,
260 260 depending on `unpack`.
261 261 """
262 262 if connection_file is None:
263 263 # get connection file from current kernel
264 264 cf = get_connection_file()
265 265 else:
266 266 # connection file specified, allow shortnames:
267 267 cf = find_connection_file(connection_file, profile=profile)
268 268
269 269 with open(cf) as f:
270 270 info = f.read()
271 271
272 272 if unpack:
273 273 info = json.loads(info)
274 274 # ensure key is bytes:
275 275 info['key'] = str_to_bytes(info.get('key', ''))
276 276 return info
277 277
278 278
279 279 def connect_qtconsole(connection_file=None, argv=None, profile=None):
280 280 """Connect a qtconsole to the current kernel.
281 281
282 282 This is useful for connecting a second qtconsole to a kernel, or to a
283 283 local notebook.
284 284
285 285 Parameters
286 286 ----------
287 287 connection_file : str [optional]
288 288 The connection file to be used. Can be given by absolute path, or
289 289 IPython will search in the security directory of a given profile.
290 290 If run from IPython,
291 291
292 292 If unspecified, the connection file for the currently running
293 293 IPython Kernel will be used, which is only allowed from inside a kernel.
294 294 argv : list [optional]
295 295 Any extra args to be passed to the console.
296 296 profile : str [optional]
297 297 The name of the profile to use when searching for the connection file,
298 298 if different from the current IPython session or 'default'.
299 299
300 300
301 301 Returns
302 302 -------
303 303 subprocess.Popen instance running the qtconsole frontend
304 304 """
305 305 argv = [] if argv is None else argv
306 306
307 307 if connection_file is None:
308 308 # get connection file from current kernel
309 309 cf = get_connection_file()
310 310 else:
311 311 cf = find_connection_file(connection_file, profile=profile)
312 312
313 313 cmd = ';'.join([
314 314 "from IPython.qt.console import qtconsoleapp",
315 315 "qtconsoleapp.main()"
316 316 ])
317 317
318 318 return Popen([sys.executable, '-c', cmd, '--existing', cf] + argv,
319 319 stdout=PIPE, stderr=PIPE, close_fds=True,
320 320 )
321 321
322 322
323 323 def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
324 324 """tunnel connections to a kernel via ssh
325 325
326 326 This will open four SSH tunnels from localhost on this machine to the
327 327 ports associated with the kernel. They can be either direct
328 328 localhost-localhost tunnels, or if an intermediate server is necessary,
329 329 the kernel must be listening on a public IP.
330 330
331 331 Parameters
332 332 ----------
333 333 connection_info : dict or str (path)
334 334 Either a connection dict, or the path to a JSON connection file
335 335 sshserver : str
336 336 The ssh sever to use to tunnel to the kernel. Can be a full
337 337 `user@server:port` string. ssh config aliases are respected.
338 338 sshkey : str [optional]
339 339 Path to file containing ssh key to use for authentication.
340 340 Only necessary if your ssh config does not already associate
341 341 a keyfile with the host.
342 342
343 343 Returns
344 344 -------
345 345
346 346 (shell, iopub, stdin, hb) : ints
347 347 The four ports on localhost that have been forwarded to the kernel.
348 348 """
349 349 if isinstance(connection_info, basestring):
350 350 # it's a path, unpack it
351 351 with open(connection_info) as f:
352 352 connection_info = json.loads(f.read())
353 353
354 354 cf = connection_info
355 355
356 356 lports = tunnel.select_random_ports(4)
357 357 rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port']
358 358
359 359 remote_ip = cf['ip']
360 360
361 361 if tunnel.try_passwordless_ssh(sshserver, sshkey):
362 362 password=False
363 363 else:
364 364 password = getpass("SSH Password for %s: "%sshserver)
365 365
366 366 for lp,rp in zip(lports, rports):
367 367 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
368 368
369 369 return tuple(lports)
370 370
371 371
372 372 #-----------------------------------------------------------------------------
373 373 # Mixin for classes that work with connection files
374 374 #-----------------------------------------------------------------------------
375 375
376 376 channel_socket_types = {
377 377 'hb' : zmq.REQ,
378 378 'shell' : zmq.DEALER,
379 379 'iopub' : zmq.SUB,
380 380 'stdin' : zmq.DEALER,
381 381 'control': zmq.DEALER,
382 382 }
383 383
384 384 port_names = [ "%s_port" % channel for channel in ('shell', 'stdin', 'iopub', 'hb', 'control')]
385 385
386 386 class ConnectionFileMixin(HasTraits):
387 387 """Mixin for configurable classes that work with connection files"""
388 388
389 389 # The addresses for the communication channels
390 390 connection_file = Unicode('')
391 391 _connection_file_written = Bool(False)
392 392
393 393 transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
394 signature_scheme = Unicode('')
395 394
396 395 ip = Unicode(LOCALHOST, config=True,
397 396 help="""Set the kernel\'s IP address [default localhost].
398 397 If the IP address is something other than localhost, then
399 398 Consoles on other machines will be able to connect
400 399 to the Kernel, so be careful!"""
401 400 )
402 401
403 402 def _ip_default(self):
404 403 if self.transport == 'ipc':
405 404 if self.connection_file:
406 405 return os.path.splitext(self.connection_file)[0] + '-ipc'
407 406 else:
408 407 return 'kernel-ipc'
409 408 else:
410 409 return LOCALHOST
411 410
412 411 def _ip_changed(self, name, old, new):
413 412 if new == '*':
414 413 self.ip = '0.0.0.0'
415 414
416 415 # protected traits
417 416
418 417 shell_port = Integer(0)
419 418 iopub_port = Integer(0)
420 419 stdin_port = Integer(0)
421 420 control_port = Integer(0)
422 421 hb_port = Integer(0)
423 422
424 423 @property
425 424 def ports(self):
426 425 return [ getattr(self, name) for name in port_names ]
427 426
428 427 #--------------------------------------------------------------------------
429 428 # Connection and ipc file management
430 429 #--------------------------------------------------------------------------
431 430
432 431 def get_connection_info(self):
433 432 """return the connection info as a dict"""
434 433 return dict(
435 434 transport=self.transport,
436 435 ip=self.ip,
437 436 shell_port=self.shell_port,
438 437 iopub_port=self.iopub_port,
439 438 stdin_port=self.stdin_port,
440 439 hb_port=self.hb_port,
441 440 control_port=self.control_port,
442 signature_scheme=self.signature_scheme,
441 signature_scheme=self.session.signature_scheme,
442 key=self.session.key,
443 443 )
444 444
445 445 def cleanup_connection_file(self):
446 446 """Cleanup connection file *if we wrote it*
447 447
448 448 Will not raise if the connection file was already removed somehow.
449 449 """
450 450 if self._connection_file_written:
451 451 # cleanup connection files on full shutdown of kernel we started
452 452 self._connection_file_written = False
453 453 try:
454 454 os.remove(self.connection_file)
455 455 except (IOError, OSError, AttributeError):
456 456 pass
457 457
458 458 def cleanup_ipc_files(self):
459 459 """Cleanup ipc files if we wrote them."""
460 460 if self.transport != 'ipc':
461 461 return
462 462 for port in self.ports:
463 463 ipcfile = "%s-%i" % (self.ip, port)
464 464 try:
465 465 os.remove(ipcfile)
466 466 except (IOError, OSError):
467 467 pass
468 468
469 469 def write_connection_file(self):
470 470 """Write connection info to JSON dict in self.connection_file."""
471 471 if self._connection_file_written:
472 472 return
473 473
474 474 self.connection_file, cfg = write_connection_file(self.connection_file,
475 475 transport=self.transport, ip=self.ip, key=self.session.key,
476 476 stdin_port=self.stdin_port, iopub_port=self.iopub_port,
477 477 shell_port=self.shell_port, hb_port=self.hb_port,
478 478 control_port=self.control_port,
479 signature_scheme=self.signature_scheme,
479 signature_scheme=self.session.signature_scheme,
480 480 )
481 481 # write_connection_file also sets default ports:
482 482 for name in port_names:
483 483 setattr(self, name, cfg[name])
484 484
485 485 self._connection_file_written = True
486 486
487 487 def load_connection_file(self):
488 488 """Load connection info from JSON dict in self.connection_file."""
489 489 with open(self.connection_file) as f:
490 490 cfg = json.loads(f.read())
491 491
492 492 self.transport = cfg.get('transport', 'tcp')
493 493 self.ip = cfg['ip']
494 494 for name in port_names:
495 495 setattr(self, name, cfg[name])
496 496 if 'key' in cfg:
497 497 self.session.key = str_to_bytes(cfg['key'])
498 498 if cfg.get('signature_scheme'):
499 499 self.session.signature_scheme = cfg['signature_scheme']
500 500
501 501 #--------------------------------------------------------------------------
502 502 # Creating connected sockets
503 503 #--------------------------------------------------------------------------
504 504
505 505 def _make_url(self, channel):
506 506 """Make a ZeroMQ URL for a given channel."""
507 507 transport = self.transport
508 508 ip = self.ip
509 509 port = getattr(self, '%s_port' % channel)
510 510
511 511 if transport == 'tcp':
512 512 return "tcp://%s:%i" % (ip, port)
513 513 else:
514 514 return "%s://%s-%s" % (transport, ip, port)
515 515
516 516 def _create_connected_socket(self, channel, identity=None):
517 517 """Create a zmq Socket and connect it to the kernel."""
518 518 url = self._make_url(channel)
519 519 socket_type = channel_socket_types[channel]
520 520 self.log.info("Connecting to: %s" % url)
521 521 sock = self.context.socket(socket_type)
522 522 if identity:
523 523 sock.identity = identity
524 524 sock.connect(url)
525 525 return sock
526 526
527 527 def connect_iopub(self, identity=None):
528 528 """return zmq Socket connected to the IOPub channel"""
529 529 sock = self._create_connected_socket('iopub', identity=identity)
530 530 sock.setsockopt(zmq.SUBSCRIBE, b'')
531 531 return sock
532 532
533 533 def connect_shell(self, identity=None):
534 534 """return zmq Socket connected to the Shell channel"""
535 535 return self._create_connected_socket('shell', identity=identity)
536 536
537 537 def connect_stdin(self, identity=None):
538 538 """return zmq Socket connected to the StdIn channel"""
539 539 return self._create_connected_socket('stdin', identity=identity)
540 540
541 541 def connect_hb(self, identity=None):
542 542 """return zmq Socket connected to the Heartbeat channel"""
543 543 return self._create_connected_socket('hb', identity=identity)
544 544
545 545 def connect_control(self, identity=None):
546 546 """return zmq Socket connected to the Heartbeat channel"""
547 547 return self._create_connected_socket('control', identity=identity)
548 548
549 549
550 550 __all__ = [
551 551 'write_connection_file',
552 552 'get_connection_file',
553 553 'find_connection_file',
554 554 'get_connection_info',
555 555 'connect_qtconsole',
556 556 'tunnel_to_kernel',
557 557 ]
@@ -1,50 +1,61 b''
1 1 """Tests for the notebook kernel and session manager"""
2 2
3 3 from subprocess import PIPE
4 4 import time
5 5 from unittest import TestCase
6 6
7 7 from IPython.testing import decorators as dec
8 8
9 9 from IPython.config.loader import Config
10 10 from IPython.kernel import KernelManager
11 11
12 12 class TestKernelManager(TestCase):
13 13
14 14 def _get_tcp_km(self):
15 15 c = Config()
16 16 km = KernelManager(config=c)
17 17 return km
18 18
19 19 def _get_ipc_km(self):
20 20 c = Config()
21 21 c.KernelManager.transport = 'ipc'
22 22 c.KernelManager.ip = 'test'
23 23 km = KernelManager(config=c)
24 24 return km
25 25
26 26 def _run_lifecycle(self, km):
27 27 km.start_kernel(stdout=PIPE, stderr=PIPE)
28 28 self.assertTrue(km.is_alive())
29 29 km.restart_kernel()
30 30 self.assertTrue(km.is_alive())
31 31 # We need a delay here to give the restarting kernel a chance to
32 32 # restart. Otherwise, the interrupt will kill it, causing the test
33 33 # suite to hang. The reason it *hangs* is that the shutdown
34 34 # message for the restart sometimes hasn't been sent to the kernel.
35 35 # Because linger is oo on the shell channel, the context can't
36 36 # close until the message is sent to the kernel, which is not dead.
37 37 time.sleep(1.0)
38 38 km.interrupt_kernel()
39 39 self.assertTrue(isinstance(km, KernelManager))
40 40 km.shutdown_kernel()
41 41
42 42 def test_tcp_lifecycle(self):
43 43 km = self._get_tcp_km()
44 44 self._run_lifecycle(km)
45 45
46 46 @dec.skip_win32
47 47 def test_ipc_lifecycle(self):
48 48 km = self._get_ipc_km()
49 49 self._run_lifecycle(km)
50
51 def test_get_connect_info(self):
52 km = self._get_tcp_km()
53 cinfo = km.get_connection_info()
54 keys = sorted(cinfo.keys())
55 expected = sorted([
56 'ip', 'transport',
57 'hb_port', 'shell_port', 'stdin_port', 'iopub_port', 'control_port',
58 'key', 'signature_scheme',
59 ])
60 self.assertEqual(keys, expected)
50 61
General Comments 0
You need to be logged in to leave comments. Login now