##// END OF EJS Templates
remove IPython.external.ssh...
MinRK -
Show More
@@ -1,578 +1,576 b''
1 1 """Utilities for connecting to kernels
2 2
3 3 The :class:`ConnectionFileMixin` class in this module encapsulates the logic
4 4 related to writing and reading connections files.
5 5 """
6 6 # Copyright (c) IPython Development Team.
7 7 # Distributed under the terms of the Modified BSD License.
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from __future__ import absolute_import
14 14
15 15 import glob
16 16 import json
17 17 import os
18 18 import socket
19 19 import sys
20 20 from getpass import getpass
21 21 from subprocess import Popen, PIPE
22 22 import tempfile
23 23
24 24 import zmq
25
26 # external imports
27 from IPython.external.ssh import tunnel
25 from zmq.ssh import tunnel
28 26
29 27 # IPython imports
30 28 from IPython.config import LoggingConfigurable
31 29 from IPython.core.profiledir import ProfileDir
32 30 from IPython.utils.localinterfaces import localhost
33 31 from IPython.utils.path import filefind, get_ipython_dir
34 32 from IPython.utils.py3compat import (str_to_bytes, bytes_to_str, cast_bytes_py2,
35 33 string_types)
36 34 from IPython.utils.traitlets import (
37 35 Bool, Integer, Unicode, CaselessStrEnum, Instance,
38 36 )
39 37
40 38
41 39 #-----------------------------------------------------------------------------
42 40 # Working with Connection Files
43 41 #-----------------------------------------------------------------------------
44 42
45 43 def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
46 44 control_port=0, ip='', key=b'', transport='tcp',
47 45 signature_scheme='hmac-sha256',
48 46 ):
49 47 """Generates a JSON config file, including the selection of random ports.
50 48
51 49 Parameters
52 50 ----------
53 51
54 52 fname : unicode
55 53 The path to the file to write
56 54
57 55 shell_port : int, optional
58 56 The port to use for ROUTER (shell) channel.
59 57
60 58 iopub_port : int, optional
61 59 The port to use for the SUB channel.
62 60
63 61 stdin_port : int, optional
64 62 The port to use for the ROUTER (raw input) channel.
65 63
66 64 control_port : int, optional
67 65 The port to use for the ROUTER (control) channel.
68 66
69 67 hb_port : int, optional
70 68 The port to use for the heartbeat REP channel.
71 69
72 70 ip : str, optional
73 71 The ip address the kernel will bind to.
74 72
75 73 key : str, optional
76 74 The Session key used for message authentication.
77 75
78 76 signature_scheme : str, optional
79 77 The scheme used for message authentication.
80 78 This has the form 'digest-hash', where 'digest'
81 79 is the scheme used for digests, and 'hash' is the name of the hash function
82 80 used by the digest scheme.
83 81 Currently, 'hmac' is the only supported digest scheme,
84 82 and 'sha256' is the default hash function.
85 83
86 84 """
87 85 if not ip:
88 86 ip = localhost()
89 87 # default to temporary connector file
90 88 if not fname:
91 89 fd, fname = tempfile.mkstemp('.json')
92 90 os.close(fd)
93 91
94 92 # Find open ports as necessary.
95 93
96 94 ports = []
97 95 ports_needed = int(shell_port <= 0) + \
98 96 int(iopub_port <= 0) + \
99 97 int(stdin_port <= 0) + \
100 98 int(control_port <= 0) + \
101 99 int(hb_port <= 0)
102 100 if transport == 'tcp':
103 101 for i in range(ports_needed):
104 102 sock = socket.socket()
105 103 # struct.pack('ii', (0,0)) is 8 null bytes
106 104 sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b'\0' * 8)
107 105 sock.bind(('', 0))
108 106 ports.append(sock)
109 107 for i, sock in enumerate(ports):
110 108 port = sock.getsockname()[1]
111 109 sock.close()
112 110 ports[i] = port
113 111 else:
114 112 N = 1
115 113 for i in range(ports_needed):
116 114 while os.path.exists("%s-%s" % (ip, str(N))):
117 115 N += 1
118 116 ports.append(N)
119 117 N += 1
120 118 if shell_port <= 0:
121 119 shell_port = ports.pop(0)
122 120 if iopub_port <= 0:
123 121 iopub_port = ports.pop(0)
124 122 if stdin_port <= 0:
125 123 stdin_port = ports.pop(0)
126 124 if control_port <= 0:
127 125 control_port = ports.pop(0)
128 126 if hb_port <= 0:
129 127 hb_port = ports.pop(0)
130 128
131 129 cfg = dict( shell_port=shell_port,
132 130 iopub_port=iopub_port,
133 131 stdin_port=stdin_port,
134 132 control_port=control_port,
135 133 hb_port=hb_port,
136 134 )
137 135 cfg['ip'] = ip
138 136 cfg['key'] = bytes_to_str(key)
139 137 cfg['transport'] = transport
140 138 cfg['signature_scheme'] = signature_scheme
141 139
142 140 with open(fname, 'w') as f:
143 141 f.write(json.dumps(cfg, indent=2))
144 142
145 143 return fname, cfg
146 144
147 145
148 146 def get_connection_file(app=None):
149 147 """Return the path to the connection file of an app
150 148
151 149 Parameters
152 150 ----------
153 151 app : IPKernelApp instance [optional]
154 152 If unspecified, the currently running app will be used
155 153 """
156 154 if app is None:
157 155 from IPython.kernel.zmq.kernelapp import IPKernelApp
158 156 if not IPKernelApp.initialized():
159 157 raise RuntimeError("app not specified, and not in a running Kernel")
160 158
161 159 app = IPKernelApp.instance()
162 160 return filefind(app.connection_file, ['.', app.profile_dir.security_dir])
163 161
164 162
165 163 def find_connection_file(filename, profile=None):
166 164 """find a connection file, and return its absolute path.
167 165
168 166 The current working directory and the profile's security
169 167 directory will be searched for the file if it is not given by
170 168 absolute path.
171 169
172 170 If profile is unspecified, then the current running application's
173 171 profile will be used, or 'default', if not run from IPython.
174 172
175 173 If the argument does not match an existing file, it will be interpreted as a
176 174 fileglob, and the matching file in the profile's security dir with
177 175 the latest access time will be used.
178 176
179 177 Parameters
180 178 ----------
181 179 filename : str
182 180 The connection file or fileglob to search for.
183 181 profile : str [optional]
184 182 The name of the profile to use when searching for the connection file,
185 183 if different from the current IPython session or 'default'.
186 184
187 185 Returns
188 186 -------
189 187 str : The absolute path of the connection file.
190 188 """
191 189 from IPython.core.application import BaseIPythonApplication as IPApp
192 190 try:
193 191 # quick check for absolute path, before going through logic
194 192 return filefind(filename)
195 193 except IOError:
196 194 pass
197 195
198 196 if profile is None:
199 197 # profile unspecified, check if running from an IPython app
200 198 if IPApp.initialized():
201 199 app = IPApp.instance()
202 200 profile_dir = app.profile_dir
203 201 else:
204 202 # not running in IPython, use default profile
205 203 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), 'default')
206 204 else:
207 205 # find profiledir by profile name:
208 206 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), profile)
209 207 security_dir = profile_dir.security_dir
210 208
211 209 try:
212 210 # first, try explicit name
213 211 return filefind(filename, ['.', security_dir])
214 212 except IOError:
215 213 pass
216 214
217 215 # not found by full name
218 216
219 217 if '*' in filename:
220 218 # given as a glob already
221 219 pat = filename
222 220 else:
223 221 # accept any substring match
224 222 pat = '*%s*' % filename
225 223 matches = glob.glob( os.path.join(security_dir, pat) )
226 224 if not matches:
227 225 raise IOError("Could not find %r in %r" % (filename, security_dir))
228 226 elif len(matches) == 1:
229 227 return matches[0]
230 228 else:
231 229 # get most recent match, by access time:
232 230 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
233 231
234 232
235 233 def get_connection_info(connection_file=None, unpack=False, profile=None):
236 234 """Return the connection information for the current Kernel.
237 235
238 236 Parameters
239 237 ----------
240 238 connection_file : str [optional]
241 239 The connection file to be used. Can be given by absolute path, or
242 240 IPython will search in the security directory of a given profile.
243 241 If run from IPython,
244 242
245 243 If unspecified, the connection file for the currently running
246 244 IPython Kernel will be used, which is only allowed from inside a kernel.
247 245 unpack : bool [default: False]
248 246 if True, return the unpacked dict, otherwise just the string contents
249 247 of the file.
250 248 profile : str [optional]
251 249 The name of the profile to use when searching for the connection file,
252 250 if different from the current IPython session or 'default'.
253 251
254 252
255 253 Returns
256 254 -------
257 255 The connection dictionary of the current kernel, as string or dict,
258 256 depending on `unpack`.
259 257 """
260 258 if connection_file is None:
261 259 # get connection file from current kernel
262 260 cf = get_connection_file()
263 261 else:
264 262 # connection file specified, allow shortnames:
265 263 cf = find_connection_file(connection_file, profile=profile)
266 264
267 265 with open(cf) as f:
268 266 info = f.read()
269 267
270 268 if unpack:
271 269 info = json.loads(info)
272 270 # ensure key is bytes:
273 271 info['key'] = str_to_bytes(info.get('key', ''))
274 272 return info
275 273
276 274
277 275 def connect_qtconsole(connection_file=None, argv=None, profile=None):
278 276 """Connect a qtconsole to the current kernel.
279 277
280 278 This is useful for connecting a second qtconsole to a kernel, or to a
281 279 local notebook.
282 280
283 281 Parameters
284 282 ----------
285 283 connection_file : str [optional]
286 284 The connection file to be used. Can be given by absolute path, or
287 285 IPython will search in the security directory of a given profile.
288 286 If run from IPython,
289 287
290 288 If unspecified, the connection file for the currently running
291 289 IPython Kernel will be used, which is only allowed from inside a kernel.
292 290 argv : list [optional]
293 291 Any extra args to be passed to the console.
294 292 profile : str [optional]
295 293 The name of the profile to use when searching for the connection file,
296 294 if different from the current IPython session or 'default'.
297 295
298 296
299 297 Returns
300 298 -------
301 299 :class:`subprocess.Popen` instance running the qtconsole frontend
302 300 """
303 301 argv = [] if argv is None else argv
304 302
305 303 if connection_file is None:
306 304 # get connection file from current kernel
307 305 cf = get_connection_file()
308 306 else:
309 307 cf = find_connection_file(connection_file, profile=profile)
310 308
311 309 cmd = ';'.join([
312 310 "from IPython.qt.console import qtconsoleapp",
313 311 "qtconsoleapp.main()"
314 312 ])
315 313
316 314 return Popen([sys.executable, '-c', cmd, '--existing', cf] + argv,
317 315 stdout=PIPE, stderr=PIPE, close_fds=(sys.platform != 'win32'),
318 316 )
319 317
320 318
321 319 def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
322 320 """tunnel connections to a kernel via ssh
323 321
324 322 This will open four SSH tunnels from localhost on this machine to the
325 323 ports associated with the kernel. They can be either direct
326 324 localhost-localhost tunnels, or if an intermediate server is necessary,
327 325 the kernel must be listening on a public IP.
328 326
329 327 Parameters
330 328 ----------
331 329 connection_info : dict or str (path)
332 330 Either a connection dict, or the path to a JSON connection file
333 331 sshserver : str
334 332 The ssh sever to use to tunnel to the kernel. Can be a full
335 333 `user@server:port` string. ssh config aliases are respected.
336 334 sshkey : str [optional]
337 335 Path to file containing ssh key to use for authentication.
338 336 Only necessary if your ssh config does not already associate
339 337 a keyfile with the host.
340 338
341 339 Returns
342 340 -------
343 341
344 342 (shell, iopub, stdin, hb) : ints
345 343 The four ports on localhost that have been forwarded to the kernel.
346 344 """
347 345 if isinstance(connection_info, string_types):
348 346 # it's a path, unpack it
349 347 with open(connection_info) as f:
350 348 connection_info = json.loads(f.read())
351 349
352 350 cf = connection_info
353 351
354 352 lports = tunnel.select_random_ports(4)
355 353 rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port']
356 354
357 355 remote_ip = cf['ip']
358 356
359 357 if tunnel.try_passwordless_ssh(sshserver, sshkey):
360 358 password=False
361 359 else:
362 360 password = getpass("SSH Password for %s: " % cast_bytes_py2(sshserver))
363 361
364 362 for lp,rp in zip(lports, rports):
365 363 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
366 364
367 365 return tuple(lports)
368 366
369 367
370 368 #-----------------------------------------------------------------------------
371 369 # Mixin for classes that work with connection files
372 370 #-----------------------------------------------------------------------------
373 371
374 372 channel_socket_types = {
375 373 'hb' : zmq.REQ,
376 374 'shell' : zmq.DEALER,
377 375 'iopub' : zmq.SUB,
378 376 'stdin' : zmq.DEALER,
379 377 'control': zmq.DEALER,
380 378 }
381 379
382 380 port_names = [ "%s_port" % channel for channel in ('shell', 'stdin', 'iopub', 'hb', 'control')]
383 381
384 382 class ConnectionFileMixin(LoggingConfigurable):
385 383 """Mixin for configurable classes that work with connection files"""
386 384
387 385 # The addresses for the communication channels
388 386 connection_file = Unicode('', config=True,
389 387 help="""JSON file in which to store connection info [default: kernel-<pid>.json]
390 388
391 389 This file will contain the IP, ports, and authentication key needed to connect
392 390 clients to this kernel. By default, this file will be created in the security dir
393 391 of the current profile, but can be specified by absolute path.
394 392 """)
395 393 _connection_file_written = Bool(False)
396 394
397 395 transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
398 396
399 397 ip = Unicode(config=True,
400 398 help="""Set the kernel\'s IP address [default localhost].
401 399 If the IP address is something other than localhost, then
402 400 Consoles on other machines will be able to connect
403 401 to the Kernel, so be careful!"""
404 402 )
405 403
406 404 def _ip_default(self):
407 405 if self.transport == 'ipc':
408 406 if self.connection_file:
409 407 return os.path.splitext(self.connection_file)[0] + '-ipc'
410 408 else:
411 409 return 'kernel-ipc'
412 410 else:
413 411 return localhost()
414 412
415 413 def _ip_changed(self, name, old, new):
416 414 if new == '*':
417 415 self.ip = '0.0.0.0'
418 416
419 417 # protected traits
420 418
421 419 hb_port = Integer(0, config=True,
422 420 help="set the heartbeat port [default: random]")
423 421 shell_port = Integer(0, config=True,
424 422 help="set the shell (ROUTER) port [default: random]")
425 423 iopub_port = Integer(0, config=True,
426 424 help="set the iopub (PUB) port [default: random]")
427 425 stdin_port = Integer(0, config=True,
428 426 help="set the stdin (ROUTER) port [default: random]")
429 427 control_port = Integer(0, config=True,
430 428 help="set the control (ROUTER) port [default: random]")
431 429
432 430 @property
433 431 def ports(self):
434 432 return [ getattr(self, name) for name in port_names ]
435 433
436 434 # The Session to use for communication with the kernel.
437 435 session = Instance('IPython.kernel.zmq.session.Session')
438 436 def _session_default(self):
439 437 from IPython.kernel.zmq.session import Session
440 438 return Session(parent=self)
441 439
442 440 #--------------------------------------------------------------------------
443 441 # Connection and ipc file management
444 442 #--------------------------------------------------------------------------
445 443
446 444 def get_connection_info(self):
447 445 """return the connection info as a dict"""
448 446 return dict(
449 447 transport=self.transport,
450 448 ip=self.ip,
451 449 shell_port=self.shell_port,
452 450 iopub_port=self.iopub_port,
453 451 stdin_port=self.stdin_port,
454 452 hb_port=self.hb_port,
455 453 control_port=self.control_port,
456 454 signature_scheme=self.session.signature_scheme,
457 455 key=self.session.key,
458 456 )
459 457
460 458 def cleanup_connection_file(self):
461 459 """Cleanup connection file *if we wrote it*
462 460
463 461 Will not raise if the connection file was already removed somehow.
464 462 """
465 463 if self._connection_file_written:
466 464 # cleanup connection files on full shutdown of kernel we started
467 465 self._connection_file_written = False
468 466 try:
469 467 os.remove(self.connection_file)
470 468 except (IOError, OSError, AttributeError):
471 469 pass
472 470
473 471 def cleanup_ipc_files(self):
474 472 """Cleanup ipc files if we wrote them."""
475 473 if self.transport != 'ipc':
476 474 return
477 475 for port in self.ports:
478 476 ipcfile = "%s-%i" % (self.ip, port)
479 477 try:
480 478 os.remove(ipcfile)
481 479 except (IOError, OSError):
482 480 pass
483 481
484 482 def write_connection_file(self):
485 483 """Write connection info to JSON dict in self.connection_file."""
486 484 if self._connection_file_written and os.path.exists(self.connection_file):
487 485 return
488 486
489 487 self.connection_file, cfg = write_connection_file(self.connection_file,
490 488 transport=self.transport, ip=self.ip, key=self.session.key,
491 489 stdin_port=self.stdin_port, iopub_port=self.iopub_port,
492 490 shell_port=self.shell_port, hb_port=self.hb_port,
493 491 control_port=self.control_port,
494 492 signature_scheme=self.session.signature_scheme,
495 493 )
496 494 # write_connection_file also sets default ports:
497 495 for name in port_names:
498 496 setattr(self, name, cfg[name])
499 497
500 498 self._connection_file_written = True
501 499
502 500 def load_connection_file(self):
503 501 """Load connection info from JSON dict in self.connection_file."""
504 502 self.log.debug(u"Loading connection file %s", self.connection_file)
505 503 with open(self.connection_file) as f:
506 504 cfg = json.load(f)
507 505 self.transport = cfg.get('transport', self.transport)
508 506 self.ip = cfg.get('ip', self._ip_default())
509 507
510 508 for name in port_names:
511 509 if getattr(self, name) == 0 and name in cfg:
512 510 # not overridden by config or cl_args
513 511 setattr(self, name, cfg[name])
514 512
515 513 if 'key' in cfg:
516 514 self.session.key = str_to_bytes(cfg['key'])
517 515 if 'signature_scheme' in cfg:
518 516 self.session.signature_scheme = cfg['signature_scheme']
519 517
520 518 #--------------------------------------------------------------------------
521 519 # Creating connected sockets
522 520 #--------------------------------------------------------------------------
523 521
524 522 def _make_url(self, channel):
525 523 """Make a ZeroMQ URL for a given channel."""
526 524 transport = self.transport
527 525 ip = self.ip
528 526 port = getattr(self, '%s_port' % channel)
529 527
530 528 if transport == 'tcp':
531 529 return "tcp://%s:%i" % (ip, port)
532 530 else:
533 531 return "%s://%s-%s" % (transport, ip, port)
534 532
535 533 def _create_connected_socket(self, channel, identity=None):
536 534 """Create a zmq Socket and connect it to the kernel."""
537 535 url = self._make_url(channel)
538 536 socket_type = channel_socket_types[channel]
539 537 self.log.debug("Connecting to: %s" % url)
540 538 sock = self.context.socket(socket_type)
541 539 # set linger to 1s to prevent hangs at exit
542 540 sock.linger = 1000
543 541 if identity:
544 542 sock.identity = identity
545 543 sock.connect(url)
546 544 return sock
547 545
548 546 def connect_iopub(self, identity=None):
549 547 """return zmq Socket connected to the IOPub channel"""
550 548 sock = self._create_connected_socket('iopub', identity=identity)
551 549 sock.setsockopt(zmq.SUBSCRIBE, b'')
552 550 return sock
553 551
554 552 def connect_shell(self, identity=None):
555 553 """return zmq Socket connected to the Shell channel"""
556 554 return self._create_connected_socket('shell', identity=identity)
557 555
558 556 def connect_stdin(self, identity=None):
559 557 """return zmq Socket connected to the StdIn channel"""
560 558 return self._create_connected_socket('stdin', identity=identity)
561 559
562 560 def connect_hb(self, identity=None):
563 561 """return zmq Socket connected to the Heartbeat channel"""
564 562 return self._create_connected_socket('hb', identity=identity)
565 563
566 564 def connect_control(self, identity=None):
567 565 """return zmq Socket connected to the Heartbeat channel"""
568 566 return self._create_connected_socket('control', identity=identity)
569 567
570 568
571 569 __all__ = [
572 570 'write_connection_file',
573 571 'get_connection_file',
574 572 'find_connection_file',
575 573 'get_connection_info',
576 574 'connect_qtconsole',
577 575 'tunnel_to_kernel',
578 576 ]
@@ -1,1868 +1,1868 b''
1 1 """A semi-synchronous Client for IPython parallel"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import print_function
7 7
8 8 import os
9 9 import json
10 10 import sys
11 11 from threading import Thread, Event
12 12 import time
13 13 import warnings
14 14 from datetime import datetime
15 15 from getpass import getpass
16 16 from pprint import pprint
17 17
18 18 pjoin = os.path.join
19 19
20 20 import zmq
21 from zmq.ssh import tunnel
21 22
22 23 from IPython.config.configurable import MultipleInstanceError
23 24 from IPython.core.application import BaseIPythonApplication
24 25 from IPython.core.profiledir import ProfileDir, ProfileDirError
25 26
26 27 from IPython.utils.capture import RichOutput
27 28 from IPython.utils.coloransi import TermColors
28 29 from IPython.utils.jsonutil import rekey, extract_dates, parse_date
29 30 from IPython.utils.localinterfaces import localhost, is_local_ip
30 31 from IPython.utils.path import get_ipython_dir
31 32 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
32 33 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
33 34 Dict, List, Bool, Set, Any)
34 35 from IPython.external.decorator import decorator
35 from IPython.external.ssh import tunnel
36 36
37 37 from IPython.parallel import Reference
38 38 from IPython.parallel import error
39 39 from IPython.parallel import util
40 40
41 41 from IPython.kernel.zmq.session import Session, Message
42 42 from IPython.kernel.zmq import serialize
43 43
44 44 from .asyncresult import AsyncResult, AsyncHubResult
45 45 from .view import DirectView, LoadBalancedView
46 46
47 47 #--------------------------------------------------------------------------
48 48 # Decorators for Client methods
49 49 #--------------------------------------------------------------------------
50 50
51 51 @decorator
52 52 def spin_first(f, self, *args, **kwargs):
53 53 """Call spin() to sync state prior to calling the method."""
54 54 self.spin()
55 55 return f(self, *args, **kwargs)
56 56
57 57
58 58 #--------------------------------------------------------------------------
59 59 # Classes
60 60 #--------------------------------------------------------------------------
61 61
62 62
63 63 class ExecuteReply(RichOutput):
64 64 """wrapper for finished Execute results"""
65 65 def __init__(self, msg_id, content, metadata):
66 66 self.msg_id = msg_id
67 67 self._content = content
68 68 self.execution_count = content['execution_count']
69 69 self.metadata = metadata
70 70
71 71 # RichOutput overrides
72 72
73 73 @property
74 74 def source(self):
75 75 execute_result = self.metadata['execute_result']
76 76 if execute_result:
77 77 return execute_result.get('source', '')
78 78
79 79 @property
80 80 def data(self):
81 81 execute_result = self.metadata['execute_result']
82 82 if execute_result:
83 83 return execute_result.get('data', {})
84 84
85 85 @property
86 86 def _metadata(self):
87 87 execute_result = self.metadata['execute_result']
88 88 if execute_result:
89 89 return execute_result.get('metadata', {})
90 90
91 91 def display(self):
92 92 from IPython.display import publish_display_data
93 93 publish_display_data(self.data, self.metadata)
94 94
95 95 def _repr_mime_(self, mime):
96 96 if mime not in self.data:
97 97 return
98 98 data = self.data[mime]
99 99 if mime in self._metadata:
100 100 return data, self._metadata[mime]
101 101 else:
102 102 return data
103 103
104 104 def __getitem__(self, key):
105 105 return self.metadata[key]
106 106
107 107 def __getattr__(self, key):
108 108 if key not in self.metadata:
109 109 raise AttributeError(key)
110 110 return self.metadata[key]
111 111
112 112 def __repr__(self):
113 113 execute_result = self.metadata['execute_result'] or {'data':{}}
114 114 text_out = execute_result['data'].get('text/plain', '')
115 115 if len(text_out) > 32:
116 116 text_out = text_out[:29] + '...'
117 117
118 118 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
119 119
120 120 def _repr_pretty_(self, p, cycle):
121 121 execute_result = self.metadata['execute_result'] or {'data':{}}
122 122 text_out = execute_result['data'].get('text/plain', '')
123 123
124 124 if not text_out:
125 125 return
126 126
127 127 try:
128 128 ip = get_ipython()
129 129 except NameError:
130 130 colors = "NoColor"
131 131 else:
132 132 colors = ip.colors
133 133
134 134 if colors == "NoColor":
135 135 out = normal = ""
136 136 else:
137 137 out = TermColors.Red
138 138 normal = TermColors.Normal
139 139
140 140 if '\n' in text_out and not text_out.startswith('\n'):
141 141 # add newline for multiline reprs
142 142 text_out = '\n' + text_out
143 143
144 144 p.text(
145 145 out + u'Out[%i:%i]: ' % (
146 146 self.metadata['engine_id'], self.execution_count
147 147 ) + normal + text_out
148 148 )
149 149
150 150
151 151 class Metadata(dict):
152 152 """Subclass of dict for initializing metadata values.
153 153
154 154 Attribute access works on keys.
155 155
156 156 These objects have a strict set of keys - errors will raise if you try
157 157 to add new keys.
158 158 """
159 159 def __init__(self, *args, **kwargs):
160 160 dict.__init__(self)
161 161 md = {'msg_id' : None,
162 162 'submitted' : None,
163 163 'started' : None,
164 164 'completed' : None,
165 165 'received' : None,
166 166 'engine_uuid' : None,
167 167 'engine_id' : None,
168 168 'follow' : None,
169 169 'after' : None,
170 170 'status' : None,
171 171
172 172 'execute_input' : None,
173 173 'execute_result' : None,
174 174 'error' : None,
175 175 'stdout' : '',
176 176 'stderr' : '',
177 177 'outputs' : [],
178 178 'data': {},
179 179 'outputs_ready' : False,
180 180 }
181 181 self.update(md)
182 182 self.update(dict(*args, **kwargs))
183 183
184 184 def __getattr__(self, key):
185 185 """getattr aliased to getitem"""
186 186 if key in self:
187 187 return self[key]
188 188 else:
189 189 raise AttributeError(key)
190 190
191 191 def __setattr__(self, key, value):
192 192 """setattr aliased to setitem, with strict"""
193 193 if key in self:
194 194 self[key] = value
195 195 else:
196 196 raise AttributeError(key)
197 197
198 198 def __setitem__(self, key, value):
199 199 """strict static key enforcement"""
200 200 if key in self:
201 201 dict.__setitem__(self, key, value)
202 202 else:
203 203 raise KeyError(key)
204 204
205 205
206 206 class Client(HasTraits):
207 207 """A semi-synchronous client to the IPython ZMQ cluster
208 208
209 209 Parameters
210 210 ----------
211 211
212 212 url_file : str/unicode; path to ipcontroller-client.json
213 213 This JSON file should contain all the information needed to connect to a cluster,
214 214 and is likely the only argument needed.
215 215 Connection information for the Hub's registration. If a json connector
216 216 file is given, then likely no further configuration is necessary.
217 217 [Default: use profile]
218 218 profile : bytes
219 219 The name of the Cluster profile to be used to find connector information.
220 220 If run from an IPython application, the default profile will be the same
221 221 as the running application, otherwise it will be 'default'.
222 222 cluster_id : str
223 223 String id to added to runtime files, to prevent name collisions when using
224 224 multiple clusters with a single profile simultaneously.
225 225 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
226 226 Since this is text inserted into filenames, typical recommendations apply:
227 227 Simple character strings are ideal, and spaces are not recommended (but
228 228 should generally work)
229 229 context : zmq.Context
230 230 Pass an existing zmq.Context instance, otherwise the client will create its own.
231 231 debug : bool
232 232 flag for lots of message printing for debug purposes
233 233 timeout : int/float
234 234 time (in seconds) to wait for connection replies from the Hub
235 235 [Default: 10]
236 236
237 237 #-------------- session related args ----------------
238 238
239 239 config : Config object
240 240 If specified, this will be relayed to the Session for configuration
241 241 username : str
242 242 set username for the session object
243 243
244 244 #-------------- ssh related args ----------------
245 245 # These are args for configuring the ssh tunnel to be used
246 246 # credentials are used to forward connections over ssh to the Controller
247 247 # Note that the ip given in `addr` needs to be relative to sshserver
248 248 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
249 249 # and set sshserver as the same machine the Controller is on. However,
250 250 # the only requirement is that sshserver is able to see the Controller
251 251 # (i.e. is within the same trusted network).
252 252
253 253 sshserver : str
254 254 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
255 255 If keyfile or password is specified, and this is not, it will default to
256 256 the ip given in addr.
257 257 sshkey : str; path to ssh private key file
258 258 This specifies a key to be used in ssh login, default None.
259 259 Regular default ssh keys will be used without specifying this argument.
260 260 password : str
261 261 Your ssh password to sshserver. Note that if this is left None,
262 262 you will be prompted for it if passwordless key based login is unavailable.
263 263 paramiko : bool
264 264 flag for whether to use paramiko instead of shell ssh for tunneling.
265 265 [default: True on win32, False else]
266 266
267 267
268 268 Attributes
269 269 ----------
270 270
271 271 ids : list of int engine IDs
272 272 requesting the ids attribute always synchronizes
273 273 the registration state. To request ids without synchronization,
274 274 use semi-private _ids attributes.
275 275
276 276 history : list of msg_ids
277 277 a list of msg_ids, keeping track of all the execution
278 278 messages you have submitted in order.
279 279
280 280 outstanding : set of msg_ids
281 281 a set of msg_ids that have been submitted, but whose
282 282 results have not yet been received.
283 283
284 284 results : dict
285 285 a dict of all our results, keyed by msg_id
286 286
287 287 block : bool
288 288 determines default behavior when block not specified
289 289 in execution methods
290 290
291 291 Methods
292 292 -------
293 293
294 294 spin
295 295 flushes incoming results and registration state changes
296 296 control methods spin, and requesting `ids` also ensures up to date
297 297
298 298 wait
299 299 wait on one or more msg_ids
300 300
301 301 execution methods
302 302 apply
303 303 legacy: execute, run
304 304
305 305 data movement
306 306 push, pull, scatter, gather
307 307
308 308 query methods
309 309 queue_status, get_result, purge, result_status
310 310
311 311 control methods
312 312 abort, shutdown
313 313
314 314 """
315 315
316 316
317 317 block = Bool(False)
318 318 outstanding = Set()
319 319 results = Instance('collections.defaultdict', (dict,))
320 320 metadata = Instance('collections.defaultdict', (Metadata,))
321 321 history = List()
322 322 debug = Bool(False)
323 323 _spin_thread = Any()
324 324 _stop_spinning = Any()
325 325
326 326 profile=Unicode()
327 327 def _profile_default(self):
328 328 if BaseIPythonApplication.initialized():
329 329 # an IPython app *might* be running, try to get its profile
330 330 try:
331 331 return BaseIPythonApplication.instance().profile
332 332 except (AttributeError, MultipleInstanceError):
333 333 # could be a *different* subclass of config.Application,
334 334 # which would raise one of these two errors.
335 335 return u'default'
336 336 else:
337 337 return u'default'
338 338
339 339
340 340 _outstanding_dict = Instance('collections.defaultdict', (set,))
341 341 _ids = List()
342 342 _connected=Bool(False)
343 343 _ssh=Bool(False)
344 344 _context = Instance('zmq.Context')
345 345 _config = Dict()
346 346 _engines=Instance(util.ReverseDict, (), {})
347 347 # _hub_socket=Instance('zmq.Socket')
348 348 _query_socket=Instance('zmq.Socket')
349 349 _control_socket=Instance('zmq.Socket')
350 350 _iopub_socket=Instance('zmq.Socket')
351 351 _notification_socket=Instance('zmq.Socket')
352 352 _mux_socket=Instance('zmq.Socket')
353 353 _task_socket=Instance('zmq.Socket')
354 354 _task_scheme=Unicode()
355 355 _closed = False
356 356 _ignored_control_replies=Integer(0)
357 357 _ignored_hub_replies=Integer(0)
358 358
359 359 def __new__(self, *args, **kw):
360 360 # don't raise on positional args
361 361 return HasTraits.__new__(self, **kw)
362 362
363 363 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
364 364 context=None, debug=False,
365 365 sshserver=None, sshkey=None, password=None, paramiko=None,
366 366 timeout=10, cluster_id=None, **extra_args
367 367 ):
368 368 if profile:
369 369 super(Client, self).__init__(debug=debug, profile=profile)
370 370 else:
371 371 super(Client, self).__init__(debug=debug)
372 372 if context is None:
373 373 context = zmq.Context.instance()
374 374 self._context = context
375 375 self._stop_spinning = Event()
376 376
377 377 if 'url_or_file' in extra_args:
378 378 url_file = extra_args['url_or_file']
379 379 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
380 380
381 381 if url_file and util.is_url(url_file):
382 382 raise ValueError("single urls cannot be specified, url-files must be used.")
383 383
384 384 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
385 385
386 386 if self._cd is not None:
387 387 if url_file is None:
388 388 if not cluster_id:
389 389 client_json = 'ipcontroller-client.json'
390 390 else:
391 391 client_json = 'ipcontroller-%s-client.json' % cluster_id
392 392 url_file = pjoin(self._cd.security_dir, client_json)
393 393 if url_file is None:
394 394 raise ValueError(
395 395 "I can't find enough information to connect to a hub!"
396 396 " Please specify at least one of url_file or profile."
397 397 )
398 398
399 399 with open(url_file) as f:
400 400 cfg = json.load(f)
401 401
402 402 self._task_scheme = cfg['task_scheme']
403 403
404 404 # sync defaults from args, json:
405 405 if sshserver:
406 406 cfg['ssh'] = sshserver
407 407
408 408 location = cfg.setdefault('location', None)
409 409
410 410 proto,addr = cfg['interface'].split('://')
411 411 addr = util.disambiguate_ip_address(addr, location)
412 412 cfg['interface'] = "%s://%s" % (proto, addr)
413 413
414 414 # turn interface,port into full urls:
415 415 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
416 416 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
417 417
418 418 url = cfg['registration']
419 419
420 420 if location is not None and addr == localhost():
421 421 # location specified, and connection is expected to be local
422 422 if not is_local_ip(location) and not sshserver:
423 423 # load ssh from JSON *only* if the controller is not on
424 424 # this machine
425 425 sshserver=cfg['ssh']
426 426 if not is_local_ip(location) and not sshserver:
427 427 # warn if no ssh specified, but SSH is probably needed
428 428 # This is only a warning, because the most likely cause
429 429 # is a local Controller on a laptop whose IP is dynamic
430 430 warnings.warn("""
431 431 Controller appears to be listening on localhost, but not on this machine.
432 432 If this is true, you should specify Client(...,sshserver='you@%s')
433 433 or instruct your controller to listen on an external IP."""%location,
434 434 RuntimeWarning)
435 435 elif not sshserver:
436 436 # otherwise sync with cfg
437 437 sshserver = cfg['ssh']
438 438
439 439 self._config = cfg
440 440
441 441 self._ssh = bool(sshserver or sshkey or password)
442 442 if self._ssh and sshserver is None:
443 443 # default to ssh via localhost
444 444 sshserver = addr
445 445 if self._ssh and password is None:
446 446 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
447 447 password=False
448 448 else:
449 449 password = getpass("SSH Password for %s: "%sshserver)
450 450 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
451 451
452 452 # configure and construct the session
453 453 try:
454 454 extra_args['packer'] = cfg['pack']
455 455 extra_args['unpacker'] = cfg['unpack']
456 456 extra_args['key'] = cast_bytes(cfg['key'])
457 457 extra_args['signature_scheme'] = cfg['signature_scheme']
458 458 except KeyError as exc:
459 459 msg = '\n'.join([
460 460 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
461 461 "If you are reusing connection files, remove them and start ipcontroller again."
462 462 ])
463 463 raise ValueError(msg.format(exc.message))
464 464
465 465 self.session = Session(**extra_args)
466 466
467 467 self._query_socket = self._context.socket(zmq.DEALER)
468 468
469 469 if self._ssh:
470 470 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
471 471 else:
472 472 self._query_socket.connect(cfg['registration'])
473 473
474 474 self.session.debug = self.debug
475 475
476 476 self._notification_handlers = {'registration_notification' : self._register_engine,
477 477 'unregistration_notification' : self._unregister_engine,
478 478 'shutdown_notification' : lambda msg: self.close(),
479 479 }
480 480 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
481 481 'apply_reply' : self._handle_apply_reply}
482 482
483 483 try:
484 484 self._connect(sshserver, ssh_kwargs, timeout)
485 485 except:
486 486 self.close(linger=0)
487 487 raise
488 488
489 489 # last step: setup magics, if we are in IPython:
490 490
491 491 try:
492 492 ip = get_ipython()
493 493 except NameError:
494 494 return
495 495 else:
496 496 if 'px' not in ip.magics_manager.magics:
497 497 # in IPython but we are the first Client.
498 498 # activate a default view for parallel magics.
499 499 self.activate()
500 500
501 501 def __del__(self):
502 502 """cleanup sockets, but _not_ context."""
503 503 self.close()
504 504
505 505 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
506 506 if ipython_dir is None:
507 507 ipython_dir = get_ipython_dir()
508 508 if profile_dir is not None:
509 509 try:
510 510 self._cd = ProfileDir.find_profile_dir(profile_dir)
511 511 return
512 512 except ProfileDirError:
513 513 pass
514 514 elif profile is not None:
515 515 try:
516 516 self._cd = ProfileDir.find_profile_dir_by_name(
517 517 ipython_dir, profile)
518 518 return
519 519 except ProfileDirError:
520 520 pass
521 521 self._cd = None
522 522
523 523 def _update_engines(self, engines):
524 524 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
525 525 for k,v in iteritems(engines):
526 526 eid = int(k)
527 527 if eid not in self._engines:
528 528 self._ids.append(eid)
529 529 self._engines[eid] = v
530 530 self._ids = sorted(self._ids)
531 531 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
532 532 self._task_scheme == 'pure' and self._task_socket:
533 533 self._stop_scheduling_tasks()
534 534
535 535 def _stop_scheduling_tasks(self):
536 536 """Stop scheduling tasks because an engine has been unregistered
537 537 from a pure ZMQ scheduler.
538 538 """
539 539 self._task_socket.close()
540 540 self._task_socket = None
541 541 msg = "An engine has been unregistered, and we are using pure " +\
542 542 "ZMQ task scheduling. Task farming will be disabled."
543 543 if self.outstanding:
544 544 msg += " If you were running tasks when this happened, " +\
545 545 "some `outstanding` msg_ids may never resolve."
546 546 warnings.warn(msg, RuntimeWarning)
547 547
548 548 def _build_targets(self, targets):
549 549 """Turn valid target IDs or 'all' into two lists:
550 550 (int_ids, uuids).
551 551 """
552 552 if not self._ids:
553 553 # flush notification socket if no engines yet, just in case
554 554 if not self.ids:
555 555 raise error.NoEnginesRegistered("Can't build targets without any engines")
556 556
557 557 if targets is None:
558 558 targets = self._ids
559 559 elif isinstance(targets, string_types):
560 560 if targets.lower() == 'all':
561 561 targets = self._ids
562 562 else:
563 563 raise TypeError("%r not valid str target, must be 'all'"%(targets))
564 564 elif isinstance(targets, int):
565 565 if targets < 0:
566 566 targets = self.ids[targets]
567 567 if targets not in self._ids:
568 568 raise IndexError("No such engine: %i"%targets)
569 569 targets = [targets]
570 570
571 571 if isinstance(targets, slice):
572 572 indices = list(range(len(self._ids))[targets])
573 573 ids = self.ids
574 574 targets = [ ids[i] for i in indices ]
575 575
576 576 if not isinstance(targets, (tuple, list, xrange)):
577 577 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
578 578
579 579 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
580 580
581 581 def _connect(self, sshserver, ssh_kwargs, timeout):
582 582 """setup all our socket connections to the cluster. This is called from
583 583 __init__."""
584 584
585 585 # Maybe allow reconnecting?
586 586 if self._connected:
587 587 return
588 588 self._connected=True
589 589
590 590 def connect_socket(s, url):
591 591 if self._ssh:
592 592 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
593 593 else:
594 594 return s.connect(url)
595 595
596 596 self.session.send(self._query_socket, 'connection_request')
597 597 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
598 598 poller = zmq.Poller()
599 599 poller.register(self._query_socket, zmq.POLLIN)
600 600 # poll expects milliseconds, timeout is seconds
601 601 evts = poller.poll(timeout*1000)
602 602 if not evts:
603 603 raise error.TimeoutError("Hub connection request timed out")
604 604 idents,msg = self.session.recv(self._query_socket,mode=0)
605 605 if self.debug:
606 606 pprint(msg)
607 607 content = msg['content']
608 608 # self._config['registration'] = dict(content)
609 609 cfg = self._config
610 610 if content['status'] == 'ok':
611 611 self._mux_socket = self._context.socket(zmq.DEALER)
612 612 connect_socket(self._mux_socket, cfg['mux'])
613 613
614 614 self._task_socket = self._context.socket(zmq.DEALER)
615 615 connect_socket(self._task_socket, cfg['task'])
616 616
617 617 self._notification_socket = self._context.socket(zmq.SUB)
618 618 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
619 619 connect_socket(self._notification_socket, cfg['notification'])
620 620
621 621 self._control_socket = self._context.socket(zmq.DEALER)
622 622 connect_socket(self._control_socket, cfg['control'])
623 623
624 624 self._iopub_socket = self._context.socket(zmq.SUB)
625 625 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
626 626 connect_socket(self._iopub_socket, cfg['iopub'])
627 627
628 628 self._update_engines(dict(content['engines']))
629 629 else:
630 630 self._connected = False
631 631 raise Exception("Failed to connect!")
632 632
633 633 #--------------------------------------------------------------------------
634 634 # handlers and callbacks for incoming messages
635 635 #--------------------------------------------------------------------------
636 636
637 637 def _unwrap_exception(self, content):
638 638 """unwrap exception, and remap engine_id to int."""
639 639 e = error.unwrap_exception(content)
640 640 # print e.traceback
641 641 if e.engine_info:
642 642 e_uuid = e.engine_info['engine_uuid']
643 643 eid = self._engines[e_uuid]
644 644 e.engine_info['engine_id'] = eid
645 645 return e
646 646
647 647 def _extract_metadata(self, msg):
648 648 header = msg['header']
649 649 parent = msg['parent_header']
650 650 msg_meta = msg['metadata']
651 651 content = msg['content']
652 652 md = {'msg_id' : parent['msg_id'],
653 653 'received' : datetime.now(),
654 654 'engine_uuid' : msg_meta.get('engine', None),
655 655 'follow' : msg_meta.get('follow', []),
656 656 'after' : msg_meta.get('after', []),
657 657 'status' : content['status'],
658 658 }
659 659
660 660 if md['engine_uuid'] is not None:
661 661 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
662 662
663 663 if 'date' in parent:
664 664 md['submitted'] = parent['date']
665 665 if 'started' in msg_meta:
666 666 md['started'] = parse_date(msg_meta['started'])
667 667 if 'date' in header:
668 668 md['completed'] = header['date']
669 669 return md
670 670
671 671 def _register_engine(self, msg):
672 672 """Register a new engine, and update our connection info."""
673 673 content = msg['content']
674 674 eid = content['id']
675 675 d = {eid : content['uuid']}
676 676 self._update_engines(d)
677 677
678 678 def _unregister_engine(self, msg):
679 679 """Unregister an engine that has died."""
680 680 content = msg['content']
681 681 eid = int(content['id'])
682 682 if eid in self._ids:
683 683 self._ids.remove(eid)
684 684 uuid = self._engines.pop(eid)
685 685
686 686 self._handle_stranded_msgs(eid, uuid)
687 687
688 688 if self._task_socket and self._task_scheme == 'pure':
689 689 self._stop_scheduling_tasks()
690 690
691 691 def _handle_stranded_msgs(self, eid, uuid):
692 692 """Handle messages known to be on an engine when the engine unregisters.
693 693
694 694 It is possible that this will fire prematurely - that is, an engine will
695 695 go down after completing a result, and the client will be notified
696 696 of the unregistration and later receive the successful result.
697 697 """
698 698
699 699 outstanding = self._outstanding_dict[uuid]
700 700
701 701 for msg_id in list(outstanding):
702 702 if msg_id in self.results:
703 703 # we already
704 704 continue
705 705 try:
706 706 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
707 707 except:
708 708 content = error.wrap_exception()
709 709 # build a fake message:
710 710 msg = self.session.msg('apply_reply', content=content)
711 711 msg['parent_header']['msg_id'] = msg_id
712 712 msg['metadata']['engine'] = uuid
713 713 self._handle_apply_reply(msg)
714 714
715 715 def _handle_execute_reply(self, msg):
716 716 """Save the reply to an execute_request into our results.
717 717
718 718 execute messages are never actually used. apply is used instead.
719 719 """
720 720
721 721 parent = msg['parent_header']
722 722 msg_id = parent['msg_id']
723 723 if msg_id not in self.outstanding:
724 724 if msg_id in self.history:
725 725 print("got stale result: %s"%msg_id)
726 726 else:
727 727 print("got unknown result: %s"%msg_id)
728 728 else:
729 729 self.outstanding.remove(msg_id)
730 730
731 731 content = msg['content']
732 732 header = msg['header']
733 733
734 734 # construct metadata:
735 735 md = self.metadata[msg_id]
736 736 md.update(self._extract_metadata(msg))
737 737 # is this redundant?
738 738 self.metadata[msg_id] = md
739 739
740 740 e_outstanding = self._outstanding_dict[md['engine_uuid']]
741 741 if msg_id in e_outstanding:
742 742 e_outstanding.remove(msg_id)
743 743
744 744 # construct result:
745 745 if content['status'] == 'ok':
746 746 self.results[msg_id] = ExecuteReply(msg_id, content, md)
747 747 elif content['status'] == 'aborted':
748 748 self.results[msg_id] = error.TaskAborted(msg_id)
749 749 elif content['status'] == 'resubmitted':
750 750 # TODO: handle resubmission
751 751 pass
752 752 else:
753 753 self.results[msg_id] = self._unwrap_exception(content)
754 754
755 755 def _handle_apply_reply(self, msg):
756 756 """Save the reply to an apply_request into our results."""
757 757 parent = msg['parent_header']
758 758 msg_id = parent['msg_id']
759 759 if msg_id not in self.outstanding:
760 760 if msg_id in self.history:
761 761 print("got stale result: %s"%msg_id)
762 762 print(self.results[msg_id])
763 763 print(msg)
764 764 else:
765 765 print("got unknown result: %s"%msg_id)
766 766 else:
767 767 self.outstanding.remove(msg_id)
768 768 content = msg['content']
769 769 header = msg['header']
770 770
771 771 # construct metadata:
772 772 md = self.metadata[msg_id]
773 773 md.update(self._extract_metadata(msg))
774 774 # is this redundant?
775 775 self.metadata[msg_id] = md
776 776
777 777 e_outstanding = self._outstanding_dict[md['engine_uuid']]
778 778 if msg_id in e_outstanding:
779 779 e_outstanding.remove(msg_id)
780 780
781 781 # construct result:
782 782 if content['status'] == 'ok':
783 783 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
784 784 elif content['status'] == 'aborted':
785 785 self.results[msg_id] = error.TaskAborted(msg_id)
786 786 elif content['status'] == 'resubmitted':
787 787 # TODO: handle resubmission
788 788 pass
789 789 else:
790 790 self.results[msg_id] = self._unwrap_exception(content)
791 791
792 792 def _flush_notifications(self):
793 793 """Flush notifications of engine registrations waiting
794 794 in ZMQ queue."""
795 795 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
796 796 while msg is not None:
797 797 if self.debug:
798 798 pprint(msg)
799 799 msg_type = msg['header']['msg_type']
800 800 handler = self._notification_handlers.get(msg_type, None)
801 801 if handler is None:
802 802 raise Exception("Unhandled message type: %s" % msg_type)
803 803 else:
804 804 handler(msg)
805 805 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
806 806
807 807 def _flush_results(self, sock):
808 808 """Flush task or queue results waiting in ZMQ queue."""
809 809 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
810 810 while msg is not None:
811 811 if self.debug:
812 812 pprint(msg)
813 813 msg_type = msg['header']['msg_type']
814 814 handler = self._queue_handlers.get(msg_type, None)
815 815 if handler is None:
816 816 raise Exception("Unhandled message type: %s" % msg_type)
817 817 else:
818 818 handler(msg)
819 819 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
820 820
821 821 def _flush_control(self, sock):
822 822 """Flush replies from the control channel waiting
823 823 in the ZMQ queue.
824 824
825 825 Currently: ignore them."""
826 826 if self._ignored_control_replies <= 0:
827 827 return
828 828 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
829 829 while msg is not None:
830 830 self._ignored_control_replies -= 1
831 831 if self.debug:
832 832 pprint(msg)
833 833 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
834 834
835 835 def _flush_ignored_control(self):
836 836 """flush ignored control replies"""
837 837 while self._ignored_control_replies > 0:
838 838 self.session.recv(self._control_socket)
839 839 self._ignored_control_replies -= 1
840 840
841 841 def _flush_ignored_hub_replies(self):
842 842 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
843 843 while msg is not None:
844 844 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
845 845
846 846 def _flush_iopub(self, sock):
847 847 """Flush replies from the iopub channel waiting
848 848 in the ZMQ queue.
849 849 """
850 850 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
851 851 while msg is not None:
852 852 if self.debug:
853 853 pprint(msg)
854 854 parent = msg['parent_header']
855 855 # ignore IOPub messages with no parent.
856 856 # Caused by print statements or warnings from before the first execution.
857 857 if not parent:
858 858 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
859 859 continue
860 860 msg_id = parent['msg_id']
861 861 content = msg['content']
862 862 header = msg['header']
863 863 msg_type = msg['header']['msg_type']
864 864
865 865 # init metadata:
866 866 md = self.metadata[msg_id]
867 867
868 868 if msg_type == 'stream':
869 869 name = content['name']
870 870 s = md[name] or ''
871 871 md[name] = s + content['data']
872 872 elif msg_type == 'error':
873 873 md.update({'error' : self._unwrap_exception(content)})
874 874 elif msg_type == 'execute_input':
875 875 md.update({'execute_input' : content['code']})
876 876 elif msg_type == 'display_data':
877 877 md['outputs'].append(content)
878 878 elif msg_type == 'execute_result':
879 879 md['execute_result'] = content
880 880 elif msg_type == 'data_message':
881 881 data, remainder = serialize.unserialize_object(msg['buffers'])
882 882 md['data'].update(data)
883 883 elif msg_type == 'status':
884 884 # idle message comes after all outputs
885 885 if content['execution_state'] == 'idle':
886 886 md['outputs_ready'] = True
887 887 else:
888 888 # unhandled msg_type (status, etc.)
889 889 pass
890 890
891 891 # reduntant?
892 892 self.metadata[msg_id] = md
893 893
894 894 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
895 895
896 896 #--------------------------------------------------------------------------
897 897 # len, getitem
898 898 #--------------------------------------------------------------------------
899 899
900 900 def __len__(self):
901 901 """len(client) returns # of engines."""
902 902 return len(self.ids)
903 903
904 904 def __getitem__(self, key):
905 905 """index access returns DirectView multiplexer objects
906 906
907 907 Must be int, slice, or list/tuple/xrange of ints"""
908 908 if not isinstance(key, (int, slice, tuple, list, xrange)):
909 909 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
910 910 else:
911 911 return self.direct_view(key)
912 912
913 913 def __iter__(self):
914 914 """Since we define getitem, Client is iterable
915 915
916 916 but unless we also define __iter__, it won't work correctly unless engine IDs
917 917 start at zero and are continuous.
918 918 """
919 919 for eid in self.ids:
920 920 yield self.direct_view(eid)
921 921
922 922 #--------------------------------------------------------------------------
923 923 # Begin public methods
924 924 #--------------------------------------------------------------------------
925 925
926 926 @property
927 927 def ids(self):
928 928 """Always up-to-date ids property."""
929 929 self._flush_notifications()
930 930 # always copy:
931 931 return list(self._ids)
932 932
933 933 def activate(self, targets='all', suffix=''):
934 934 """Create a DirectView and register it with IPython magics
935 935
936 936 Defines the magics `%px, %autopx, %pxresult, %%px`
937 937
938 938 Parameters
939 939 ----------
940 940
941 941 targets: int, list of ints, or 'all'
942 942 The engines on which the view's magics will run
943 943 suffix: str [default: '']
944 944 The suffix, if any, for the magics. This allows you to have
945 945 multiple views associated with parallel magics at the same time.
946 946
947 947 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
948 948 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
949 949 on engine 0.
950 950 """
951 951 view = self.direct_view(targets)
952 952 view.block = True
953 953 view.activate(suffix)
954 954 return view
955 955
956 956 def close(self, linger=None):
957 957 """Close my zmq Sockets
958 958
959 959 If `linger`, set the zmq LINGER socket option,
960 960 which allows discarding of messages.
961 961 """
962 962 if self._closed:
963 963 return
964 964 self.stop_spin_thread()
965 965 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
966 966 for name in snames:
967 967 socket = getattr(self, name)
968 968 if socket is not None and not socket.closed:
969 969 if linger is not None:
970 970 socket.close(linger=linger)
971 971 else:
972 972 socket.close()
973 973 self._closed = True
974 974
975 975 def _spin_every(self, interval=1):
976 976 """target func for use in spin_thread"""
977 977 while True:
978 978 if self._stop_spinning.is_set():
979 979 return
980 980 time.sleep(interval)
981 981 self.spin()
982 982
983 983 def spin_thread(self, interval=1):
984 984 """call Client.spin() in a background thread on some regular interval
985 985
986 986 This helps ensure that messages don't pile up too much in the zmq queue
987 987 while you are working on other things, or just leaving an idle terminal.
988 988
989 989 It also helps limit potential padding of the `received` timestamp
990 990 on AsyncResult objects, used for timings.
991 991
992 992 Parameters
993 993 ----------
994 994
995 995 interval : float, optional
996 996 The interval on which to spin the client in the background thread
997 997 (simply passed to time.sleep).
998 998
999 999 Notes
1000 1000 -----
1001 1001
1002 1002 For precision timing, you may want to use this method to put a bound
1003 1003 on the jitter (in seconds) in `received` timestamps used
1004 1004 in AsyncResult.wall_time.
1005 1005
1006 1006 """
1007 1007 if self._spin_thread is not None:
1008 1008 self.stop_spin_thread()
1009 1009 self._stop_spinning.clear()
1010 1010 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1011 1011 self._spin_thread.daemon = True
1012 1012 self._spin_thread.start()
1013 1013
1014 1014 def stop_spin_thread(self):
1015 1015 """stop background spin_thread, if any"""
1016 1016 if self._spin_thread is not None:
1017 1017 self._stop_spinning.set()
1018 1018 self._spin_thread.join()
1019 1019 self._spin_thread = None
1020 1020
1021 1021 def spin(self):
1022 1022 """Flush any registration notifications and execution results
1023 1023 waiting in the ZMQ queue.
1024 1024 """
1025 1025 if self._notification_socket:
1026 1026 self._flush_notifications()
1027 1027 if self._iopub_socket:
1028 1028 self._flush_iopub(self._iopub_socket)
1029 1029 if self._mux_socket:
1030 1030 self._flush_results(self._mux_socket)
1031 1031 if self._task_socket:
1032 1032 self._flush_results(self._task_socket)
1033 1033 if self._control_socket:
1034 1034 self._flush_control(self._control_socket)
1035 1035 if self._query_socket:
1036 1036 self._flush_ignored_hub_replies()
1037 1037
1038 1038 def wait(self, jobs=None, timeout=-1):
1039 1039 """waits on one or more `jobs`, for up to `timeout` seconds.
1040 1040
1041 1041 Parameters
1042 1042 ----------
1043 1043
1044 1044 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1045 1045 ints are indices to self.history
1046 1046 strs are msg_ids
1047 1047 default: wait on all outstanding messages
1048 1048 timeout : float
1049 1049 a time in seconds, after which to give up.
1050 1050 default is -1, which means no timeout
1051 1051
1052 1052 Returns
1053 1053 -------
1054 1054
1055 1055 True : when all msg_ids are done
1056 1056 False : timeout reached, some msg_ids still outstanding
1057 1057 """
1058 1058 tic = time.time()
1059 1059 if jobs is None:
1060 1060 theids = self.outstanding
1061 1061 else:
1062 1062 if isinstance(jobs, string_types + (int, AsyncResult)):
1063 1063 jobs = [jobs]
1064 1064 theids = set()
1065 1065 for job in jobs:
1066 1066 if isinstance(job, int):
1067 1067 # index access
1068 1068 job = self.history[job]
1069 1069 elif isinstance(job, AsyncResult):
1070 1070 theids.update(job.msg_ids)
1071 1071 continue
1072 1072 theids.add(job)
1073 1073 if not theids.intersection(self.outstanding):
1074 1074 return True
1075 1075 self.spin()
1076 1076 while theids.intersection(self.outstanding):
1077 1077 if timeout >= 0 and ( time.time()-tic ) > timeout:
1078 1078 break
1079 1079 time.sleep(1e-3)
1080 1080 self.spin()
1081 1081 return len(theids.intersection(self.outstanding)) == 0
1082 1082
1083 1083 #--------------------------------------------------------------------------
1084 1084 # Control methods
1085 1085 #--------------------------------------------------------------------------
1086 1086
1087 1087 @spin_first
1088 1088 def clear(self, targets=None, block=None):
1089 1089 """Clear the namespace in target(s)."""
1090 1090 block = self.block if block is None else block
1091 1091 targets = self._build_targets(targets)[0]
1092 1092 for t in targets:
1093 1093 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1094 1094 error = False
1095 1095 if block:
1096 1096 self._flush_ignored_control()
1097 1097 for i in range(len(targets)):
1098 1098 idents,msg = self.session.recv(self._control_socket,0)
1099 1099 if self.debug:
1100 1100 pprint(msg)
1101 1101 if msg['content']['status'] != 'ok':
1102 1102 error = self._unwrap_exception(msg['content'])
1103 1103 else:
1104 1104 self._ignored_control_replies += len(targets)
1105 1105 if error:
1106 1106 raise error
1107 1107
1108 1108
1109 1109 @spin_first
1110 1110 def abort(self, jobs=None, targets=None, block=None):
1111 1111 """Abort specific jobs from the execution queues of target(s).
1112 1112
1113 1113 This is a mechanism to prevent jobs that have already been submitted
1114 1114 from executing.
1115 1115
1116 1116 Parameters
1117 1117 ----------
1118 1118
1119 1119 jobs : msg_id, list of msg_ids, or AsyncResult
1120 1120 The jobs to be aborted
1121 1121
1122 1122 If unspecified/None: abort all outstanding jobs.
1123 1123
1124 1124 """
1125 1125 block = self.block if block is None else block
1126 1126 jobs = jobs if jobs is not None else list(self.outstanding)
1127 1127 targets = self._build_targets(targets)[0]
1128 1128
1129 1129 msg_ids = []
1130 1130 if isinstance(jobs, string_types + (AsyncResult,)):
1131 1131 jobs = [jobs]
1132 1132 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1133 1133 if bad_ids:
1134 1134 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1135 1135 for j in jobs:
1136 1136 if isinstance(j, AsyncResult):
1137 1137 msg_ids.extend(j.msg_ids)
1138 1138 else:
1139 1139 msg_ids.append(j)
1140 1140 content = dict(msg_ids=msg_ids)
1141 1141 for t in targets:
1142 1142 self.session.send(self._control_socket, 'abort_request',
1143 1143 content=content, ident=t)
1144 1144 error = False
1145 1145 if block:
1146 1146 self._flush_ignored_control()
1147 1147 for i in range(len(targets)):
1148 1148 idents,msg = self.session.recv(self._control_socket,0)
1149 1149 if self.debug:
1150 1150 pprint(msg)
1151 1151 if msg['content']['status'] != 'ok':
1152 1152 error = self._unwrap_exception(msg['content'])
1153 1153 else:
1154 1154 self._ignored_control_replies += len(targets)
1155 1155 if error:
1156 1156 raise error
1157 1157
1158 1158 @spin_first
1159 1159 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1160 1160 """Terminates one or more engine processes, optionally including the hub.
1161 1161
1162 1162 Parameters
1163 1163 ----------
1164 1164
1165 1165 targets: list of ints or 'all' [default: all]
1166 1166 Which engines to shutdown.
1167 1167 hub: bool [default: False]
1168 1168 Whether to include the Hub. hub=True implies targets='all'.
1169 1169 block: bool [default: self.block]
1170 1170 Whether to wait for clean shutdown replies or not.
1171 1171 restart: bool [default: False]
1172 1172 NOT IMPLEMENTED
1173 1173 whether to restart engines after shutting them down.
1174 1174 """
1175 1175 from IPython.parallel.error import NoEnginesRegistered
1176 1176 if restart:
1177 1177 raise NotImplementedError("Engine restart is not yet implemented")
1178 1178
1179 1179 block = self.block if block is None else block
1180 1180 if hub:
1181 1181 targets = 'all'
1182 1182 try:
1183 1183 targets = self._build_targets(targets)[0]
1184 1184 except NoEnginesRegistered:
1185 1185 targets = []
1186 1186 for t in targets:
1187 1187 self.session.send(self._control_socket, 'shutdown_request',
1188 1188 content={'restart':restart},ident=t)
1189 1189 error = False
1190 1190 if block or hub:
1191 1191 self._flush_ignored_control()
1192 1192 for i in range(len(targets)):
1193 1193 idents,msg = self.session.recv(self._control_socket, 0)
1194 1194 if self.debug:
1195 1195 pprint(msg)
1196 1196 if msg['content']['status'] != 'ok':
1197 1197 error = self._unwrap_exception(msg['content'])
1198 1198 else:
1199 1199 self._ignored_control_replies += len(targets)
1200 1200
1201 1201 if hub:
1202 1202 time.sleep(0.25)
1203 1203 self.session.send(self._query_socket, 'shutdown_request')
1204 1204 idents,msg = self.session.recv(self._query_socket, 0)
1205 1205 if self.debug:
1206 1206 pprint(msg)
1207 1207 if msg['content']['status'] != 'ok':
1208 1208 error = self._unwrap_exception(msg['content'])
1209 1209
1210 1210 if error:
1211 1211 raise error
1212 1212
1213 1213 #--------------------------------------------------------------------------
1214 1214 # Execution related methods
1215 1215 #--------------------------------------------------------------------------
1216 1216
1217 1217 def _maybe_raise(self, result):
1218 1218 """wrapper for maybe raising an exception if apply failed."""
1219 1219 if isinstance(result, error.RemoteError):
1220 1220 raise result
1221 1221
1222 1222 return result
1223 1223
1224 1224 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1225 1225 ident=None):
1226 1226 """construct and send an apply message via a socket.
1227 1227
1228 1228 This is the principal method with which all engine execution is performed by views.
1229 1229 """
1230 1230
1231 1231 if self._closed:
1232 1232 raise RuntimeError("Client cannot be used after its sockets have been closed")
1233 1233
1234 1234 # defaults:
1235 1235 args = args if args is not None else []
1236 1236 kwargs = kwargs if kwargs is not None else {}
1237 1237 metadata = metadata if metadata is not None else {}
1238 1238
1239 1239 # validate arguments
1240 1240 if not callable(f) and not isinstance(f, Reference):
1241 1241 raise TypeError("f must be callable, not %s"%type(f))
1242 1242 if not isinstance(args, (tuple, list)):
1243 1243 raise TypeError("args must be tuple or list, not %s"%type(args))
1244 1244 if not isinstance(kwargs, dict):
1245 1245 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1246 1246 if not isinstance(metadata, dict):
1247 1247 raise TypeError("metadata must be dict, not %s"%type(metadata))
1248 1248
1249 1249 bufs = serialize.pack_apply_message(f, args, kwargs,
1250 1250 buffer_threshold=self.session.buffer_threshold,
1251 1251 item_threshold=self.session.item_threshold,
1252 1252 )
1253 1253
1254 1254 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1255 1255 metadata=metadata, track=track)
1256 1256
1257 1257 msg_id = msg['header']['msg_id']
1258 1258 self.outstanding.add(msg_id)
1259 1259 if ident:
1260 1260 # possibly routed to a specific engine
1261 1261 if isinstance(ident, list):
1262 1262 ident = ident[-1]
1263 1263 if ident in self._engines.values():
1264 1264 # save for later, in case of engine death
1265 1265 self._outstanding_dict[ident].add(msg_id)
1266 1266 self.history.append(msg_id)
1267 1267 self.metadata[msg_id]['submitted'] = datetime.now()
1268 1268
1269 1269 return msg
1270 1270
1271 1271 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1272 1272 """construct and send an execute request via a socket.
1273 1273
1274 1274 """
1275 1275
1276 1276 if self._closed:
1277 1277 raise RuntimeError("Client cannot be used after its sockets have been closed")
1278 1278
1279 1279 # defaults:
1280 1280 metadata = metadata if metadata is not None else {}
1281 1281
1282 1282 # validate arguments
1283 1283 if not isinstance(code, string_types):
1284 1284 raise TypeError("code must be text, not %s" % type(code))
1285 1285 if not isinstance(metadata, dict):
1286 1286 raise TypeError("metadata must be dict, not %s" % type(metadata))
1287 1287
1288 1288 content = dict(code=code, silent=bool(silent), user_expressions={})
1289 1289
1290 1290
1291 1291 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1292 1292 metadata=metadata)
1293 1293
1294 1294 msg_id = msg['header']['msg_id']
1295 1295 self.outstanding.add(msg_id)
1296 1296 if ident:
1297 1297 # possibly routed to a specific engine
1298 1298 if isinstance(ident, list):
1299 1299 ident = ident[-1]
1300 1300 if ident in self._engines.values():
1301 1301 # save for later, in case of engine death
1302 1302 self._outstanding_dict[ident].add(msg_id)
1303 1303 self.history.append(msg_id)
1304 1304 self.metadata[msg_id]['submitted'] = datetime.now()
1305 1305
1306 1306 return msg
1307 1307
1308 1308 #--------------------------------------------------------------------------
1309 1309 # construct a View object
1310 1310 #--------------------------------------------------------------------------
1311 1311
1312 1312 def load_balanced_view(self, targets=None):
1313 1313 """construct a DirectView object.
1314 1314
1315 1315 If no arguments are specified, create a LoadBalancedView
1316 1316 using all engines.
1317 1317
1318 1318 Parameters
1319 1319 ----------
1320 1320
1321 1321 targets: list,slice,int,etc. [default: use all engines]
1322 1322 The subset of engines across which to load-balance
1323 1323 """
1324 1324 if targets == 'all':
1325 1325 targets = None
1326 1326 if targets is not None:
1327 1327 targets = self._build_targets(targets)[1]
1328 1328 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1329 1329
1330 1330 def direct_view(self, targets='all'):
1331 1331 """construct a DirectView object.
1332 1332
1333 1333 If no targets are specified, create a DirectView using all engines.
1334 1334
1335 1335 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1336 1336 evaluate the target engines at each execution, whereas rc[:] will connect to
1337 1337 all *current* engines, and that list will not change.
1338 1338
1339 1339 That is, 'all' will always use all engines, whereas rc[:] will not use
1340 1340 engines added after the DirectView is constructed.
1341 1341
1342 1342 Parameters
1343 1343 ----------
1344 1344
1345 1345 targets: list,slice,int,etc. [default: use all engines]
1346 1346 The engines to use for the View
1347 1347 """
1348 1348 single = isinstance(targets, int)
1349 1349 # allow 'all' to be lazily evaluated at each execution
1350 1350 if targets != 'all':
1351 1351 targets = self._build_targets(targets)[1]
1352 1352 if single:
1353 1353 targets = targets[0]
1354 1354 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1355 1355
1356 1356 #--------------------------------------------------------------------------
1357 1357 # Query methods
1358 1358 #--------------------------------------------------------------------------
1359 1359
1360 1360 @spin_first
1361 1361 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
1362 1362 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1363 1363
1364 1364 If the client already has the results, no request to the Hub will be made.
1365 1365
1366 1366 This is a convenient way to construct AsyncResult objects, which are wrappers
1367 1367 that include metadata about execution, and allow for awaiting results that
1368 1368 were not submitted by this Client.
1369 1369
1370 1370 It can also be a convenient way to retrieve the metadata associated with
1371 1371 blocking execution, since it always retrieves
1372 1372
1373 1373 Examples
1374 1374 --------
1375 1375 ::
1376 1376
1377 1377 In [10]: r = client.apply()
1378 1378
1379 1379 Parameters
1380 1380 ----------
1381 1381
1382 1382 indices_or_msg_ids : integer history index, str msg_id, or list of either
1383 1383 The indices or msg_ids of indices to be retrieved
1384 1384
1385 1385 block : bool
1386 1386 Whether to wait for the result to be done
1387 1387 owner : bool [default: True]
1388 1388 Whether this AsyncResult should own the result.
1389 1389 If so, calling `ar.get()` will remove data from the
1390 1390 client's result and metadata cache.
1391 1391 There should only be one owner of any given msg_id.
1392 1392
1393 1393 Returns
1394 1394 -------
1395 1395
1396 1396 AsyncResult
1397 1397 A single AsyncResult object will always be returned.
1398 1398
1399 1399 AsyncHubResult
1400 1400 A subclass of AsyncResult that retrieves results from the Hub
1401 1401
1402 1402 """
1403 1403 block = self.block if block is None else block
1404 1404 if indices_or_msg_ids is None:
1405 1405 indices_or_msg_ids = -1
1406 1406
1407 1407 single_result = False
1408 1408 if not isinstance(indices_or_msg_ids, (list,tuple)):
1409 1409 indices_or_msg_ids = [indices_or_msg_ids]
1410 1410 single_result = True
1411 1411
1412 1412 theids = []
1413 1413 for id in indices_or_msg_ids:
1414 1414 if isinstance(id, int):
1415 1415 id = self.history[id]
1416 1416 if not isinstance(id, string_types):
1417 1417 raise TypeError("indices must be str or int, not %r"%id)
1418 1418 theids.append(id)
1419 1419
1420 1420 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1421 1421 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1422 1422
1423 1423 # given single msg_id initially, get_result shot get the result itself,
1424 1424 # not a length-one list
1425 1425 if single_result:
1426 1426 theids = theids[0]
1427 1427
1428 1428 if remote_ids:
1429 1429 ar = AsyncHubResult(self, msg_ids=theids, owner=owner)
1430 1430 else:
1431 1431 ar = AsyncResult(self, msg_ids=theids, owner=owner)
1432 1432
1433 1433 if block:
1434 1434 ar.wait()
1435 1435
1436 1436 return ar
1437 1437
1438 1438 @spin_first
1439 1439 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1440 1440 """Resubmit one or more tasks.
1441 1441
1442 1442 in-flight tasks may not be resubmitted.
1443 1443
1444 1444 Parameters
1445 1445 ----------
1446 1446
1447 1447 indices_or_msg_ids : integer history index, str msg_id, or list of either
1448 1448 The indices or msg_ids of indices to be retrieved
1449 1449
1450 1450 block : bool
1451 1451 Whether to wait for the result to be done
1452 1452
1453 1453 Returns
1454 1454 -------
1455 1455
1456 1456 AsyncHubResult
1457 1457 A subclass of AsyncResult that retrieves results from the Hub
1458 1458
1459 1459 """
1460 1460 block = self.block if block is None else block
1461 1461 if indices_or_msg_ids is None:
1462 1462 indices_or_msg_ids = -1
1463 1463
1464 1464 if not isinstance(indices_or_msg_ids, (list,tuple)):
1465 1465 indices_or_msg_ids = [indices_or_msg_ids]
1466 1466
1467 1467 theids = []
1468 1468 for id in indices_or_msg_ids:
1469 1469 if isinstance(id, int):
1470 1470 id = self.history[id]
1471 1471 if not isinstance(id, string_types):
1472 1472 raise TypeError("indices must be str or int, not %r"%id)
1473 1473 theids.append(id)
1474 1474
1475 1475 content = dict(msg_ids = theids)
1476 1476
1477 1477 self.session.send(self._query_socket, 'resubmit_request', content)
1478 1478
1479 1479 zmq.select([self._query_socket], [], [])
1480 1480 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1481 1481 if self.debug:
1482 1482 pprint(msg)
1483 1483 content = msg['content']
1484 1484 if content['status'] != 'ok':
1485 1485 raise self._unwrap_exception(content)
1486 1486 mapping = content['resubmitted']
1487 1487 new_ids = [ mapping[msg_id] for msg_id in theids ]
1488 1488
1489 1489 ar = AsyncHubResult(self, msg_ids=new_ids)
1490 1490
1491 1491 if block:
1492 1492 ar.wait()
1493 1493
1494 1494 return ar
1495 1495
1496 1496 @spin_first
1497 1497 def result_status(self, msg_ids, status_only=True):
1498 1498 """Check on the status of the result(s) of the apply request with `msg_ids`.
1499 1499
1500 1500 If status_only is False, then the actual results will be retrieved, else
1501 1501 only the status of the results will be checked.
1502 1502
1503 1503 Parameters
1504 1504 ----------
1505 1505
1506 1506 msg_ids : list of msg_ids
1507 1507 if int:
1508 1508 Passed as index to self.history for convenience.
1509 1509 status_only : bool (default: True)
1510 1510 if False:
1511 1511 Retrieve the actual results of completed tasks.
1512 1512
1513 1513 Returns
1514 1514 -------
1515 1515
1516 1516 results : dict
1517 1517 There will always be the keys 'pending' and 'completed', which will
1518 1518 be lists of msg_ids that are incomplete or complete. If `status_only`
1519 1519 is False, then completed results will be keyed by their `msg_id`.
1520 1520 """
1521 1521 if not isinstance(msg_ids, (list,tuple)):
1522 1522 msg_ids = [msg_ids]
1523 1523
1524 1524 theids = []
1525 1525 for msg_id in msg_ids:
1526 1526 if isinstance(msg_id, int):
1527 1527 msg_id = self.history[msg_id]
1528 1528 if not isinstance(msg_id, string_types):
1529 1529 raise TypeError("msg_ids must be str, not %r"%msg_id)
1530 1530 theids.append(msg_id)
1531 1531
1532 1532 completed = []
1533 1533 local_results = {}
1534 1534
1535 1535 # comment this block out to temporarily disable local shortcut:
1536 1536 for msg_id in theids:
1537 1537 if msg_id in self.results:
1538 1538 completed.append(msg_id)
1539 1539 local_results[msg_id] = self.results[msg_id]
1540 1540 theids.remove(msg_id)
1541 1541
1542 1542 if theids: # some not locally cached
1543 1543 content = dict(msg_ids=theids, status_only=status_only)
1544 1544 msg = self.session.send(self._query_socket, "result_request", content=content)
1545 1545 zmq.select([self._query_socket], [], [])
1546 1546 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1547 1547 if self.debug:
1548 1548 pprint(msg)
1549 1549 content = msg['content']
1550 1550 if content['status'] != 'ok':
1551 1551 raise self._unwrap_exception(content)
1552 1552 buffers = msg['buffers']
1553 1553 else:
1554 1554 content = dict(completed=[],pending=[])
1555 1555
1556 1556 content['completed'].extend(completed)
1557 1557
1558 1558 if status_only:
1559 1559 return content
1560 1560
1561 1561 failures = []
1562 1562 # load cached results into result:
1563 1563 content.update(local_results)
1564 1564
1565 1565 # update cache with results:
1566 1566 for msg_id in sorted(theids):
1567 1567 if msg_id in content['completed']:
1568 1568 rec = content[msg_id]
1569 1569 parent = extract_dates(rec['header'])
1570 1570 header = extract_dates(rec['result_header'])
1571 1571 rcontent = rec['result_content']
1572 1572 iodict = rec['io']
1573 1573 if isinstance(rcontent, str):
1574 1574 rcontent = self.session.unpack(rcontent)
1575 1575
1576 1576 md = self.metadata[msg_id]
1577 1577 md_msg = dict(
1578 1578 content=rcontent,
1579 1579 parent_header=parent,
1580 1580 header=header,
1581 1581 metadata=rec['result_metadata'],
1582 1582 )
1583 1583 md.update(self._extract_metadata(md_msg))
1584 1584 if rec.get('received'):
1585 1585 md['received'] = parse_date(rec['received'])
1586 1586 md.update(iodict)
1587 1587
1588 1588 if rcontent['status'] == 'ok':
1589 1589 if header['msg_type'] == 'apply_reply':
1590 1590 res,buffers = serialize.unserialize_object(buffers)
1591 1591 elif header['msg_type'] == 'execute_reply':
1592 1592 res = ExecuteReply(msg_id, rcontent, md)
1593 1593 else:
1594 1594 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1595 1595 else:
1596 1596 res = self._unwrap_exception(rcontent)
1597 1597 failures.append(res)
1598 1598
1599 1599 self.results[msg_id] = res
1600 1600 content[msg_id] = res
1601 1601
1602 1602 if len(theids) == 1 and failures:
1603 1603 raise failures[0]
1604 1604
1605 1605 error.collect_exceptions(failures, "result_status")
1606 1606 return content
1607 1607
1608 1608 @spin_first
1609 1609 def queue_status(self, targets='all', verbose=False):
1610 1610 """Fetch the status of engine queues.
1611 1611
1612 1612 Parameters
1613 1613 ----------
1614 1614
1615 1615 targets : int/str/list of ints/strs
1616 1616 the engines whose states are to be queried.
1617 1617 default : all
1618 1618 verbose : bool
1619 1619 Whether to return lengths only, or lists of ids for each element
1620 1620 """
1621 1621 if targets == 'all':
1622 1622 # allow 'all' to be evaluated on the engine
1623 1623 engine_ids = None
1624 1624 else:
1625 1625 engine_ids = self._build_targets(targets)[1]
1626 1626 content = dict(targets=engine_ids, verbose=verbose)
1627 1627 self.session.send(self._query_socket, "queue_request", content=content)
1628 1628 idents,msg = self.session.recv(self._query_socket, 0)
1629 1629 if self.debug:
1630 1630 pprint(msg)
1631 1631 content = msg['content']
1632 1632 status = content.pop('status')
1633 1633 if status != 'ok':
1634 1634 raise self._unwrap_exception(content)
1635 1635 content = rekey(content)
1636 1636 if isinstance(targets, int):
1637 1637 return content[targets]
1638 1638 else:
1639 1639 return content
1640 1640
1641 1641 def _build_msgids_from_target(self, targets=None):
1642 1642 """Build a list of msg_ids from the list of engine targets"""
1643 1643 if not targets: # needed as _build_targets otherwise uses all engines
1644 1644 return []
1645 1645 target_ids = self._build_targets(targets)[0]
1646 1646 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1647 1647
1648 1648 def _build_msgids_from_jobs(self, jobs=None):
1649 1649 """Build a list of msg_ids from "jobs" """
1650 1650 if not jobs:
1651 1651 return []
1652 1652 msg_ids = []
1653 1653 if isinstance(jobs, string_types + (AsyncResult,)):
1654 1654 jobs = [jobs]
1655 1655 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1656 1656 if bad_ids:
1657 1657 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1658 1658 for j in jobs:
1659 1659 if isinstance(j, AsyncResult):
1660 1660 msg_ids.extend(j.msg_ids)
1661 1661 else:
1662 1662 msg_ids.append(j)
1663 1663 return msg_ids
1664 1664
1665 1665 def purge_local_results(self, jobs=[], targets=[]):
1666 1666 """Clears the client caches of results and their metadata.
1667 1667
1668 1668 Individual results can be purged by msg_id, or the entire
1669 1669 history of specific targets can be purged.
1670 1670
1671 1671 Use `purge_local_results('all')` to scrub everything from the Clients's
1672 1672 results and metadata caches.
1673 1673
1674 1674 After this call all `AsyncResults` are invalid and should be discarded.
1675 1675
1676 1676 If you must "reget" the results, you can still do so by using
1677 1677 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1678 1678 redownload the results from the hub if they are still available
1679 1679 (i.e `client.purge_hub_results(...)` has not been called.
1680 1680
1681 1681 Parameters
1682 1682 ----------
1683 1683
1684 1684 jobs : str or list of str or AsyncResult objects
1685 1685 the msg_ids whose results should be purged.
1686 1686 targets : int/list of ints
1687 1687 The engines, by integer ID, whose entire result histories are to be purged.
1688 1688
1689 1689 Raises
1690 1690 ------
1691 1691
1692 1692 RuntimeError : if any of the tasks to be purged are still outstanding.
1693 1693
1694 1694 """
1695 1695 if not targets and not jobs:
1696 1696 raise ValueError("Must specify at least one of `targets` and `jobs`")
1697 1697
1698 1698 if jobs == 'all':
1699 1699 if self.outstanding:
1700 1700 raise RuntimeError("Can't purge outstanding tasks: %s" % self.outstanding)
1701 1701 self.results.clear()
1702 1702 self.metadata.clear()
1703 1703 else:
1704 1704 msg_ids = set()
1705 1705 msg_ids.update(self._build_msgids_from_target(targets))
1706 1706 msg_ids.update(self._build_msgids_from_jobs(jobs))
1707 1707 still_outstanding = self.outstanding.intersection(msg_ids)
1708 1708 if still_outstanding:
1709 1709 raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding)
1710 1710 for mid in msg_ids:
1711 1711 self.results.pop(mid, None)
1712 1712 self.metadata.pop(mid, None)
1713 1713
1714 1714
1715 1715 @spin_first
1716 1716 def purge_hub_results(self, jobs=[], targets=[]):
1717 1717 """Tell the Hub to forget results.
1718 1718
1719 1719 Individual results can be purged by msg_id, or the entire
1720 1720 history of specific targets can be purged.
1721 1721
1722 1722 Use `purge_results('all')` to scrub everything from the Hub's db.
1723 1723
1724 1724 Parameters
1725 1725 ----------
1726 1726
1727 1727 jobs : str or list of str or AsyncResult objects
1728 1728 the msg_ids whose results should be forgotten.
1729 1729 targets : int/str/list of ints/strs
1730 1730 The targets, by int_id, whose entire history is to be purged.
1731 1731
1732 1732 default : None
1733 1733 """
1734 1734 if not targets and not jobs:
1735 1735 raise ValueError("Must specify at least one of `targets` and `jobs`")
1736 1736 if targets:
1737 1737 targets = self._build_targets(targets)[1]
1738 1738
1739 1739 # construct msg_ids from jobs
1740 1740 if jobs == 'all':
1741 1741 msg_ids = jobs
1742 1742 else:
1743 1743 msg_ids = self._build_msgids_from_jobs(jobs)
1744 1744
1745 1745 content = dict(engine_ids=targets, msg_ids=msg_ids)
1746 1746 self.session.send(self._query_socket, "purge_request", content=content)
1747 1747 idents, msg = self.session.recv(self._query_socket, 0)
1748 1748 if self.debug:
1749 1749 pprint(msg)
1750 1750 content = msg['content']
1751 1751 if content['status'] != 'ok':
1752 1752 raise self._unwrap_exception(content)
1753 1753
1754 1754 def purge_results(self, jobs=[], targets=[]):
1755 1755 """Clears the cached results from both the hub and the local client
1756 1756
1757 1757 Individual results can be purged by msg_id, or the entire
1758 1758 history of specific targets can be purged.
1759 1759
1760 1760 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1761 1761 the Client's db.
1762 1762
1763 1763 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1764 1764 the same arguments.
1765 1765
1766 1766 Parameters
1767 1767 ----------
1768 1768
1769 1769 jobs : str or list of str or AsyncResult objects
1770 1770 the msg_ids whose results should be forgotten.
1771 1771 targets : int/str/list of ints/strs
1772 1772 The targets, by int_id, whose entire history is to be purged.
1773 1773
1774 1774 default : None
1775 1775 """
1776 1776 self.purge_local_results(jobs=jobs, targets=targets)
1777 1777 self.purge_hub_results(jobs=jobs, targets=targets)
1778 1778
1779 1779 def purge_everything(self):
1780 1780 """Clears all content from previous Tasks from both the hub and the local client
1781 1781
1782 1782 In addition to calling `purge_results("all")` it also deletes the history and
1783 1783 other bookkeeping lists.
1784 1784 """
1785 1785 self.purge_results("all")
1786 1786 self.history = []
1787 1787 self.session.digest_history.clear()
1788 1788
1789 1789 @spin_first
1790 1790 def hub_history(self):
1791 1791 """Get the Hub's history
1792 1792
1793 1793 Just like the Client, the Hub has a history, which is a list of msg_ids.
1794 1794 This will contain the history of all clients, and, depending on configuration,
1795 1795 may contain history across multiple cluster sessions.
1796 1796
1797 1797 Any msg_id returned here is a valid argument to `get_result`.
1798 1798
1799 1799 Returns
1800 1800 -------
1801 1801
1802 1802 msg_ids : list of strs
1803 1803 list of all msg_ids, ordered by task submission time.
1804 1804 """
1805 1805
1806 1806 self.session.send(self._query_socket, "history_request", content={})
1807 1807 idents, msg = self.session.recv(self._query_socket, 0)
1808 1808
1809 1809 if self.debug:
1810 1810 pprint(msg)
1811 1811 content = msg['content']
1812 1812 if content['status'] != 'ok':
1813 1813 raise self._unwrap_exception(content)
1814 1814 else:
1815 1815 return content['history']
1816 1816
1817 1817 @spin_first
1818 1818 def db_query(self, query, keys=None):
1819 1819 """Query the Hub's TaskRecord database
1820 1820
1821 1821 This will return a list of task record dicts that match `query`
1822 1822
1823 1823 Parameters
1824 1824 ----------
1825 1825
1826 1826 query : mongodb query dict
1827 1827 The search dict. See mongodb query docs for details.
1828 1828 keys : list of strs [optional]
1829 1829 The subset of keys to be returned. The default is to fetch everything but buffers.
1830 1830 'msg_id' will *always* be included.
1831 1831 """
1832 1832 if isinstance(keys, string_types):
1833 1833 keys = [keys]
1834 1834 content = dict(query=query, keys=keys)
1835 1835 self.session.send(self._query_socket, "db_request", content=content)
1836 1836 idents, msg = self.session.recv(self._query_socket, 0)
1837 1837 if self.debug:
1838 1838 pprint(msg)
1839 1839 content = msg['content']
1840 1840 if content['status'] != 'ok':
1841 1841 raise self._unwrap_exception(content)
1842 1842
1843 1843 records = content['records']
1844 1844
1845 1845 buffer_lens = content['buffer_lens']
1846 1846 result_buffer_lens = content['result_buffer_lens']
1847 1847 buffers = msg['buffers']
1848 1848 has_bufs = buffer_lens is not None
1849 1849 has_rbufs = result_buffer_lens is not None
1850 1850 for i,rec in enumerate(records):
1851 1851 # unpack datetime objects
1852 1852 for hkey in ('header', 'result_header'):
1853 1853 if hkey in rec:
1854 1854 rec[hkey] = extract_dates(rec[hkey])
1855 1855 for dtkey in ('submitted', 'started', 'completed', 'received'):
1856 1856 if dtkey in rec:
1857 1857 rec[dtkey] = parse_date(rec[dtkey])
1858 1858 # relink buffers
1859 1859 if has_bufs:
1860 1860 blen = buffer_lens[i]
1861 1861 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1862 1862 if has_rbufs:
1863 1863 blen = result_buffer_lens[i]
1864 1864 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1865 1865
1866 1866 return records
1867 1867
1868 1868 __all__ = [ 'Client' ]
@@ -1,298 +1,297 b''
1 1 """A simple engine that talks to a controller over 0MQ.
2 2 it handles registration, etc. and launches a kernel
3 3 connected to the Controller's Schedulers.
4 4 """
5 5
6 6 # Copyright (c) IPython Development Team.
7 7 # Distributed under the terms of the Modified BSD License.
8 8
9 9 from __future__ import print_function
10 10
11 11 import sys
12 12 import time
13 13 from getpass import getpass
14 14
15 15 import zmq
16 16 from zmq.eventloop import ioloop, zmqstream
17 from zmq.ssh import tunnel
17 18
18 from IPython.external.ssh import tunnel
19 # internal
20 19 from IPython.utils.localinterfaces import localhost
21 20 from IPython.utils.traitlets import (
22 21 Instance, Dict, Integer, Type, Float, Integer, Unicode, CBytes, Bool
23 22 )
24 23 from IPython.utils.py3compat import cast_bytes
25 24
26 25 from IPython.parallel.controller.heartmonitor import Heart
27 26 from IPython.parallel.factory import RegistrationFactory
28 27 from IPython.parallel.util import disambiguate_url
29 28
30 29 from IPython.kernel.zmq.session import Message
31 30 from IPython.kernel.zmq.ipkernel import Kernel
32 31 from IPython.kernel.zmq.kernelapp import IPKernelApp
33 32
34 33 class EngineFactory(RegistrationFactory):
35 34 """IPython engine"""
36 35
37 36 # configurables:
38 37 out_stream_factory=Type('IPython.kernel.zmq.iostream.OutStream', config=True,
39 38 help="""The OutStream for handling stdout/err.
40 39 Typically 'IPython.kernel.zmq.iostream.OutStream'""")
41 40 display_hook_factory=Type('IPython.kernel.zmq.displayhook.ZMQDisplayHook', config=True,
42 41 help="""The class for handling displayhook.
43 42 Typically 'IPython.kernel.zmq.displayhook.ZMQDisplayHook'""")
44 43 location=Unicode(config=True,
45 44 help="""The location (an IP address) of the controller. This is
46 45 used for disambiguating URLs, to determine whether
47 46 loopback should be used to connect or the public address.""")
48 47 timeout=Float(5.0, config=True,
49 48 help="""The time (in seconds) to wait for the Controller to respond
50 49 to registration requests before giving up.""")
51 50 max_heartbeat_misses=Integer(50, config=True,
52 51 help="""The maximum number of times a check for the heartbeat ping of a
53 52 controller can be missed before shutting down the engine.
54 53
55 54 If set to 0, the check is disabled.""")
56 55 sshserver=Unicode(config=True,
57 56 help="""The SSH server to use for tunneling connections to the Controller.""")
58 57 sshkey=Unicode(config=True,
59 58 help="""The SSH private key file to use when tunneling connections to the Controller.""")
60 59 paramiko=Bool(sys.platform == 'win32', config=True,
61 60 help="""Whether to use paramiko instead of openssh for tunnels.""")
62 61
63 62
64 63 # not configurable:
65 64 connection_info = Dict()
66 65 user_ns = Dict()
67 66 id = Integer(allow_none=True)
68 67 registrar = Instance('zmq.eventloop.zmqstream.ZMQStream')
69 68 kernel = Instance(Kernel)
70 69 hb_check_period=Integer()
71 70
72 71 # States for the heartbeat monitoring
73 72 # Initial values for monitored and pinged must satisfy "monitored > pinged == False" so that
74 73 # during the first check no "missed" ping is reported. Must be floats for Python 3 compatibility.
75 74 _hb_last_pinged = 0.0
76 75 _hb_last_monitored = 0.0
77 76 _hb_missed_beats = 0
78 77 # The zmq Stream which receives the pings from the Heart
79 78 _hb_listener = None
80 79
81 80 bident = CBytes()
82 81 ident = Unicode()
83 82 def _ident_changed(self, name, old, new):
84 83 self.bident = cast_bytes(new)
85 84 using_ssh=Bool(False)
86 85
87 86
88 87 def __init__(self, **kwargs):
89 88 super(EngineFactory, self).__init__(**kwargs)
90 89 self.ident = self.session.session
91 90
92 91 def init_connector(self):
93 92 """construct connection function, which handles tunnels."""
94 93 self.using_ssh = bool(self.sshkey or self.sshserver)
95 94
96 95 if self.sshkey and not self.sshserver:
97 96 # We are using ssh directly to the controller, tunneling localhost to localhost
98 97 self.sshserver = self.url.split('://')[1].split(':')[0]
99 98
100 99 if self.using_ssh:
101 100 if tunnel.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
102 101 password=False
103 102 else:
104 103 password = getpass("SSH Password for %s: "%self.sshserver)
105 104 else:
106 105 password = False
107 106
108 107 def connect(s, url):
109 108 url = disambiguate_url(url, self.location)
110 109 if self.using_ssh:
111 110 self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
112 111 return tunnel.tunnel_connection(s, url, self.sshserver,
113 112 keyfile=self.sshkey, paramiko=self.paramiko,
114 113 password=password,
115 114 )
116 115 else:
117 116 return s.connect(url)
118 117
119 118 def maybe_tunnel(url):
120 119 """like connect, but don't complete the connection (for use by heartbeat)"""
121 120 url = disambiguate_url(url, self.location)
122 121 if self.using_ssh:
123 122 self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
124 123 url,tunnelobj = tunnel.open_tunnel(url, self.sshserver,
125 124 keyfile=self.sshkey, paramiko=self.paramiko,
126 125 password=password,
127 126 )
128 127 return str(url)
129 128 return connect, maybe_tunnel
130 129
131 130 def register(self):
132 131 """send the registration_request"""
133 132
134 133 self.log.info("Registering with controller at %s"%self.url)
135 134 ctx = self.context
136 135 connect,maybe_tunnel = self.init_connector()
137 136 reg = ctx.socket(zmq.DEALER)
138 137 reg.setsockopt(zmq.IDENTITY, self.bident)
139 138 connect(reg, self.url)
140 139 self.registrar = zmqstream.ZMQStream(reg, self.loop)
141 140
142 141
143 142 content = dict(uuid=self.ident)
144 143 self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
145 144 # print (self.session.key)
146 145 self.session.send(self.registrar, "registration_request", content=content)
147 146
148 147 def _report_ping(self, msg):
149 148 """Callback for when the heartmonitor.Heart receives a ping"""
150 149 #self.log.debug("Received a ping: %s", msg)
151 150 self._hb_last_pinged = time.time()
152 151
153 152 def complete_registration(self, msg, connect, maybe_tunnel):
154 153 # print msg
155 154 self._abort_dc.stop()
156 155 ctx = self.context
157 156 loop = self.loop
158 157 identity = self.bident
159 158 idents,msg = self.session.feed_identities(msg)
160 159 msg = self.session.unserialize(msg)
161 160 content = msg['content']
162 161 info = self.connection_info
163 162
164 163 def url(key):
165 164 """get zmq url for given channel"""
166 165 return str(info["interface"] + ":%i" % info[key])
167 166
168 167 if content['status'] == 'ok':
169 168 self.id = int(content['id'])
170 169
171 170 # launch heartbeat
172 171 # possibly forward hb ports with tunnels
173 172 hb_ping = maybe_tunnel(url('hb_ping'))
174 173 hb_pong = maybe_tunnel(url('hb_pong'))
175 174
176 175 hb_monitor = None
177 176 if self.max_heartbeat_misses > 0:
178 177 # Add a monitor socket which will record the last time a ping was seen
179 178 mon = self.context.socket(zmq.SUB)
180 179 mport = mon.bind_to_random_port('tcp://%s' % localhost())
181 180 mon.setsockopt(zmq.SUBSCRIBE, b"")
182 181 self._hb_listener = zmqstream.ZMQStream(mon, self.loop)
183 182 self._hb_listener.on_recv(self._report_ping)
184 183
185 184
186 185 hb_monitor = "tcp://%s:%i" % (localhost(), mport)
187 186
188 187 heart = Heart(hb_ping, hb_pong, hb_monitor , heart_id=identity)
189 188 heart.start()
190 189
191 190 # create Shell Connections (MUX, Task, etc.):
192 191 shell_addrs = url('mux'), url('task')
193 192
194 193 # Use only one shell stream for mux and tasks
195 194 stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
196 195 stream.setsockopt(zmq.IDENTITY, identity)
197 196 shell_streams = [stream]
198 197 for addr in shell_addrs:
199 198 connect(stream, addr)
200 199
201 200 # control stream:
202 201 control_addr = url('control')
203 202 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
204 203 control_stream.setsockopt(zmq.IDENTITY, identity)
205 204 connect(control_stream, control_addr)
206 205
207 206 # create iopub stream:
208 207 iopub_addr = url('iopub')
209 208 iopub_socket = ctx.socket(zmq.PUB)
210 209 iopub_socket.setsockopt(zmq.IDENTITY, identity)
211 210 connect(iopub_socket, iopub_addr)
212 211
213 212 # disable history:
214 213 self.config.HistoryManager.hist_file = ':memory:'
215 214
216 215 # Redirect input streams and set a display hook.
217 216 if self.out_stream_factory:
218 217 sys.stdout = self.out_stream_factory(self.session, iopub_socket, u'stdout')
219 218 sys.stdout.topic = cast_bytes('engine.%i.stdout' % self.id)
220 219 sys.stderr = self.out_stream_factory(self.session, iopub_socket, u'stderr')
221 220 sys.stderr.topic = cast_bytes('engine.%i.stderr' % self.id)
222 221 if self.display_hook_factory:
223 222 sys.displayhook = self.display_hook_factory(self.session, iopub_socket)
224 223 sys.displayhook.topic = cast_bytes('engine.%i.execute_result' % self.id)
225 224
226 225 self.kernel = Kernel(parent=self, int_id=self.id, ident=self.ident, session=self.session,
227 226 control_stream=control_stream, shell_streams=shell_streams, iopub_socket=iopub_socket,
228 227 loop=loop, user_ns=self.user_ns, log=self.log)
229 228
230 229 self.kernel.shell.display_pub.topic = cast_bytes('engine.%i.displaypub' % self.id)
231 230
232 231
233 232 # periodically check the heartbeat pings of the controller
234 233 # Should be started here and not in "start()" so that the right period can be taken
235 234 # from the hubs HeartBeatMonitor.period
236 235 if self.max_heartbeat_misses > 0:
237 236 # Use a slightly bigger check period than the hub signal period to not warn unnecessary
238 237 self.hb_check_period = int(content['hb_period'])+10
239 238 self.log.info("Starting to monitor the heartbeat signal from the hub every %i ms." , self.hb_check_period)
240 239 self._hb_reporter = ioloop.PeriodicCallback(self._hb_monitor, self.hb_check_period, self.loop)
241 240 self._hb_reporter.start()
242 241 else:
243 242 self.log.info("Monitoring of the heartbeat signal from the hub is not enabled.")
244 243
245 244
246 245 # FIXME: This is a hack until IPKernelApp and IPEngineApp can be fully merged
247 246 app = IPKernelApp(parent=self, shell=self.kernel.shell, kernel=self.kernel, log=self.log)
248 247 app.init_profile_dir()
249 248 app.init_code()
250 249
251 250 self.kernel.start()
252 251 else:
253 252 self.log.fatal("Registration Failed: %s"%msg)
254 253 raise Exception("Registration Failed: %s"%msg)
255 254
256 255 self.log.info("Completed registration with id %i"%self.id)
257 256
258 257
259 258 def abort(self):
260 259 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
261 260 if self.url.startswith('127.'):
262 261 self.log.fatal("""
263 262 If the controller and engines are not on the same machine,
264 263 you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
265 264 c.HubFactory.ip='*' # for all interfaces, internal and external
266 265 c.HubFactory.ip='192.168.1.101' # or any interface that the engines can see
267 266 or tunnel connections via ssh.
268 267 """)
269 268 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
270 269 time.sleep(1)
271 270 sys.exit(255)
272 271
273 272 def _hb_monitor(self):
274 273 """Callback to monitor the heartbeat from the controller"""
275 274 self._hb_listener.flush()
276 275 if self._hb_last_monitored > self._hb_last_pinged:
277 276 self._hb_missed_beats += 1
278 277 self.log.warn("No heartbeat in the last %s ms (%s time(s) in a row).", self.hb_check_period, self._hb_missed_beats)
279 278 else:
280 279 #self.log.debug("Heartbeat received (after missing %s beats).", self._hb_missed_beats)
281 280 self._hb_missed_beats = 0
282 281
283 282 if self._hb_missed_beats >= self.max_heartbeat_misses:
284 283 self.log.fatal("Maximum number of heartbeats misses reached (%s times %s ms), shutting down.",
285 284 self.max_heartbeat_misses, self.hb_check_period)
286 285 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
287 286 self.loop.stop()
288 287
289 288 self._hb_last_monitored = time.time()
290 289
291 290
292 291 def start(self):
293 292 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
294 293 dc.start()
295 294 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
296 295 self._abort_dc.start()
297 296
298 297
1 NO CONTENT: file was removed
1 NO CONTENT: file was removed
1 NO CONTENT: file was removed
General Comments 0
You need to be logged in to leave comments. Login now