##// END OF EJS Templates
Merge pull request #6042 from minrk/import-less-than-everything...
Thomas Kluyver -
r17055:c03e562b merge
parent child Browse files
Show More
@@ -1,171 +1,173 b''
1 1 """Manage IPython.parallel clusters in the notebook.
2 2
3 3 Authors:
4 4
5 5 * Brian Granger
6 6 """
7 7
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2008-2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18
19 19 from tornado import web
20 20 from zmq.eventloop import ioloop
21 21
22 22 from IPython.config.configurable import LoggingConfigurable
23 23 from IPython.utils.traitlets import Dict, Instance, CFloat
24 from IPython.parallel.apps.ipclusterapp import IPClusterStart
25 24 from IPython.core.profileapp import list_profiles_in
26 25 from IPython.core.profiledir import ProfileDir
27 26 from IPython.utils import py3compat
28 27 from IPython.utils.path import get_ipython_dir
29 28
30 29
31 30 #-----------------------------------------------------------------------------
32 31 # Classes
33 32 #-----------------------------------------------------------------------------
34 33
35 34
36 class DummyIPClusterStart(IPClusterStart):
37 """Dummy subclass to skip init steps that conflict with global app.
38
39 Instantiating and initializing this class should result in fully configured
40 launchers, but no other side effects or state.
41 """
42
43 def init_signal(self):
44 pass
45 def reinit_logging(self):
46 pass
47 35
48 36
49 37 class ClusterManager(LoggingConfigurable):
50 38
51 39 profiles = Dict()
52 40
53 41 delay = CFloat(1., config=True,
54 42 help="delay (in s) between starting the controller and the engines")
55 43
56 44 loop = Instance('zmq.eventloop.ioloop.IOLoop')
57 45 def _loop_default(self):
58 46 from zmq.eventloop.ioloop import IOLoop
59 47 return IOLoop.instance()
60 48
61 49 def build_launchers(self, profile_dir):
50 from IPython.parallel.apps.ipclusterapp import IPClusterStart
51
52 class DummyIPClusterStart(IPClusterStart):
53 """Dummy subclass to skip init steps that conflict with global app.
54
55 Instantiating and initializing this class should result in fully configured
56 launchers, but no other side effects or state.
57 """
58
59 def init_signal(self):
60 pass
61 def reinit_logging(self):
62 pass
63
62 64 starter = DummyIPClusterStart(log=self.log)
63 65 starter.initialize(['--profile-dir', profile_dir])
64 66 cl = starter.controller_launcher
65 67 esl = starter.engine_launcher
66 68 n = starter.n
67 69 return cl, esl, n
68 70
69 71 def get_profile_dir(self, name, path):
70 72 p = ProfileDir.find_profile_dir_by_name(path,name=name)
71 73 return p.location
72 74
73 75 def update_profiles(self):
74 76 """List all profiles in the ipython_dir and cwd.
75 77 """
76 78 for path in [get_ipython_dir(), py3compat.getcwd()]:
77 79 for profile in list_profiles_in(path):
78 80 pd = self.get_profile_dir(profile, path)
79 81 if profile not in self.profiles:
80 82 self.log.debug("Adding cluster profile '%s'" % profile)
81 83 self.profiles[profile] = {
82 84 'profile': profile,
83 85 'profile_dir': pd,
84 86 'status': 'stopped'
85 87 }
86 88
87 89 def list_profiles(self):
88 90 self.update_profiles()
89 91 # sorted list, but ensure that 'default' always comes first
90 92 default_first = lambda name: name if name != 'default' else ''
91 93 result = [self.profile_info(p) for p in sorted(self.profiles, key=default_first)]
92 94 return result
93 95
94 96 def check_profile(self, profile):
95 97 if profile not in self.profiles:
96 98 raise web.HTTPError(404, u'profile not found')
97 99
98 100 def profile_info(self, profile):
99 101 self.check_profile(profile)
100 102 result = {}
101 103 data = self.profiles.get(profile)
102 104 result['profile'] = profile
103 105 result['profile_dir'] = data['profile_dir']
104 106 result['status'] = data['status']
105 107 if 'n' in data:
106 108 result['n'] = data['n']
107 109 return result
108 110
109 111 def start_cluster(self, profile, n=None):
110 112 """Start a cluster for a given profile."""
111 113 self.check_profile(profile)
112 114 data = self.profiles[profile]
113 115 if data['status'] == 'running':
114 116 raise web.HTTPError(409, u'cluster already running')
115 117 cl, esl, default_n = self.build_launchers(data['profile_dir'])
116 118 n = n if n is not None else default_n
117 119 def clean_data():
118 120 data.pop('controller_launcher',None)
119 121 data.pop('engine_set_launcher',None)
120 122 data.pop('n',None)
121 123 data['status'] = 'stopped'
122 124 def engines_stopped(r):
123 125 self.log.debug('Engines stopped')
124 126 if cl.running:
125 127 cl.stop()
126 128 clean_data()
127 129 esl.on_stop(engines_stopped)
128 130 def controller_stopped(r):
129 131 self.log.debug('Controller stopped')
130 132 if esl.running:
131 133 esl.stop()
132 134 clean_data()
133 135 cl.on_stop(controller_stopped)
134 136
135 137 dc = ioloop.DelayedCallback(lambda: cl.start(), 0, self.loop)
136 138 dc.start()
137 139 dc = ioloop.DelayedCallback(lambda: esl.start(n), 1000*self.delay, self.loop)
138 140 dc.start()
139 141
140 142 self.log.debug('Cluster started')
141 143 data['controller_launcher'] = cl
142 144 data['engine_set_launcher'] = esl
143 145 data['n'] = n
144 146 data['status'] = 'running'
145 147 return self.profile_info(profile)
146 148
147 149 def stop_cluster(self, profile):
148 150 """Stop a cluster for a given profile."""
149 151 self.check_profile(profile)
150 152 data = self.profiles[profile]
151 153 if data['status'] == 'stopped':
152 154 raise web.HTTPError(409, u'cluster not running')
153 155 data = self.profiles[profile]
154 156 cl = data['controller_launcher']
155 157 esl = data['engine_set_launcher']
156 158 if cl.running:
157 159 cl.stop()
158 160 if esl.running:
159 161 esl.stop()
160 162 # Return a temp info dict, the real one is updated in the on_stop
161 163 # logic above.
162 164 result = {
163 165 'profile': data['profile'],
164 166 'profile_dir': data['profile_dir'],
165 167 'status': 'stopped'
166 168 }
167 169 return result
168 170
169 171 def stop_all_clusters(self):
170 172 for p in self.profiles.keys():
171 173 self.stop_cluster(p)
@@ -1,576 +1,576 b''
1 1 """Utilities for connecting to kernels
2 2
3 3 The :class:`ConnectionFileMixin` class in this module encapsulates the logic
4 4 related to writing and reading connections files.
5 5 """
6 6 # Copyright (c) IPython Development Team.
7 7 # Distributed under the terms of the Modified BSD License.
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from __future__ import absolute_import
14 14
15 15 import glob
16 16 import json
17 17 import os
18 18 import socket
19 19 import sys
20 20 from getpass import getpass
21 21 from subprocess import Popen, PIPE
22 22 import tempfile
23 23
24 24 import zmq
25 from zmq.ssh import tunnel
26 25
27 26 # IPython imports
28 27 from IPython.config import LoggingConfigurable
29 28 from IPython.core.profiledir import ProfileDir
30 29 from IPython.utils.localinterfaces import localhost
31 30 from IPython.utils.path import filefind, get_ipython_dir
32 31 from IPython.utils.py3compat import (str_to_bytes, bytes_to_str, cast_bytes_py2,
33 32 string_types)
34 33 from IPython.utils.traitlets import (
35 34 Bool, Integer, Unicode, CaselessStrEnum, Instance,
36 35 )
37 36
38 37
39 38 #-----------------------------------------------------------------------------
40 39 # Working with Connection Files
41 40 #-----------------------------------------------------------------------------
42 41
43 42 def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
44 43 control_port=0, ip='', key=b'', transport='tcp',
45 44 signature_scheme='hmac-sha256',
46 45 ):
47 46 """Generates a JSON config file, including the selection of random ports.
48 47
49 48 Parameters
50 49 ----------
51 50
52 51 fname : unicode
53 52 The path to the file to write
54 53
55 54 shell_port : int, optional
56 55 The port to use for ROUTER (shell) channel.
57 56
58 57 iopub_port : int, optional
59 58 The port to use for the SUB channel.
60 59
61 60 stdin_port : int, optional
62 61 The port to use for the ROUTER (raw input) channel.
63 62
64 63 control_port : int, optional
65 64 The port to use for the ROUTER (control) channel.
66 65
67 66 hb_port : int, optional
68 67 The port to use for the heartbeat REP channel.
69 68
70 69 ip : str, optional
71 70 The ip address the kernel will bind to.
72 71
73 72 key : str, optional
74 73 The Session key used for message authentication.
75 74
76 75 signature_scheme : str, optional
77 76 The scheme used for message authentication.
78 77 This has the form 'digest-hash', where 'digest'
79 78 is the scheme used for digests, and 'hash' is the name of the hash function
80 79 used by the digest scheme.
81 80 Currently, 'hmac' is the only supported digest scheme,
82 81 and 'sha256' is the default hash function.
83 82
84 83 """
85 84 if not ip:
86 85 ip = localhost()
87 86 # default to temporary connector file
88 87 if not fname:
89 88 fd, fname = tempfile.mkstemp('.json')
90 89 os.close(fd)
91 90
92 91 # Find open ports as necessary.
93 92
94 93 ports = []
95 94 ports_needed = int(shell_port <= 0) + \
96 95 int(iopub_port <= 0) + \
97 96 int(stdin_port <= 0) + \
98 97 int(control_port <= 0) + \
99 98 int(hb_port <= 0)
100 99 if transport == 'tcp':
101 100 for i in range(ports_needed):
102 101 sock = socket.socket()
103 102 # struct.pack('ii', (0,0)) is 8 null bytes
104 103 sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b'\0' * 8)
105 104 sock.bind(('', 0))
106 105 ports.append(sock)
107 106 for i, sock in enumerate(ports):
108 107 port = sock.getsockname()[1]
109 108 sock.close()
110 109 ports[i] = port
111 110 else:
112 111 N = 1
113 112 for i in range(ports_needed):
114 113 while os.path.exists("%s-%s" % (ip, str(N))):
115 114 N += 1
116 115 ports.append(N)
117 116 N += 1
118 117 if shell_port <= 0:
119 118 shell_port = ports.pop(0)
120 119 if iopub_port <= 0:
121 120 iopub_port = ports.pop(0)
122 121 if stdin_port <= 0:
123 122 stdin_port = ports.pop(0)
124 123 if control_port <= 0:
125 124 control_port = ports.pop(0)
126 125 if hb_port <= 0:
127 126 hb_port = ports.pop(0)
128 127
129 128 cfg = dict( shell_port=shell_port,
130 129 iopub_port=iopub_port,
131 130 stdin_port=stdin_port,
132 131 control_port=control_port,
133 132 hb_port=hb_port,
134 133 )
135 134 cfg['ip'] = ip
136 135 cfg['key'] = bytes_to_str(key)
137 136 cfg['transport'] = transport
138 137 cfg['signature_scheme'] = signature_scheme
139 138
140 139 with open(fname, 'w') as f:
141 140 f.write(json.dumps(cfg, indent=2))
142 141
143 142 return fname, cfg
144 143
145 144
146 145 def get_connection_file(app=None):
147 146 """Return the path to the connection file of an app
148 147
149 148 Parameters
150 149 ----------
151 150 app : IPKernelApp instance [optional]
152 151 If unspecified, the currently running app will be used
153 152 """
154 153 if app is None:
155 154 from IPython.kernel.zmq.kernelapp import IPKernelApp
156 155 if not IPKernelApp.initialized():
157 156 raise RuntimeError("app not specified, and not in a running Kernel")
158 157
159 158 app = IPKernelApp.instance()
160 159 return filefind(app.connection_file, ['.', app.profile_dir.security_dir])
161 160
162 161
163 162 def find_connection_file(filename, profile=None):
164 163 """find a connection file, and return its absolute path.
165 164
166 165 The current working directory and the profile's security
167 166 directory will be searched for the file if it is not given by
168 167 absolute path.
169 168
170 169 If profile is unspecified, then the current running application's
171 170 profile will be used, or 'default', if not run from IPython.
172 171
173 172 If the argument does not match an existing file, it will be interpreted as a
174 173 fileglob, and the matching file in the profile's security dir with
175 174 the latest access time will be used.
176 175
177 176 Parameters
178 177 ----------
179 178 filename : str
180 179 The connection file or fileglob to search for.
181 180 profile : str [optional]
182 181 The name of the profile to use when searching for the connection file,
183 182 if different from the current IPython session or 'default'.
184 183
185 184 Returns
186 185 -------
187 186 str : The absolute path of the connection file.
188 187 """
189 188 from IPython.core.application import BaseIPythonApplication as IPApp
190 189 try:
191 190 # quick check for absolute path, before going through logic
192 191 return filefind(filename)
193 192 except IOError:
194 193 pass
195 194
196 195 if profile is None:
197 196 # profile unspecified, check if running from an IPython app
198 197 if IPApp.initialized():
199 198 app = IPApp.instance()
200 199 profile_dir = app.profile_dir
201 200 else:
202 201 # not running in IPython, use default profile
203 202 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), 'default')
204 203 else:
205 204 # find profiledir by profile name:
206 205 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), profile)
207 206 security_dir = profile_dir.security_dir
208 207
209 208 try:
210 209 # first, try explicit name
211 210 return filefind(filename, ['.', security_dir])
212 211 except IOError:
213 212 pass
214 213
215 214 # not found by full name
216 215
217 216 if '*' in filename:
218 217 # given as a glob already
219 218 pat = filename
220 219 else:
221 220 # accept any substring match
222 221 pat = '*%s*' % filename
223 222 matches = glob.glob( os.path.join(security_dir, pat) )
224 223 if not matches:
225 224 raise IOError("Could not find %r in %r" % (filename, security_dir))
226 225 elif len(matches) == 1:
227 226 return matches[0]
228 227 else:
229 228 # get most recent match, by access time:
230 229 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
231 230
232 231
233 232 def get_connection_info(connection_file=None, unpack=False, profile=None):
234 233 """Return the connection information for the current Kernel.
235 234
236 235 Parameters
237 236 ----------
238 237 connection_file : str [optional]
239 238 The connection file to be used. Can be given by absolute path, or
240 239 IPython will search in the security directory of a given profile.
241 240 If run from IPython,
242 241
243 242 If unspecified, the connection file for the currently running
244 243 IPython Kernel will be used, which is only allowed from inside a kernel.
245 244 unpack : bool [default: False]
246 245 if True, return the unpacked dict, otherwise just the string contents
247 246 of the file.
248 247 profile : str [optional]
249 248 The name of the profile to use when searching for the connection file,
250 249 if different from the current IPython session or 'default'.
251 250
252 251
253 252 Returns
254 253 -------
255 254 The connection dictionary of the current kernel, as string or dict,
256 255 depending on `unpack`.
257 256 """
258 257 if connection_file is None:
259 258 # get connection file from current kernel
260 259 cf = get_connection_file()
261 260 else:
262 261 # connection file specified, allow shortnames:
263 262 cf = find_connection_file(connection_file, profile=profile)
264 263
265 264 with open(cf) as f:
266 265 info = f.read()
267 266
268 267 if unpack:
269 268 info = json.loads(info)
270 269 # ensure key is bytes:
271 270 info['key'] = str_to_bytes(info.get('key', ''))
272 271 return info
273 272
274 273
275 274 def connect_qtconsole(connection_file=None, argv=None, profile=None):
276 275 """Connect a qtconsole to the current kernel.
277 276
278 277 This is useful for connecting a second qtconsole to a kernel, or to a
279 278 local notebook.
280 279
281 280 Parameters
282 281 ----------
283 282 connection_file : str [optional]
284 283 The connection file to be used. Can be given by absolute path, or
285 284 IPython will search in the security directory of a given profile.
286 285 If run from IPython,
287 286
288 287 If unspecified, the connection file for the currently running
289 288 IPython Kernel will be used, which is only allowed from inside a kernel.
290 289 argv : list [optional]
291 290 Any extra args to be passed to the console.
292 291 profile : str [optional]
293 292 The name of the profile to use when searching for the connection file,
294 293 if different from the current IPython session or 'default'.
295 294
296 295
297 296 Returns
298 297 -------
299 298 :class:`subprocess.Popen` instance running the qtconsole frontend
300 299 """
301 300 argv = [] if argv is None else argv
302 301
303 302 if connection_file is None:
304 303 # get connection file from current kernel
305 304 cf = get_connection_file()
306 305 else:
307 306 cf = find_connection_file(connection_file, profile=profile)
308 307
309 308 cmd = ';'.join([
310 309 "from IPython.qt.console import qtconsoleapp",
311 310 "qtconsoleapp.main()"
312 311 ])
313 312
314 313 return Popen([sys.executable, '-c', cmd, '--existing', cf] + argv,
315 314 stdout=PIPE, stderr=PIPE, close_fds=(sys.platform != 'win32'),
316 315 )
317 316
318 317
319 318 def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
320 319 """tunnel connections to a kernel via ssh
321 320
322 321 This will open four SSH tunnels from localhost on this machine to the
323 322 ports associated with the kernel. They can be either direct
324 323 localhost-localhost tunnels, or if an intermediate server is necessary,
325 324 the kernel must be listening on a public IP.
326 325
327 326 Parameters
328 327 ----------
329 328 connection_info : dict or str (path)
330 329 Either a connection dict, or the path to a JSON connection file
331 330 sshserver : str
332 331 The ssh sever to use to tunnel to the kernel. Can be a full
333 332 `user@server:port` string. ssh config aliases are respected.
334 333 sshkey : str [optional]
335 334 Path to file containing ssh key to use for authentication.
336 335 Only necessary if your ssh config does not already associate
337 336 a keyfile with the host.
338 337
339 338 Returns
340 339 -------
341 340
342 341 (shell, iopub, stdin, hb) : ints
343 342 The four ports on localhost that have been forwarded to the kernel.
344 343 """
344 from zmq.ssh import tunnel
345 345 if isinstance(connection_info, string_types):
346 346 # it's a path, unpack it
347 347 with open(connection_info) as f:
348 348 connection_info = json.loads(f.read())
349 349
350 350 cf = connection_info
351 351
352 352 lports = tunnel.select_random_ports(4)
353 353 rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port']
354 354
355 355 remote_ip = cf['ip']
356 356
357 357 if tunnel.try_passwordless_ssh(sshserver, sshkey):
358 358 password=False
359 359 else:
360 360 password = getpass("SSH Password for %s: " % cast_bytes_py2(sshserver))
361 361
362 362 for lp,rp in zip(lports, rports):
363 363 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
364 364
365 365 return tuple(lports)
366 366
367 367
368 368 #-----------------------------------------------------------------------------
369 369 # Mixin for classes that work with connection files
370 370 #-----------------------------------------------------------------------------
371 371
372 372 channel_socket_types = {
373 373 'hb' : zmq.REQ,
374 374 'shell' : zmq.DEALER,
375 375 'iopub' : zmq.SUB,
376 376 'stdin' : zmq.DEALER,
377 377 'control': zmq.DEALER,
378 378 }
379 379
380 380 port_names = [ "%s_port" % channel for channel in ('shell', 'stdin', 'iopub', 'hb', 'control')]
381 381
382 382 class ConnectionFileMixin(LoggingConfigurable):
383 383 """Mixin for configurable classes that work with connection files"""
384 384
385 385 # The addresses for the communication channels
386 386 connection_file = Unicode('', config=True,
387 387 help="""JSON file in which to store connection info [default: kernel-<pid>.json]
388 388
389 389 This file will contain the IP, ports, and authentication key needed to connect
390 390 clients to this kernel. By default, this file will be created in the security dir
391 391 of the current profile, but can be specified by absolute path.
392 392 """)
393 393 _connection_file_written = Bool(False)
394 394
395 395 transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
396 396
397 397 ip = Unicode(config=True,
398 398 help="""Set the kernel\'s IP address [default localhost].
399 399 If the IP address is something other than localhost, then
400 400 Consoles on other machines will be able to connect
401 401 to the Kernel, so be careful!"""
402 402 )
403 403
404 404 def _ip_default(self):
405 405 if self.transport == 'ipc':
406 406 if self.connection_file:
407 407 return os.path.splitext(self.connection_file)[0] + '-ipc'
408 408 else:
409 409 return 'kernel-ipc'
410 410 else:
411 411 return localhost()
412 412
413 413 def _ip_changed(self, name, old, new):
414 414 if new == '*':
415 415 self.ip = '0.0.0.0'
416 416
417 417 # protected traits
418 418
419 419 hb_port = Integer(0, config=True,
420 420 help="set the heartbeat port [default: random]")
421 421 shell_port = Integer(0, config=True,
422 422 help="set the shell (ROUTER) port [default: random]")
423 423 iopub_port = Integer(0, config=True,
424 424 help="set the iopub (PUB) port [default: random]")
425 425 stdin_port = Integer(0, config=True,
426 426 help="set the stdin (ROUTER) port [default: random]")
427 427 control_port = Integer(0, config=True,
428 428 help="set the control (ROUTER) port [default: random]")
429 429
430 430 @property
431 431 def ports(self):
432 432 return [ getattr(self, name) for name in port_names ]
433 433
434 434 # The Session to use for communication with the kernel.
435 435 session = Instance('IPython.kernel.zmq.session.Session')
436 436 def _session_default(self):
437 437 from IPython.kernel.zmq.session import Session
438 438 return Session(parent=self)
439 439
440 440 #--------------------------------------------------------------------------
441 441 # Connection and ipc file management
442 442 #--------------------------------------------------------------------------
443 443
444 444 def get_connection_info(self):
445 445 """return the connection info as a dict"""
446 446 return dict(
447 447 transport=self.transport,
448 448 ip=self.ip,
449 449 shell_port=self.shell_port,
450 450 iopub_port=self.iopub_port,
451 451 stdin_port=self.stdin_port,
452 452 hb_port=self.hb_port,
453 453 control_port=self.control_port,
454 454 signature_scheme=self.session.signature_scheme,
455 455 key=self.session.key,
456 456 )
457 457
458 458 def cleanup_connection_file(self):
459 459 """Cleanup connection file *if we wrote it*
460 460
461 461 Will not raise if the connection file was already removed somehow.
462 462 """
463 463 if self._connection_file_written:
464 464 # cleanup connection files on full shutdown of kernel we started
465 465 self._connection_file_written = False
466 466 try:
467 467 os.remove(self.connection_file)
468 468 except (IOError, OSError, AttributeError):
469 469 pass
470 470
471 471 def cleanup_ipc_files(self):
472 472 """Cleanup ipc files if we wrote them."""
473 473 if self.transport != 'ipc':
474 474 return
475 475 for port in self.ports:
476 476 ipcfile = "%s-%i" % (self.ip, port)
477 477 try:
478 478 os.remove(ipcfile)
479 479 except (IOError, OSError):
480 480 pass
481 481
482 482 def write_connection_file(self):
483 483 """Write connection info to JSON dict in self.connection_file."""
484 484 if self._connection_file_written and os.path.exists(self.connection_file):
485 485 return
486 486
487 487 self.connection_file, cfg = write_connection_file(self.connection_file,
488 488 transport=self.transport, ip=self.ip, key=self.session.key,
489 489 stdin_port=self.stdin_port, iopub_port=self.iopub_port,
490 490 shell_port=self.shell_port, hb_port=self.hb_port,
491 491 control_port=self.control_port,
492 492 signature_scheme=self.session.signature_scheme,
493 493 )
494 494 # write_connection_file also sets default ports:
495 495 for name in port_names:
496 496 setattr(self, name, cfg[name])
497 497
498 498 self._connection_file_written = True
499 499
500 500 def load_connection_file(self):
501 501 """Load connection info from JSON dict in self.connection_file."""
502 502 self.log.debug(u"Loading connection file %s", self.connection_file)
503 503 with open(self.connection_file) as f:
504 504 cfg = json.load(f)
505 505 self.transport = cfg.get('transport', self.transport)
506 506 self.ip = cfg.get('ip', self._ip_default())
507 507
508 508 for name in port_names:
509 509 if getattr(self, name) == 0 and name in cfg:
510 510 # not overridden by config or cl_args
511 511 setattr(self, name, cfg[name])
512 512
513 513 if 'key' in cfg:
514 514 self.session.key = str_to_bytes(cfg['key'])
515 515 if 'signature_scheme' in cfg:
516 516 self.session.signature_scheme = cfg['signature_scheme']
517 517
518 518 #--------------------------------------------------------------------------
519 519 # Creating connected sockets
520 520 #--------------------------------------------------------------------------
521 521
522 522 def _make_url(self, channel):
523 523 """Make a ZeroMQ URL for a given channel."""
524 524 transport = self.transport
525 525 ip = self.ip
526 526 port = getattr(self, '%s_port' % channel)
527 527
528 528 if transport == 'tcp':
529 529 return "tcp://%s:%i" % (ip, port)
530 530 else:
531 531 return "%s://%s-%s" % (transport, ip, port)
532 532
533 533 def _create_connected_socket(self, channel, identity=None):
534 534 """Create a zmq Socket and connect it to the kernel."""
535 535 url = self._make_url(channel)
536 536 socket_type = channel_socket_types[channel]
537 537 self.log.debug("Connecting to: %s" % url)
538 538 sock = self.context.socket(socket_type)
539 539 # set linger to 1s to prevent hangs at exit
540 540 sock.linger = 1000
541 541 if identity:
542 542 sock.identity = identity
543 543 sock.connect(url)
544 544 return sock
545 545
546 546 def connect_iopub(self, identity=None):
547 547 """return zmq Socket connected to the IOPub channel"""
548 548 sock = self._create_connected_socket('iopub', identity=identity)
549 549 sock.setsockopt(zmq.SUBSCRIBE, b'')
550 550 return sock
551 551
552 552 def connect_shell(self, identity=None):
553 553 """return zmq Socket connected to the Shell channel"""
554 554 return self._create_connected_socket('shell', identity=identity)
555 555
556 556 def connect_stdin(self, identity=None):
557 557 """return zmq Socket connected to the StdIn channel"""
558 558 return self._create_connected_socket('stdin', identity=identity)
559 559
560 560 def connect_hb(self, identity=None):
561 561 """return zmq Socket connected to the Heartbeat channel"""
562 562 return self._create_connected_socket('hb', identity=identity)
563 563
564 564 def connect_control(self, identity=None):
565 565 """return zmq Socket connected to the Heartbeat channel"""
566 566 return self._create_connected_socket('control', identity=identity)
567 567
568 568
569 569 __all__ = [
570 570 'write_connection_file',
571 571 'get_connection_file',
572 572 'find_connection_file',
573 573 'get_connection_info',
574 574 'connect_qtconsole',
575 575 'tunnel_to_kernel',
576 576 ]
@@ -1,1868 +1,1870 b''
1 1 """A semi-synchronous Client for IPython parallel"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import print_function
7 7
8 8 import os
9 9 import json
10 10 import sys
11 11 from threading import Thread, Event
12 12 import time
13 13 import warnings
14 14 from datetime import datetime
15 15 from getpass import getpass
16 16 from pprint import pprint
17 17
18 18 pjoin = os.path.join
19 19
20 20 import zmq
21 from zmq.ssh import tunnel
22 21
23 22 from IPython.config.configurable import MultipleInstanceError
24 23 from IPython.core.application import BaseIPythonApplication
25 24 from IPython.core.profiledir import ProfileDir, ProfileDirError
26 25
27 26 from IPython.utils.capture import RichOutput
28 27 from IPython.utils.coloransi import TermColors
29 28 from IPython.utils.jsonutil import rekey, extract_dates, parse_date
30 29 from IPython.utils.localinterfaces import localhost, is_local_ip
31 30 from IPython.utils.path import get_ipython_dir
32 31 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
33 32 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
34 33 Dict, List, Bool, Set, Any)
35 34 from IPython.external.decorator import decorator
36 35
37 36 from IPython.parallel import Reference
38 37 from IPython.parallel import error
39 38 from IPython.parallel import util
40 39
41 40 from IPython.kernel.zmq.session import Session, Message
42 41 from IPython.kernel.zmq import serialize
43 42
44 43 from .asyncresult import AsyncResult, AsyncHubResult
45 44 from .view import DirectView, LoadBalancedView
46 45
47 46 #--------------------------------------------------------------------------
48 47 # Decorators for Client methods
49 48 #--------------------------------------------------------------------------
50 49
51 50 @decorator
52 51 def spin_first(f, self, *args, **kwargs):
53 52 """Call spin() to sync state prior to calling the method."""
54 53 self.spin()
55 54 return f(self, *args, **kwargs)
56 55
57 56
58 57 #--------------------------------------------------------------------------
59 58 # Classes
60 59 #--------------------------------------------------------------------------
61 60
62 61
63 62 class ExecuteReply(RichOutput):
64 63 """wrapper for finished Execute results"""
65 64 def __init__(self, msg_id, content, metadata):
66 65 self.msg_id = msg_id
67 66 self._content = content
68 67 self.execution_count = content['execution_count']
69 68 self.metadata = metadata
70 69
71 70 # RichOutput overrides
72 71
73 72 @property
74 73 def source(self):
75 74 execute_result = self.metadata['execute_result']
76 75 if execute_result:
77 76 return execute_result.get('source', '')
78 77
79 78 @property
80 79 def data(self):
81 80 execute_result = self.metadata['execute_result']
82 81 if execute_result:
83 82 return execute_result.get('data', {})
84 83
85 84 @property
86 85 def _metadata(self):
87 86 execute_result = self.metadata['execute_result']
88 87 if execute_result:
89 88 return execute_result.get('metadata', {})
90 89
91 90 def display(self):
92 91 from IPython.display import publish_display_data
93 92 publish_display_data(self.data, self.metadata)
94 93
95 94 def _repr_mime_(self, mime):
96 95 if mime not in self.data:
97 96 return
98 97 data = self.data[mime]
99 98 if mime in self._metadata:
100 99 return data, self._metadata[mime]
101 100 else:
102 101 return data
103 102
104 103 def __getitem__(self, key):
105 104 return self.metadata[key]
106 105
107 106 def __getattr__(self, key):
108 107 if key not in self.metadata:
109 108 raise AttributeError(key)
110 109 return self.metadata[key]
111 110
112 111 def __repr__(self):
113 112 execute_result = self.metadata['execute_result'] or {'data':{}}
114 113 text_out = execute_result['data'].get('text/plain', '')
115 114 if len(text_out) > 32:
116 115 text_out = text_out[:29] + '...'
117 116
118 117 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
119 118
120 119 def _repr_pretty_(self, p, cycle):
121 120 execute_result = self.metadata['execute_result'] or {'data':{}}
122 121 text_out = execute_result['data'].get('text/plain', '')
123 122
124 123 if not text_out:
125 124 return
126 125
127 126 try:
128 127 ip = get_ipython()
129 128 except NameError:
130 129 colors = "NoColor"
131 130 else:
132 131 colors = ip.colors
133 132
134 133 if colors == "NoColor":
135 134 out = normal = ""
136 135 else:
137 136 out = TermColors.Red
138 137 normal = TermColors.Normal
139 138
140 139 if '\n' in text_out and not text_out.startswith('\n'):
141 140 # add newline for multiline reprs
142 141 text_out = '\n' + text_out
143 142
144 143 p.text(
145 144 out + u'Out[%i:%i]: ' % (
146 145 self.metadata['engine_id'], self.execution_count
147 146 ) + normal + text_out
148 147 )
149 148
150 149
151 150 class Metadata(dict):
152 151 """Subclass of dict for initializing metadata values.
153 152
154 153 Attribute access works on keys.
155 154
156 155 These objects have a strict set of keys - errors will raise if you try
157 156 to add new keys.
158 157 """
159 158 def __init__(self, *args, **kwargs):
160 159 dict.__init__(self)
161 160 md = {'msg_id' : None,
162 161 'submitted' : None,
163 162 'started' : None,
164 163 'completed' : None,
165 164 'received' : None,
166 165 'engine_uuid' : None,
167 166 'engine_id' : None,
168 167 'follow' : None,
169 168 'after' : None,
170 169 'status' : None,
171 170
172 171 'execute_input' : None,
173 172 'execute_result' : None,
174 173 'error' : None,
175 174 'stdout' : '',
176 175 'stderr' : '',
177 176 'outputs' : [],
178 177 'data': {},
179 178 'outputs_ready' : False,
180 179 }
181 180 self.update(md)
182 181 self.update(dict(*args, **kwargs))
183 182
184 183 def __getattr__(self, key):
185 184 """getattr aliased to getitem"""
186 185 if key in self:
187 186 return self[key]
188 187 else:
189 188 raise AttributeError(key)
190 189
191 190 def __setattr__(self, key, value):
192 191 """setattr aliased to setitem, with strict"""
193 192 if key in self:
194 193 self[key] = value
195 194 else:
196 195 raise AttributeError(key)
197 196
198 197 def __setitem__(self, key, value):
199 198 """strict static key enforcement"""
200 199 if key in self:
201 200 dict.__setitem__(self, key, value)
202 201 else:
203 202 raise KeyError(key)
204 203
205 204
206 205 class Client(HasTraits):
207 206 """A semi-synchronous client to the IPython ZMQ cluster
208 207
209 208 Parameters
210 209 ----------
211 210
212 211 url_file : str/unicode; path to ipcontroller-client.json
213 212 This JSON file should contain all the information needed to connect to a cluster,
214 213 and is likely the only argument needed.
215 214 Connection information for the Hub's registration. If a json connector
216 215 file is given, then likely no further configuration is necessary.
217 216 [Default: use profile]
218 217 profile : bytes
219 218 The name of the Cluster profile to be used to find connector information.
220 219 If run from an IPython application, the default profile will be the same
221 220 as the running application, otherwise it will be 'default'.
222 221 cluster_id : str
223 222 String id to added to runtime files, to prevent name collisions when using
224 223 multiple clusters with a single profile simultaneously.
225 224 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
226 225 Since this is text inserted into filenames, typical recommendations apply:
227 226 Simple character strings are ideal, and spaces are not recommended (but
228 227 should generally work)
229 228 context : zmq.Context
230 229 Pass an existing zmq.Context instance, otherwise the client will create its own.
231 230 debug : bool
232 231 flag for lots of message printing for debug purposes
233 232 timeout : int/float
234 233 time (in seconds) to wait for connection replies from the Hub
235 234 [Default: 10]
236 235
237 236 #-------------- session related args ----------------
238 237
239 238 config : Config object
240 239 If specified, this will be relayed to the Session for configuration
241 240 username : str
242 241 set username for the session object
243 242
244 243 #-------------- ssh related args ----------------
245 244 # These are args for configuring the ssh tunnel to be used
246 245 # credentials are used to forward connections over ssh to the Controller
247 246 # Note that the ip given in `addr` needs to be relative to sshserver
248 247 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
249 248 # and set sshserver as the same machine the Controller is on. However,
250 249 # the only requirement is that sshserver is able to see the Controller
251 250 # (i.e. is within the same trusted network).
252 251
253 252 sshserver : str
254 253 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
255 254 If keyfile or password is specified, and this is not, it will default to
256 255 the ip given in addr.
257 256 sshkey : str; path to ssh private key file
258 257 This specifies a key to be used in ssh login, default None.
259 258 Regular default ssh keys will be used without specifying this argument.
260 259 password : str
261 260 Your ssh password to sshserver. Note that if this is left None,
262 261 you will be prompted for it if passwordless key based login is unavailable.
263 262 paramiko : bool
264 263 flag for whether to use paramiko instead of shell ssh for tunneling.
265 264 [default: True on win32, False else]
266 265
267 266
268 267 Attributes
269 268 ----------
270 269
271 270 ids : list of int engine IDs
272 271 requesting the ids attribute always synchronizes
273 272 the registration state. To request ids without synchronization,
274 273 use semi-private _ids attributes.
275 274
276 275 history : list of msg_ids
277 276 a list of msg_ids, keeping track of all the execution
278 277 messages you have submitted in order.
279 278
280 279 outstanding : set of msg_ids
281 280 a set of msg_ids that have been submitted, but whose
282 281 results have not yet been received.
283 282
284 283 results : dict
285 284 a dict of all our results, keyed by msg_id
286 285
287 286 block : bool
288 287 determines default behavior when block not specified
289 288 in execution methods
290 289
291 290 Methods
292 291 -------
293 292
294 293 spin
295 294 flushes incoming results and registration state changes
296 295 control methods spin, and requesting `ids` also ensures up to date
297 296
298 297 wait
299 298 wait on one or more msg_ids
300 299
301 300 execution methods
302 301 apply
303 302 legacy: execute, run
304 303
305 304 data movement
306 305 push, pull, scatter, gather
307 306
308 307 query methods
309 308 queue_status, get_result, purge, result_status
310 309
311 310 control methods
312 311 abort, shutdown
313 312
314 313 """
315 314
316 315
317 316 block = Bool(False)
318 317 outstanding = Set()
319 318 results = Instance('collections.defaultdict', (dict,))
320 319 metadata = Instance('collections.defaultdict', (Metadata,))
321 320 history = List()
322 321 debug = Bool(False)
323 322 _spin_thread = Any()
324 323 _stop_spinning = Any()
325 324
326 325 profile=Unicode()
327 326 def _profile_default(self):
328 327 if BaseIPythonApplication.initialized():
329 328 # an IPython app *might* be running, try to get its profile
330 329 try:
331 330 return BaseIPythonApplication.instance().profile
332 331 except (AttributeError, MultipleInstanceError):
333 332 # could be a *different* subclass of config.Application,
334 333 # which would raise one of these two errors.
335 334 return u'default'
336 335 else:
337 336 return u'default'
338 337
339 338
340 339 _outstanding_dict = Instance('collections.defaultdict', (set,))
341 340 _ids = List()
342 341 _connected=Bool(False)
343 342 _ssh=Bool(False)
344 343 _context = Instance('zmq.Context')
345 344 _config = Dict()
346 345 _engines=Instance(util.ReverseDict, (), {})
347 346 # _hub_socket=Instance('zmq.Socket')
348 347 _query_socket=Instance('zmq.Socket')
349 348 _control_socket=Instance('zmq.Socket')
350 349 _iopub_socket=Instance('zmq.Socket')
351 350 _notification_socket=Instance('zmq.Socket')
352 351 _mux_socket=Instance('zmq.Socket')
353 352 _task_socket=Instance('zmq.Socket')
354 353 _task_scheme=Unicode()
355 354 _closed = False
356 355 _ignored_control_replies=Integer(0)
357 356 _ignored_hub_replies=Integer(0)
358 357
359 358 def __new__(self, *args, **kw):
360 359 # don't raise on positional args
361 360 return HasTraits.__new__(self, **kw)
362 361
363 362 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
364 363 context=None, debug=False,
365 364 sshserver=None, sshkey=None, password=None, paramiko=None,
366 365 timeout=10, cluster_id=None, **extra_args
367 366 ):
368 367 if profile:
369 368 super(Client, self).__init__(debug=debug, profile=profile)
370 369 else:
371 370 super(Client, self).__init__(debug=debug)
372 371 if context is None:
373 372 context = zmq.Context.instance()
374 373 self._context = context
375 374 self._stop_spinning = Event()
376 375
377 376 if 'url_or_file' in extra_args:
378 377 url_file = extra_args['url_or_file']
379 378 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
380 379
381 380 if url_file and util.is_url(url_file):
382 381 raise ValueError("single urls cannot be specified, url-files must be used.")
383 382
384 383 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
385 384
386 385 if self._cd is not None:
387 386 if url_file is None:
388 387 if not cluster_id:
389 388 client_json = 'ipcontroller-client.json'
390 389 else:
391 390 client_json = 'ipcontroller-%s-client.json' % cluster_id
392 391 url_file = pjoin(self._cd.security_dir, client_json)
393 392 if url_file is None:
394 393 raise ValueError(
395 394 "I can't find enough information to connect to a hub!"
396 395 " Please specify at least one of url_file or profile."
397 396 )
398 397
399 398 with open(url_file) as f:
400 399 cfg = json.load(f)
401 400
402 401 self._task_scheme = cfg['task_scheme']
403 402
404 403 # sync defaults from args, json:
405 404 if sshserver:
406 405 cfg['ssh'] = sshserver
407 406
408 407 location = cfg.setdefault('location', None)
409 408
410 409 proto,addr = cfg['interface'].split('://')
411 410 addr = util.disambiguate_ip_address(addr, location)
412 411 cfg['interface'] = "%s://%s" % (proto, addr)
413 412
414 413 # turn interface,port into full urls:
415 414 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
416 415 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
417 416
418 417 url = cfg['registration']
419 418
420 419 if location is not None and addr == localhost():
421 420 # location specified, and connection is expected to be local
422 421 if not is_local_ip(location) and not sshserver:
423 422 # load ssh from JSON *only* if the controller is not on
424 423 # this machine
425 424 sshserver=cfg['ssh']
426 425 if not is_local_ip(location) and not sshserver:
427 426 # warn if no ssh specified, but SSH is probably needed
428 427 # This is only a warning, because the most likely cause
429 428 # is a local Controller on a laptop whose IP is dynamic
430 429 warnings.warn("""
431 430 Controller appears to be listening on localhost, but not on this machine.
432 431 If this is true, you should specify Client(...,sshserver='you@%s')
433 432 or instruct your controller to listen on an external IP."""%location,
434 433 RuntimeWarning)
435 434 elif not sshserver:
436 435 # otherwise sync with cfg
437 436 sshserver = cfg['ssh']
438 437
439 438 self._config = cfg
440 439
441 440 self._ssh = bool(sshserver or sshkey or password)
442 441 if self._ssh and sshserver is None:
443 442 # default to ssh via localhost
444 443 sshserver = addr
445 444 if self._ssh and password is None:
445 from zmq.ssh import tunnel
446 446 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
447 447 password=False
448 448 else:
449 449 password = getpass("SSH Password for %s: "%sshserver)
450 450 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
451 451
452 452 # configure and construct the session
453 453 try:
454 454 extra_args['packer'] = cfg['pack']
455 455 extra_args['unpacker'] = cfg['unpack']
456 456 extra_args['key'] = cast_bytes(cfg['key'])
457 457 extra_args['signature_scheme'] = cfg['signature_scheme']
458 458 except KeyError as exc:
459 459 msg = '\n'.join([
460 460 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
461 461 "If you are reusing connection files, remove them and start ipcontroller again."
462 462 ])
463 463 raise ValueError(msg.format(exc.message))
464 464
465 465 self.session = Session(**extra_args)
466 466
467 467 self._query_socket = self._context.socket(zmq.DEALER)
468 468
469 469 if self._ssh:
470 from zmq.ssh import tunnel
470 471 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
471 472 else:
472 473 self._query_socket.connect(cfg['registration'])
473 474
474 475 self.session.debug = self.debug
475 476
476 477 self._notification_handlers = {'registration_notification' : self._register_engine,
477 478 'unregistration_notification' : self._unregister_engine,
478 479 'shutdown_notification' : lambda msg: self.close(),
479 480 }
480 481 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
481 482 'apply_reply' : self._handle_apply_reply}
482 483
483 484 try:
484 485 self._connect(sshserver, ssh_kwargs, timeout)
485 486 except:
486 487 self.close(linger=0)
487 488 raise
488 489
489 490 # last step: setup magics, if we are in IPython:
490 491
491 492 try:
492 493 ip = get_ipython()
493 494 except NameError:
494 495 return
495 496 else:
496 497 if 'px' not in ip.magics_manager.magics:
497 498 # in IPython but we are the first Client.
498 499 # activate a default view for parallel magics.
499 500 self.activate()
500 501
501 502 def __del__(self):
502 503 """cleanup sockets, but _not_ context."""
503 504 self.close()
504 505
505 506 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
506 507 if ipython_dir is None:
507 508 ipython_dir = get_ipython_dir()
508 509 if profile_dir is not None:
509 510 try:
510 511 self._cd = ProfileDir.find_profile_dir(profile_dir)
511 512 return
512 513 except ProfileDirError:
513 514 pass
514 515 elif profile is not None:
515 516 try:
516 517 self._cd = ProfileDir.find_profile_dir_by_name(
517 518 ipython_dir, profile)
518 519 return
519 520 except ProfileDirError:
520 521 pass
521 522 self._cd = None
522 523
523 524 def _update_engines(self, engines):
524 525 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
525 526 for k,v in iteritems(engines):
526 527 eid = int(k)
527 528 if eid not in self._engines:
528 529 self._ids.append(eid)
529 530 self._engines[eid] = v
530 531 self._ids = sorted(self._ids)
531 532 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
532 533 self._task_scheme == 'pure' and self._task_socket:
533 534 self._stop_scheduling_tasks()
534 535
535 536 def _stop_scheduling_tasks(self):
536 537 """Stop scheduling tasks because an engine has been unregistered
537 538 from a pure ZMQ scheduler.
538 539 """
539 540 self._task_socket.close()
540 541 self._task_socket = None
541 542 msg = "An engine has been unregistered, and we are using pure " +\
542 543 "ZMQ task scheduling. Task farming will be disabled."
543 544 if self.outstanding:
544 545 msg += " If you were running tasks when this happened, " +\
545 546 "some `outstanding` msg_ids may never resolve."
546 547 warnings.warn(msg, RuntimeWarning)
547 548
548 549 def _build_targets(self, targets):
549 550 """Turn valid target IDs or 'all' into two lists:
550 551 (int_ids, uuids).
551 552 """
552 553 if not self._ids:
553 554 # flush notification socket if no engines yet, just in case
554 555 if not self.ids:
555 556 raise error.NoEnginesRegistered("Can't build targets without any engines")
556 557
557 558 if targets is None:
558 559 targets = self._ids
559 560 elif isinstance(targets, string_types):
560 561 if targets.lower() == 'all':
561 562 targets = self._ids
562 563 else:
563 564 raise TypeError("%r not valid str target, must be 'all'"%(targets))
564 565 elif isinstance(targets, int):
565 566 if targets < 0:
566 567 targets = self.ids[targets]
567 568 if targets not in self._ids:
568 569 raise IndexError("No such engine: %i"%targets)
569 570 targets = [targets]
570 571
571 572 if isinstance(targets, slice):
572 573 indices = list(range(len(self._ids))[targets])
573 574 ids = self.ids
574 575 targets = [ ids[i] for i in indices ]
575 576
576 577 if not isinstance(targets, (tuple, list, xrange)):
577 578 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
578 579
579 580 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
580 581
581 582 def _connect(self, sshserver, ssh_kwargs, timeout):
582 583 """setup all our socket connections to the cluster. This is called from
583 584 __init__."""
584 585
585 586 # Maybe allow reconnecting?
586 587 if self._connected:
587 588 return
588 589 self._connected=True
589 590
590 591 def connect_socket(s, url):
591 592 if self._ssh:
593 from zmq.ssh import tunnel
592 594 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
593 595 else:
594 596 return s.connect(url)
595 597
596 598 self.session.send(self._query_socket, 'connection_request')
597 599 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
598 600 poller = zmq.Poller()
599 601 poller.register(self._query_socket, zmq.POLLIN)
600 602 # poll expects milliseconds, timeout is seconds
601 603 evts = poller.poll(timeout*1000)
602 604 if not evts:
603 605 raise error.TimeoutError("Hub connection request timed out")
604 606 idents,msg = self.session.recv(self._query_socket,mode=0)
605 607 if self.debug:
606 608 pprint(msg)
607 609 content = msg['content']
608 610 # self._config['registration'] = dict(content)
609 611 cfg = self._config
610 612 if content['status'] == 'ok':
611 613 self._mux_socket = self._context.socket(zmq.DEALER)
612 614 connect_socket(self._mux_socket, cfg['mux'])
613 615
614 616 self._task_socket = self._context.socket(zmq.DEALER)
615 617 connect_socket(self._task_socket, cfg['task'])
616 618
617 619 self._notification_socket = self._context.socket(zmq.SUB)
618 620 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
619 621 connect_socket(self._notification_socket, cfg['notification'])
620 622
621 623 self._control_socket = self._context.socket(zmq.DEALER)
622 624 connect_socket(self._control_socket, cfg['control'])
623 625
624 626 self._iopub_socket = self._context.socket(zmq.SUB)
625 627 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
626 628 connect_socket(self._iopub_socket, cfg['iopub'])
627 629
628 630 self._update_engines(dict(content['engines']))
629 631 else:
630 632 self._connected = False
631 633 raise Exception("Failed to connect!")
632 634
633 635 #--------------------------------------------------------------------------
634 636 # handlers and callbacks for incoming messages
635 637 #--------------------------------------------------------------------------
636 638
637 639 def _unwrap_exception(self, content):
638 640 """unwrap exception, and remap engine_id to int."""
639 641 e = error.unwrap_exception(content)
640 642 # print e.traceback
641 643 if e.engine_info:
642 644 e_uuid = e.engine_info['engine_uuid']
643 645 eid = self._engines[e_uuid]
644 646 e.engine_info['engine_id'] = eid
645 647 return e
646 648
647 649 def _extract_metadata(self, msg):
648 650 header = msg['header']
649 651 parent = msg['parent_header']
650 652 msg_meta = msg['metadata']
651 653 content = msg['content']
652 654 md = {'msg_id' : parent['msg_id'],
653 655 'received' : datetime.now(),
654 656 'engine_uuid' : msg_meta.get('engine', None),
655 657 'follow' : msg_meta.get('follow', []),
656 658 'after' : msg_meta.get('after', []),
657 659 'status' : content['status'],
658 660 }
659 661
660 662 if md['engine_uuid'] is not None:
661 663 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
662 664
663 665 if 'date' in parent:
664 666 md['submitted'] = parent['date']
665 667 if 'started' in msg_meta:
666 668 md['started'] = parse_date(msg_meta['started'])
667 669 if 'date' in header:
668 670 md['completed'] = header['date']
669 671 return md
670 672
671 673 def _register_engine(self, msg):
672 674 """Register a new engine, and update our connection info."""
673 675 content = msg['content']
674 676 eid = content['id']
675 677 d = {eid : content['uuid']}
676 678 self._update_engines(d)
677 679
678 680 def _unregister_engine(self, msg):
679 681 """Unregister an engine that has died."""
680 682 content = msg['content']
681 683 eid = int(content['id'])
682 684 if eid in self._ids:
683 685 self._ids.remove(eid)
684 686 uuid = self._engines.pop(eid)
685 687
686 688 self._handle_stranded_msgs(eid, uuid)
687 689
688 690 if self._task_socket and self._task_scheme == 'pure':
689 691 self._stop_scheduling_tasks()
690 692
691 693 def _handle_stranded_msgs(self, eid, uuid):
692 694 """Handle messages known to be on an engine when the engine unregisters.
693 695
694 696 It is possible that this will fire prematurely - that is, an engine will
695 697 go down after completing a result, and the client will be notified
696 698 of the unregistration and later receive the successful result.
697 699 """
698 700
699 701 outstanding = self._outstanding_dict[uuid]
700 702
701 703 for msg_id in list(outstanding):
702 704 if msg_id in self.results:
703 705 # we already
704 706 continue
705 707 try:
706 708 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
707 709 except:
708 710 content = error.wrap_exception()
709 711 # build a fake message:
710 712 msg = self.session.msg('apply_reply', content=content)
711 713 msg['parent_header']['msg_id'] = msg_id
712 714 msg['metadata']['engine'] = uuid
713 715 self._handle_apply_reply(msg)
714 716
715 717 def _handle_execute_reply(self, msg):
716 718 """Save the reply to an execute_request into our results.
717 719
718 720 execute messages are never actually used. apply is used instead.
719 721 """
720 722
721 723 parent = msg['parent_header']
722 724 msg_id = parent['msg_id']
723 725 if msg_id not in self.outstanding:
724 726 if msg_id in self.history:
725 727 print("got stale result: %s"%msg_id)
726 728 else:
727 729 print("got unknown result: %s"%msg_id)
728 730 else:
729 731 self.outstanding.remove(msg_id)
730 732
731 733 content = msg['content']
732 734 header = msg['header']
733 735
734 736 # construct metadata:
735 737 md = self.metadata[msg_id]
736 738 md.update(self._extract_metadata(msg))
737 739 # is this redundant?
738 740 self.metadata[msg_id] = md
739 741
740 742 e_outstanding = self._outstanding_dict[md['engine_uuid']]
741 743 if msg_id in e_outstanding:
742 744 e_outstanding.remove(msg_id)
743 745
744 746 # construct result:
745 747 if content['status'] == 'ok':
746 748 self.results[msg_id] = ExecuteReply(msg_id, content, md)
747 749 elif content['status'] == 'aborted':
748 750 self.results[msg_id] = error.TaskAborted(msg_id)
749 751 elif content['status'] == 'resubmitted':
750 752 # TODO: handle resubmission
751 753 pass
752 754 else:
753 755 self.results[msg_id] = self._unwrap_exception(content)
754 756
755 757 def _handle_apply_reply(self, msg):
756 758 """Save the reply to an apply_request into our results."""
757 759 parent = msg['parent_header']
758 760 msg_id = parent['msg_id']
759 761 if msg_id not in self.outstanding:
760 762 if msg_id in self.history:
761 763 print("got stale result: %s"%msg_id)
762 764 print(self.results[msg_id])
763 765 print(msg)
764 766 else:
765 767 print("got unknown result: %s"%msg_id)
766 768 else:
767 769 self.outstanding.remove(msg_id)
768 770 content = msg['content']
769 771 header = msg['header']
770 772
771 773 # construct metadata:
772 774 md = self.metadata[msg_id]
773 775 md.update(self._extract_metadata(msg))
774 776 # is this redundant?
775 777 self.metadata[msg_id] = md
776 778
777 779 e_outstanding = self._outstanding_dict[md['engine_uuid']]
778 780 if msg_id in e_outstanding:
779 781 e_outstanding.remove(msg_id)
780 782
781 783 # construct result:
782 784 if content['status'] == 'ok':
783 785 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
784 786 elif content['status'] == 'aborted':
785 787 self.results[msg_id] = error.TaskAborted(msg_id)
786 788 elif content['status'] == 'resubmitted':
787 789 # TODO: handle resubmission
788 790 pass
789 791 else:
790 792 self.results[msg_id] = self._unwrap_exception(content)
791 793
792 794 def _flush_notifications(self):
793 795 """Flush notifications of engine registrations waiting
794 796 in ZMQ queue."""
795 797 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
796 798 while msg is not None:
797 799 if self.debug:
798 800 pprint(msg)
799 801 msg_type = msg['header']['msg_type']
800 802 handler = self._notification_handlers.get(msg_type, None)
801 803 if handler is None:
802 804 raise Exception("Unhandled message type: %s" % msg_type)
803 805 else:
804 806 handler(msg)
805 807 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
806 808
807 809 def _flush_results(self, sock):
808 810 """Flush task or queue results waiting in ZMQ queue."""
809 811 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
810 812 while msg is not None:
811 813 if self.debug:
812 814 pprint(msg)
813 815 msg_type = msg['header']['msg_type']
814 816 handler = self._queue_handlers.get(msg_type, None)
815 817 if handler is None:
816 818 raise Exception("Unhandled message type: %s" % msg_type)
817 819 else:
818 820 handler(msg)
819 821 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
820 822
821 823 def _flush_control(self, sock):
822 824 """Flush replies from the control channel waiting
823 825 in the ZMQ queue.
824 826
825 827 Currently: ignore them."""
826 828 if self._ignored_control_replies <= 0:
827 829 return
828 830 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
829 831 while msg is not None:
830 832 self._ignored_control_replies -= 1
831 833 if self.debug:
832 834 pprint(msg)
833 835 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
834 836
835 837 def _flush_ignored_control(self):
836 838 """flush ignored control replies"""
837 839 while self._ignored_control_replies > 0:
838 840 self.session.recv(self._control_socket)
839 841 self._ignored_control_replies -= 1
840 842
841 843 def _flush_ignored_hub_replies(self):
842 844 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
843 845 while msg is not None:
844 846 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
845 847
846 848 def _flush_iopub(self, sock):
847 849 """Flush replies from the iopub channel waiting
848 850 in the ZMQ queue.
849 851 """
850 852 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
851 853 while msg is not None:
852 854 if self.debug:
853 855 pprint(msg)
854 856 parent = msg['parent_header']
855 857 # ignore IOPub messages with no parent.
856 858 # Caused by print statements or warnings from before the first execution.
857 859 if not parent:
858 860 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
859 861 continue
860 862 msg_id = parent['msg_id']
861 863 content = msg['content']
862 864 header = msg['header']
863 865 msg_type = msg['header']['msg_type']
864 866
865 867 # init metadata:
866 868 md = self.metadata[msg_id]
867 869
868 870 if msg_type == 'stream':
869 871 name = content['name']
870 872 s = md[name] or ''
871 873 md[name] = s + content['data']
872 874 elif msg_type == 'error':
873 875 md.update({'error' : self._unwrap_exception(content)})
874 876 elif msg_type == 'execute_input':
875 877 md.update({'execute_input' : content['code']})
876 878 elif msg_type == 'display_data':
877 879 md['outputs'].append(content)
878 880 elif msg_type == 'execute_result':
879 881 md['execute_result'] = content
880 882 elif msg_type == 'data_message':
881 883 data, remainder = serialize.unserialize_object(msg['buffers'])
882 884 md['data'].update(data)
883 885 elif msg_type == 'status':
884 886 # idle message comes after all outputs
885 887 if content['execution_state'] == 'idle':
886 888 md['outputs_ready'] = True
887 889 else:
888 890 # unhandled msg_type (status, etc.)
889 891 pass
890 892
891 893 # reduntant?
892 894 self.metadata[msg_id] = md
893 895
894 896 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
895 897
896 898 #--------------------------------------------------------------------------
897 899 # len, getitem
898 900 #--------------------------------------------------------------------------
899 901
900 902 def __len__(self):
901 903 """len(client) returns # of engines."""
902 904 return len(self.ids)
903 905
904 906 def __getitem__(self, key):
905 907 """index access returns DirectView multiplexer objects
906 908
907 909 Must be int, slice, or list/tuple/xrange of ints"""
908 910 if not isinstance(key, (int, slice, tuple, list, xrange)):
909 911 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
910 912 else:
911 913 return self.direct_view(key)
912 914
913 915 def __iter__(self):
914 916 """Since we define getitem, Client is iterable
915 917
916 918 but unless we also define __iter__, it won't work correctly unless engine IDs
917 919 start at zero and are continuous.
918 920 """
919 921 for eid in self.ids:
920 922 yield self.direct_view(eid)
921 923
922 924 #--------------------------------------------------------------------------
923 925 # Begin public methods
924 926 #--------------------------------------------------------------------------
925 927
926 928 @property
927 929 def ids(self):
928 930 """Always up-to-date ids property."""
929 931 self._flush_notifications()
930 932 # always copy:
931 933 return list(self._ids)
932 934
933 935 def activate(self, targets='all', suffix=''):
934 936 """Create a DirectView and register it with IPython magics
935 937
936 938 Defines the magics `%px, %autopx, %pxresult, %%px`
937 939
938 940 Parameters
939 941 ----------
940 942
941 943 targets: int, list of ints, or 'all'
942 944 The engines on which the view's magics will run
943 945 suffix: str [default: '']
944 946 The suffix, if any, for the magics. This allows you to have
945 947 multiple views associated with parallel magics at the same time.
946 948
947 949 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
948 950 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
949 951 on engine 0.
950 952 """
951 953 view = self.direct_view(targets)
952 954 view.block = True
953 955 view.activate(suffix)
954 956 return view
955 957
956 958 def close(self, linger=None):
957 959 """Close my zmq Sockets
958 960
959 961 If `linger`, set the zmq LINGER socket option,
960 962 which allows discarding of messages.
961 963 """
962 964 if self._closed:
963 965 return
964 966 self.stop_spin_thread()
965 967 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
966 968 for name in snames:
967 969 socket = getattr(self, name)
968 970 if socket is not None and not socket.closed:
969 971 if linger is not None:
970 972 socket.close(linger=linger)
971 973 else:
972 974 socket.close()
973 975 self._closed = True
974 976
975 977 def _spin_every(self, interval=1):
976 978 """target func for use in spin_thread"""
977 979 while True:
978 980 if self._stop_spinning.is_set():
979 981 return
980 982 time.sleep(interval)
981 983 self.spin()
982 984
983 985 def spin_thread(self, interval=1):
984 986 """call Client.spin() in a background thread on some regular interval
985 987
986 988 This helps ensure that messages don't pile up too much in the zmq queue
987 989 while you are working on other things, or just leaving an idle terminal.
988 990
989 991 It also helps limit potential padding of the `received` timestamp
990 992 on AsyncResult objects, used for timings.
991 993
992 994 Parameters
993 995 ----------
994 996
995 997 interval : float, optional
996 998 The interval on which to spin the client in the background thread
997 999 (simply passed to time.sleep).
998 1000
999 1001 Notes
1000 1002 -----
1001 1003
1002 1004 For precision timing, you may want to use this method to put a bound
1003 1005 on the jitter (in seconds) in `received` timestamps used
1004 1006 in AsyncResult.wall_time.
1005 1007
1006 1008 """
1007 1009 if self._spin_thread is not None:
1008 1010 self.stop_spin_thread()
1009 1011 self._stop_spinning.clear()
1010 1012 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1011 1013 self._spin_thread.daemon = True
1012 1014 self._spin_thread.start()
1013 1015
1014 1016 def stop_spin_thread(self):
1015 1017 """stop background spin_thread, if any"""
1016 1018 if self._spin_thread is not None:
1017 1019 self._stop_spinning.set()
1018 1020 self._spin_thread.join()
1019 1021 self._spin_thread = None
1020 1022
1021 1023 def spin(self):
1022 1024 """Flush any registration notifications and execution results
1023 1025 waiting in the ZMQ queue.
1024 1026 """
1025 1027 if self._notification_socket:
1026 1028 self._flush_notifications()
1027 1029 if self._iopub_socket:
1028 1030 self._flush_iopub(self._iopub_socket)
1029 1031 if self._mux_socket:
1030 1032 self._flush_results(self._mux_socket)
1031 1033 if self._task_socket:
1032 1034 self._flush_results(self._task_socket)
1033 1035 if self._control_socket:
1034 1036 self._flush_control(self._control_socket)
1035 1037 if self._query_socket:
1036 1038 self._flush_ignored_hub_replies()
1037 1039
1038 1040 def wait(self, jobs=None, timeout=-1):
1039 1041 """waits on one or more `jobs`, for up to `timeout` seconds.
1040 1042
1041 1043 Parameters
1042 1044 ----------
1043 1045
1044 1046 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1045 1047 ints are indices to self.history
1046 1048 strs are msg_ids
1047 1049 default: wait on all outstanding messages
1048 1050 timeout : float
1049 1051 a time in seconds, after which to give up.
1050 1052 default is -1, which means no timeout
1051 1053
1052 1054 Returns
1053 1055 -------
1054 1056
1055 1057 True : when all msg_ids are done
1056 1058 False : timeout reached, some msg_ids still outstanding
1057 1059 """
1058 1060 tic = time.time()
1059 1061 if jobs is None:
1060 1062 theids = self.outstanding
1061 1063 else:
1062 1064 if isinstance(jobs, string_types + (int, AsyncResult)):
1063 1065 jobs = [jobs]
1064 1066 theids = set()
1065 1067 for job in jobs:
1066 1068 if isinstance(job, int):
1067 1069 # index access
1068 1070 job = self.history[job]
1069 1071 elif isinstance(job, AsyncResult):
1070 1072 theids.update(job.msg_ids)
1071 1073 continue
1072 1074 theids.add(job)
1073 1075 if not theids.intersection(self.outstanding):
1074 1076 return True
1075 1077 self.spin()
1076 1078 while theids.intersection(self.outstanding):
1077 1079 if timeout >= 0 and ( time.time()-tic ) > timeout:
1078 1080 break
1079 1081 time.sleep(1e-3)
1080 1082 self.spin()
1081 1083 return len(theids.intersection(self.outstanding)) == 0
1082 1084
1083 1085 #--------------------------------------------------------------------------
1084 1086 # Control methods
1085 1087 #--------------------------------------------------------------------------
1086 1088
1087 1089 @spin_first
1088 1090 def clear(self, targets=None, block=None):
1089 1091 """Clear the namespace in target(s)."""
1090 1092 block = self.block if block is None else block
1091 1093 targets = self._build_targets(targets)[0]
1092 1094 for t in targets:
1093 1095 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1094 1096 error = False
1095 1097 if block:
1096 1098 self._flush_ignored_control()
1097 1099 for i in range(len(targets)):
1098 1100 idents,msg = self.session.recv(self._control_socket,0)
1099 1101 if self.debug:
1100 1102 pprint(msg)
1101 1103 if msg['content']['status'] != 'ok':
1102 1104 error = self._unwrap_exception(msg['content'])
1103 1105 else:
1104 1106 self._ignored_control_replies += len(targets)
1105 1107 if error:
1106 1108 raise error
1107 1109
1108 1110
1109 1111 @spin_first
1110 1112 def abort(self, jobs=None, targets=None, block=None):
1111 1113 """Abort specific jobs from the execution queues of target(s).
1112 1114
1113 1115 This is a mechanism to prevent jobs that have already been submitted
1114 1116 from executing.
1115 1117
1116 1118 Parameters
1117 1119 ----------
1118 1120
1119 1121 jobs : msg_id, list of msg_ids, or AsyncResult
1120 1122 The jobs to be aborted
1121 1123
1122 1124 If unspecified/None: abort all outstanding jobs.
1123 1125
1124 1126 """
1125 1127 block = self.block if block is None else block
1126 1128 jobs = jobs if jobs is not None else list(self.outstanding)
1127 1129 targets = self._build_targets(targets)[0]
1128 1130
1129 1131 msg_ids = []
1130 1132 if isinstance(jobs, string_types + (AsyncResult,)):
1131 1133 jobs = [jobs]
1132 1134 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1133 1135 if bad_ids:
1134 1136 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1135 1137 for j in jobs:
1136 1138 if isinstance(j, AsyncResult):
1137 1139 msg_ids.extend(j.msg_ids)
1138 1140 else:
1139 1141 msg_ids.append(j)
1140 1142 content = dict(msg_ids=msg_ids)
1141 1143 for t in targets:
1142 1144 self.session.send(self._control_socket, 'abort_request',
1143 1145 content=content, ident=t)
1144 1146 error = False
1145 1147 if block:
1146 1148 self._flush_ignored_control()
1147 1149 for i in range(len(targets)):
1148 1150 idents,msg = self.session.recv(self._control_socket,0)
1149 1151 if self.debug:
1150 1152 pprint(msg)
1151 1153 if msg['content']['status'] != 'ok':
1152 1154 error = self._unwrap_exception(msg['content'])
1153 1155 else:
1154 1156 self._ignored_control_replies += len(targets)
1155 1157 if error:
1156 1158 raise error
1157 1159
1158 1160 @spin_first
1159 1161 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1160 1162 """Terminates one or more engine processes, optionally including the hub.
1161 1163
1162 1164 Parameters
1163 1165 ----------
1164 1166
1165 1167 targets: list of ints or 'all' [default: all]
1166 1168 Which engines to shutdown.
1167 1169 hub: bool [default: False]
1168 1170 Whether to include the Hub. hub=True implies targets='all'.
1169 1171 block: bool [default: self.block]
1170 1172 Whether to wait for clean shutdown replies or not.
1171 1173 restart: bool [default: False]
1172 1174 NOT IMPLEMENTED
1173 1175 whether to restart engines after shutting them down.
1174 1176 """
1175 1177 from IPython.parallel.error import NoEnginesRegistered
1176 1178 if restart:
1177 1179 raise NotImplementedError("Engine restart is not yet implemented")
1178 1180
1179 1181 block = self.block if block is None else block
1180 1182 if hub:
1181 1183 targets = 'all'
1182 1184 try:
1183 1185 targets = self._build_targets(targets)[0]
1184 1186 except NoEnginesRegistered:
1185 1187 targets = []
1186 1188 for t in targets:
1187 1189 self.session.send(self._control_socket, 'shutdown_request',
1188 1190 content={'restart':restart},ident=t)
1189 1191 error = False
1190 1192 if block or hub:
1191 1193 self._flush_ignored_control()
1192 1194 for i in range(len(targets)):
1193 1195 idents,msg = self.session.recv(self._control_socket, 0)
1194 1196 if self.debug:
1195 1197 pprint(msg)
1196 1198 if msg['content']['status'] != 'ok':
1197 1199 error = self._unwrap_exception(msg['content'])
1198 1200 else:
1199 1201 self._ignored_control_replies += len(targets)
1200 1202
1201 1203 if hub:
1202 1204 time.sleep(0.25)
1203 1205 self.session.send(self._query_socket, 'shutdown_request')
1204 1206 idents,msg = self.session.recv(self._query_socket, 0)
1205 1207 if self.debug:
1206 1208 pprint(msg)
1207 1209 if msg['content']['status'] != 'ok':
1208 1210 error = self._unwrap_exception(msg['content'])
1209 1211
1210 1212 if error:
1211 1213 raise error
1212 1214
1213 1215 #--------------------------------------------------------------------------
1214 1216 # Execution related methods
1215 1217 #--------------------------------------------------------------------------
1216 1218
1217 1219 def _maybe_raise(self, result):
1218 1220 """wrapper for maybe raising an exception if apply failed."""
1219 1221 if isinstance(result, error.RemoteError):
1220 1222 raise result
1221 1223
1222 1224 return result
1223 1225
1224 1226 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1225 1227 ident=None):
1226 1228 """construct and send an apply message via a socket.
1227 1229
1228 1230 This is the principal method with which all engine execution is performed by views.
1229 1231 """
1230 1232
1231 1233 if self._closed:
1232 1234 raise RuntimeError("Client cannot be used after its sockets have been closed")
1233 1235
1234 1236 # defaults:
1235 1237 args = args if args is not None else []
1236 1238 kwargs = kwargs if kwargs is not None else {}
1237 1239 metadata = metadata if metadata is not None else {}
1238 1240
1239 1241 # validate arguments
1240 1242 if not callable(f) and not isinstance(f, Reference):
1241 1243 raise TypeError("f must be callable, not %s"%type(f))
1242 1244 if not isinstance(args, (tuple, list)):
1243 1245 raise TypeError("args must be tuple or list, not %s"%type(args))
1244 1246 if not isinstance(kwargs, dict):
1245 1247 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1246 1248 if not isinstance(metadata, dict):
1247 1249 raise TypeError("metadata must be dict, not %s"%type(metadata))
1248 1250
1249 1251 bufs = serialize.pack_apply_message(f, args, kwargs,
1250 1252 buffer_threshold=self.session.buffer_threshold,
1251 1253 item_threshold=self.session.item_threshold,
1252 1254 )
1253 1255
1254 1256 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1255 1257 metadata=metadata, track=track)
1256 1258
1257 1259 msg_id = msg['header']['msg_id']
1258 1260 self.outstanding.add(msg_id)
1259 1261 if ident:
1260 1262 # possibly routed to a specific engine
1261 1263 if isinstance(ident, list):
1262 1264 ident = ident[-1]
1263 1265 if ident in self._engines.values():
1264 1266 # save for later, in case of engine death
1265 1267 self._outstanding_dict[ident].add(msg_id)
1266 1268 self.history.append(msg_id)
1267 1269 self.metadata[msg_id]['submitted'] = datetime.now()
1268 1270
1269 1271 return msg
1270 1272
1271 1273 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1272 1274 """construct and send an execute request via a socket.
1273 1275
1274 1276 """
1275 1277
1276 1278 if self._closed:
1277 1279 raise RuntimeError("Client cannot be used after its sockets have been closed")
1278 1280
1279 1281 # defaults:
1280 1282 metadata = metadata if metadata is not None else {}
1281 1283
1282 1284 # validate arguments
1283 1285 if not isinstance(code, string_types):
1284 1286 raise TypeError("code must be text, not %s" % type(code))
1285 1287 if not isinstance(metadata, dict):
1286 1288 raise TypeError("metadata must be dict, not %s" % type(metadata))
1287 1289
1288 1290 content = dict(code=code, silent=bool(silent), user_expressions={})
1289 1291
1290 1292
1291 1293 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1292 1294 metadata=metadata)
1293 1295
1294 1296 msg_id = msg['header']['msg_id']
1295 1297 self.outstanding.add(msg_id)
1296 1298 if ident:
1297 1299 # possibly routed to a specific engine
1298 1300 if isinstance(ident, list):
1299 1301 ident = ident[-1]
1300 1302 if ident in self._engines.values():
1301 1303 # save for later, in case of engine death
1302 1304 self._outstanding_dict[ident].add(msg_id)
1303 1305 self.history.append(msg_id)
1304 1306 self.metadata[msg_id]['submitted'] = datetime.now()
1305 1307
1306 1308 return msg
1307 1309
1308 1310 #--------------------------------------------------------------------------
1309 1311 # construct a View object
1310 1312 #--------------------------------------------------------------------------
1311 1313
1312 1314 def load_balanced_view(self, targets=None):
1313 1315 """construct a DirectView object.
1314 1316
1315 1317 If no arguments are specified, create a LoadBalancedView
1316 1318 using all engines.
1317 1319
1318 1320 Parameters
1319 1321 ----------
1320 1322
1321 1323 targets: list,slice,int,etc. [default: use all engines]
1322 1324 The subset of engines across which to load-balance
1323 1325 """
1324 1326 if targets == 'all':
1325 1327 targets = None
1326 1328 if targets is not None:
1327 1329 targets = self._build_targets(targets)[1]
1328 1330 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1329 1331
1330 1332 def direct_view(self, targets='all'):
1331 1333 """construct a DirectView object.
1332 1334
1333 1335 If no targets are specified, create a DirectView using all engines.
1334 1336
1335 1337 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1336 1338 evaluate the target engines at each execution, whereas rc[:] will connect to
1337 1339 all *current* engines, and that list will not change.
1338 1340
1339 1341 That is, 'all' will always use all engines, whereas rc[:] will not use
1340 1342 engines added after the DirectView is constructed.
1341 1343
1342 1344 Parameters
1343 1345 ----------
1344 1346
1345 1347 targets: list,slice,int,etc. [default: use all engines]
1346 1348 The engines to use for the View
1347 1349 """
1348 1350 single = isinstance(targets, int)
1349 1351 # allow 'all' to be lazily evaluated at each execution
1350 1352 if targets != 'all':
1351 1353 targets = self._build_targets(targets)[1]
1352 1354 if single:
1353 1355 targets = targets[0]
1354 1356 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1355 1357
1356 1358 #--------------------------------------------------------------------------
1357 1359 # Query methods
1358 1360 #--------------------------------------------------------------------------
1359 1361
1360 1362 @spin_first
1361 1363 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
1362 1364 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1363 1365
1364 1366 If the client already has the results, no request to the Hub will be made.
1365 1367
1366 1368 This is a convenient way to construct AsyncResult objects, which are wrappers
1367 1369 that include metadata about execution, and allow for awaiting results that
1368 1370 were not submitted by this Client.
1369 1371
1370 1372 It can also be a convenient way to retrieve the metadata associated with
1371 1373 blocking execution, since it always retrieves
1372 1374
1373 1375 Examples
1374 1376 --------
1375 1377 ::
1376 1378
1377 1379 In [10]: r = client.apply()
1378 1380
1379 1381 Parameters
1380 1382 ----------
1381 1383
1382 1384 indices_or_msg_ids : integer history index, str msg_id, or list of either
1383 1385 The indices or msg_ids of indices to be retrieved
1384 1386
1385 1387 block : bool
1386 1388 Whether to wait for the result to be done
1387 1389 owner : bool [default: True]
1388 1390 Whether this AsyncResult should own the result.
1389 1391 If so, calling `ar.get()` will remove data from the
1390 1392 client's result and metadata cache.
1391 1393 There should only be one owner of any given msg_id.
1392 1394
1393 1395 Returns
1394 1396 -------
1395 1397
1396 1398 AsyncResult
1397 1399 A single AsyncResult object will always be returned.
1398 1400
1399 1401 AsyncHubResult
1400 1402 A subclass of AsyncResult that retrieves results from the Hub
1401 1403
1402 1404 """
1403 1405 block = self.block if block is None else block
1404 1406 if indices_or_msg_ids is None:
1405 1407 indices_or_msg_ids = -1
1406 1408
1407 1409 single_result = False
1408 1410 if not isinstance(indices_or_msg_ids, (list,tuple)):
1409 1411 indices_or_msg_ids = [indices_or_msg_ids]
1410 1412 single_result = True
1411 1413
1412 1414 theids = []
1413 1415 for id in indices_or_msg_ids:
1414 1416 if isinstance(id, int):
1415 1417 id = self.history[id]
1416 1418 if not isinstance(id, string_types):
1417 1419 raise TypeError("indices must be str or int, not %r"%id)
1418 1420 theids.append(id)
1419 1421
1420 1422 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1421 1423 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1422 1424
1423 1425 # given single msg_id initially, get_result shot get the result itself,
1424 1426 # not a length-one list
1425 1427 if single_result:
1426 1428 theids = theids[0]
1427 1429
1428 1430 if remote_ids:
1429 1431 ar = AsyncHubResult(self, msg_ids=theids, owner=owner)
1430 1432 else:
1431 1433 ar = AsyncResult(self, msg_ids=theids, owner=owner)
1432 1434
1433 1435 if block:
1434 1436 ar.wait()
1435 1437
1436 1438 return ar
1437 1439
1438 1440 @spin_first
1439 1441 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1440 1442 """Resubmit one or more tasks.
1441 1443
1442 1444 in-flight tasks may not be resubmitted.
1443 1445
1444 1446 Parameters
1445 1447 ----------
1446 1448
1447 1449 indices_or_msg_ids : integer history index, str msg_id, or list of either
1448 1450 The indices or msg_ids of indices to be retrieved
1449 1451
1450 1452 block : bool
1451 1453 Whether to wait for the result to be done
1452 1454
1453 1455 Returns
1454 1456 -------
1455 1457
1456 1458 AsyncHubResult
1457 1459 A subclass of AsyncResult that retrieves results from the Hub
1458 1460
1459 1461 """
1460 1462 block = self.block if block is None else block
1461 1463 if indices_or_msg_ids is None:
1462 1464 indices_or_msg_ids = -1
1463 1465
1464 1466 if not isinstance(indices_or_msg_ids, (list,tuple)):
1465 1467 indices_or_msg_ids = [indices_or_msg_ids]
1466 1468
1467 1469 theids = []
1468 1470 for id in indices_or_msg_ids:
1469 1471 if isinstance(id, int):
1470 1472 id = self.history[id]
1471 1473 if not isinstance(id, string_types):
1472 1474 raise TypeError("indices must be str or int, not %r"%id)
1473 1475 theids.append(id)
1474 1476
1475 1477 content = dict(msg_ids = theids)
1476 1478
1477 1479 self.session.send(self._query_socket, 'resubmit_request', content)
1478 1480
1479 1481 zmq.select([self._query_socket], [], [])
1480 1482 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1481 1483 if self.debug:
1482 1484 pprint(msg)
1483 1485 content = msg['content']
1484 1486 if content['status'] != 'ok':
1485 1487 raise self._unwrap_exception(content)
1486 1488 mapping = content['resubmitted']
1487 1489 new_ids = [ mapping[msg_id] for msg_id in theids ]
1488 1490
1489 1491 ar = AsyncHubResult(self, msg_ids=new_ids)
1490 1492
1491 1493 if block:
1492 1494 ar.wait()
1493 1495
1494 1496 return ar
1495 1497
1496 1498 @spin_first
1497 1499 def result_status(self, msg_ids, status_only=True):
1498 1500 """Check on the status of the result(s) of the apply request with `msg_ids`.
1499 1501
1500 1502 If status_only is False, then the actual results will be retrieved, else
1501 1503 only the status of the results will be checked.
1502 1504
1503 1505 Parameters
1504 1506 ----------
1505 1507
1506 1508 msg_ids : list of msg_ids
1507 1509 if int:
1508 1510 Passed as index to self.history for convenience.
1509 1511 status_only : bool (default: True)
1510 1512 if False:
1511 1513 Retrieve the actual results of completed tasks.
1512 1514
1513 1515 Returns
1514 1516 -------
1515 1517
1516 1518 results : dict
1517 1519 There will always be the keys 'pending' and 'completed', which will
1518 1520 be lists of msg_ids that are incomplete or complete. If `status_only`
1519 1521 is False, then completed results will be keyed by their `msg_id`.
1520 1522 """
1521 1523 if not isinstance(msg_ids, (list,tuple)):
1522 1524 msg_ids = [msg_ids]
1523 1525
1524 1526 theids = []
1525 1527 for msg_id in msg_ids:
1526 1528 if isinstance(msg_id, int):
1527 1529 msg_id = self.history[msg_id]
1528 1530 if not isinstance(msg_id, string_types):
1529 1531 raise TypeError("msg_ids must be str, not %r"%msg_id)
1530 1532 theids.append(msg_id)
1531 1533
1532 1534 completed = []
1533 1535 local_results = {}
1534 1536
1535 1537 # comment this block out to temporarily disable local shortcut:
1536 1538 for msg_id in theids:
1537 1539 if msg_id in self.results:
1538 1540 completed.append(msg_id)
1539 1541 local_results[msg_id] = self.results[msg_id]
1540 1542 theids.remove(msg_id)
1541 1543
1542 1544 if theids: # some not locally cached
1543 1545 content = dict(msg_ids=theids, status_only=status_only)
1544 1546 msg = self.session.send(self._query_socket, "result_request", content=content)
1545 1547 zmq.select([self._query_socket], [], [])
1546 1548 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1547 1549 if self.debug:
1548 1550 pprint(msg)
1549 1551 content = msg['content']
1550 1552 if content['status'] != 'ok':
1551 1553 raise self._unwrap_exception(content)
1552 1554 buffers = msg['buffers']
1553 1555 else:
1554 1556 content = dict(completed=[],pending=[])
1555 1557
1556 1558 content['completed'].extend(completed)
1557 1559
1558 1560 if status_only:
1559 1561 return content
1560 1562
1561 1563 failures = []
1562 1564 # load cached results into result:
1563 1565 content.update(local_results)
1564 1566
1565 1567 # update cache with results:
1566 1568 for msg_id in sorted(theids):
1567 1569 if msg_id in content['completed']:
1568 1570 rec = content[msg_id]
1569 1571 parent = extract_dates(rec['header'])
1570 1572 header = extract_dates(rec['result_header'])
1571 1573 rcontent = rec['result_content']
1572 1574 iodict = rec['io']
1573 1575 if isinstance(rcontent, str):
1574 1576 rcontent = self.session.unpack(rcontent)
1575 1577
1576 1578 md = self.metadata[msg_id]
1577 1579 md_msg = dict(
1578 1580 content=rcontent,
1579 1581 parent_header=parent,
1580 1582 header=header,
1581 1583 metadata=rec['result_metadata'],
1582 1584 )
1583 1585 md.update(self._extract_metadata(md_msg))
1584 1586 if rec.get('received'):
1585 1587 md['received'] = parse_date(rec['received'])
1586 1588 md.update(iodict)
1587 1589
1588 1590 if rcontent['status'] == 'ok':
1589 1591 if header['msg_type'] == 'apply_reply':
1590 1592 res,buffers = serialize.unserialize_object(buffers)
1591 1593 elif header['msg_type'] == 'execute_reply':
1592 1594 res = ExecuteReply(msg_id, rcontent, md)
1593 1595 else:
1594 1596 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1595 1597 else:
1596 1598 res = self._unwrap_exception(rcontent)
1597 1599 failures.append(res)
1598 1600
1599 1601 self.results[msg_id] = res
1600 1602 content[msg_id] = res
1601 1603
1602 1604 if len(theids) == 1 and failures:
1603 1605 raise failures[0]
1604 1606
1605 1607 error.collect_exceptions(failures, "result_status")
1606 1608 return content
1607 1609
1608 1610 @spin_first
1609 1611 def queue_status(self, targets='all', verbose=False):
1610 1612 """Fetch the status of engine queues.
1611 1613
1612 1614 Parameters
1613 1615 ----------
1614 1616
1615 1617 targets : int/str/list of ints/strs
1616 1618 the engines whose states are to be queried.
1617 1619 default : all
1618 1620 verbose : bool
1619 1621 Whether to return lengths only, or lists of ids for each element
1620 1622 """
1621 1623 if targets == 'all':
1622 1624 # allow 'all' to be evaluated on the engine
1623 1625 engine_ids = None
1624 1626 else:
1625 1627 engine_ids = self._build_targets(targets)[1]
1626 1628 content = dict(targets=engine_ids, verbose=verbose)
1627 1629 self.session.send(self._query_socket, "queue_request", content=content)
1628 1630 idents,msg = self.session.recv(self._query_socket, 0)
1629 1631 if self.debug:
1630 1632 pprint(msg)
1631 1633 content = msg['content']
1632 1634 status = content.pop('status')
1633 1635 if status != 'ok':
1634 1636 raise self._unwrap_exception(content)
1635 1637 content = rekey(content)
1636 1638 if isinstance(targets, int):
1637 1639 return content[targets]
1638 1640 else:
1639 1641 return content
1640 1642
1641 1643 def _build_msgids_from_target(self, targets=None):
1642 1644 """Build a list of msg_ids from the list of engine targets"""
1643 1645 if not targets: # needed as _build_targets otherwise uses all engines
1644 1646 return []
1645 1647 target_ids = self._build_targets(targets)[0]
1646 1648 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1647 1649
1648 1650 def _build_msgids_from_jobs(self, jobs=None):
1649 1651 """Build a list of msg_ids from "jobs" """
1650 1652 if not jobs:
1651 1653 return []
1652 1654 msg_ids = []
1653 1655 if isinstance(jobs, string_types + (AsyncResult,)):
1654 1656 jobs = [jobs]
1655 1657 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1656 1658 if bad_ids:
1657 1659 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1658 1660 for j in jobs:
1659 1661 if isinstance(j, AsyncResult):
1660 1662 msg_ids.extend(j.msg_ids)
1661 1663 else:
1662 1664 msg_ids.append(j)
1663 1665 return msg_ids
1664 1666
1665 1667 def purge_local_results(self, jobs=[], targets=[]):
1666 1668 """Clears the client caches of results and their metadata.
1667 1669
1668 1670 Individual results can be purged by msg_id, or the entire
1669 1671 history of specific targets can be purged.
1670 1672
1671 1673 Use `purge_local_results('all')` to scrub everything from the Clients's
1672 1674 results and metadata caches.
1673 1675
1674 1676 After this call all `AsyncResults` are invalid and should be discarded.
1675 1677
1676 1678 If you must "reget" the results, you can still do so by using
1677 1679 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1678 1680 redownload the results from the hub if they are still available
1679 1681 (i.e `client.purge_hub_results(...)` has not been called.
1680 1682
1681 1683 Parameters
1682 1684 ----------
1683 1685
1684 1686 jobs : str or list of str or AsyncResult objects
1685 1687 the msg_ids whose results should be purged.
1686 1688 targets : int/list of ints
1687 1689 The engines, by integer ID, whose entire result histories are to be purged.
1688 1690
1689 1691 Raises
1690 1692 ------
1691 1693
1692 1694 RuntimeError : if any of the tasks to be purged are still outstanding.
1693 1695
1694 1696 """
1695 1697 if not targets and not jobs:
1696 1698 raise ValueError("Must specify at least one of `targets` and `jobs`")
1697 1699
1698 1700 if jobs == 'all':
1699 1701 if self.outstanding:
1700 1702 raise RuntimeError("Can't purge outstanding tasks: %s" % self.outstanding)
1701 1703 self.results.clear()
1702 1704 self.metadata.clear()
1703 1705 else:
1704 1706 msg_ids = set()
1705 1707 msg_ids.update(self._build_msgids_from_target(targets))
1706 1708 msg_ids.update(self._build_msgids_from_jobs(jobs))
1707 1709 still_outstanding = self.outstanding.intersection(msg_ids)
1708 1710 if still_outstanding:
1709 1711 raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding)
1710 1712 for mid in msg_ids:
1711 1713 self.results.pop(mid, None)
1712 1714 self.metadata.pop(mid, None)
1713 1715
1714 1716
1715 1717 @spin_first
1716 1718 def purge_hub_results(self, jobs=[], targets=[]):
1717 1719 """Tell the Hub to forget results.
1718 1720
1719 1721 Individual results can be purged by msg_id, or the entire
1720 1722 history of specific targets can be purged.
1721 1723
1722 1724 Use `purge_results('all')` to scrub everything from the Hub's db.
1723 1725
1724 1726 Parameters
1725 1727 ----------
1726 1728
1727 1729 jobs : str or list of str or AsyncResult objects
1728 1730 the msg_ids whose results should be forgotten.
1729 1731 targets : int/str/list of ints/strs
1730 1732 The targets, by int_id, whose entire history is to be purged.
1731 1733
1732 1734 default : None
1733 1735 """
1734 1736 if not targets and not jobs:
1735 1737 raise ValueError("Must specify at least one of `targets` and `jobs`")
1736 1738 if targets:
1737 1739 targets = self._build_targets(targets)[1]
1738 1740
1739 1741 # construct msg_ids from jobs
1740 1742 if jobs == 'all':
1741 1743 msg_ids = jobs
1742 1744 else:
1743 1745 msg_ids = self._build_msgids_from_jobs(jobs)
1744 1746
1745 1747 content = dict(engine_ids=targets, msg_ids=msg_ids)
1746 1748 self.session.send(self._query_socket, "purge_request", content=content)
1747 1749 idents, msg = self.session.recv(self._query_socket, 0)
1748 1750 if self.debug:
1749 1751 pprint(msg)
1750 1752 content = msg['content']
1751 1753 if content['status'] != 'ok':
1752 1754 raise self._unwrap_exception(content)
1753 1755
1754 1756 def purge_results(self, jobs=[], targets=[]):
1755 1757 """Clears the cached results from both the hub and the local client
1756 1758
1757 1759 Individual results can be purged by msg_id, or the entire
1758 1760 history of specific targets can be purged.
1759 1761
1760 1762 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1761 1763 the Client's db.
1762 1764
1763 1765 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1764 1766 the same arguments.
1765 1767
1766 1768 Parameters
1767 1769 ----------
1768 1770
1769 1771 jobs : str or list of str or AsyncResult objects
1770 1772 the msg_ids whose results should be forgotten.
1771 1773 targets : int/str/list of ints/strs
1772 1774 The targets, by int_id, whose entire history is to be purged.
1773 1775
1774 1776 default : None
1775 1777 """
1776 1778 self.purge_local_results(jobs=jobs, targets=targets)
1777 1779 self.purge_hub_results(jobs=jobs, targets=targets)
1778 1780
1779 1781 def purge_everything(self):
1780 1782 """Clears all content from previous Tasks from both the hub and the local client
1781 1783
1782 1784 In addition to calling `purge_results("all")` it also deletes the history and
1783 1785 other bookkeeping lists.
1784 1786 """
1785 1787 self.purge_results("all")
1786 1788 self.history = []
1787 1789 self.session.digest_history.clear()
1788 1790
1789 1791 @spin_first
1790 1792 def hub_history(self):
1791 1793 """Get the Hub's history
1792 1794
1793 1795 Just like the Client, the Hub has a history, which is a list of msg_ids.
1794 1796 This will contain the history of all clients, and, depending on configuration,
1795 1797 may contain history across multiple cluster sessions.
1796 1798
1797 1799 Any msg_id returned here is a valid argument to `get_result`.
1798 1800
1799 1801 Returns
1800 1802 -------
1801 1803
1802 1804 msg_ids : list of strs
1803 1805 list of all msg_ids, ordered by task submission time.
1804 1806 """
1805 1807
1806 1808 self.session.send(self._query_socket, "history_request", content={})
1807 1809 idents, msg = self.session.recv(self._query_socket, 0)
1808 1810
1809 1811 if self.debug:
1810 1812 pprint(msg)
1811 1813 content = msg['content']
1812 1814 if content['status'] != 'ok':
1813 1815 raise self._unwrap_exception(content)
1814 1816 else:
1815 1817 return content['history']
1816 1818
1817 1819 @spin_first
1818 1820 def db_query(self, query, keys=None):
1819 1821 """Query the Hub's TaskRecord database
1820 1822
1821 1823 This will return a list of task record dicts that match `query`
1822 1824
1823 1825 Parameters
1824 1826 ----------
1825 1827
1826 1828 query : mongodb query dict
1827 1829 The search dict. See mongodb query docs for details.
1828 1830 keys : list of strs [optional]
1829 1831 The subset of keys to be returned. The default is to fetch everything but buffers.
1830 1832 'msg_id' will *always* be included.
1831 1833 """
1832 1834 if isinstance(keys, string_types):
1833 1835 keys = [keys]
1834 1836 content = dict(query=query, keys=keys)
1835 1837 self.session.send(self._query_socket, "db_request", content=content)
1836 1838 idents, msg = self.session.recv(self._query_socket, 0)
1837 1839 if self.debug:
1838 1840 pprint(msg)
1839 1841 content = msg['content']
1840 1842 if content['status'] != 'ok':
1841 1843 raise self._unwrap_exception(content)
1842 1844
1843 1845 records = content['records']
1844 1846
1845 1847 buffer_lens = content['buffer_lens']
1846 1848 result_buffer_lens = content['result_buffer_lens']
1847 1849 buffers = msg['buffers']
1848 1850 has_bufs = buffer_lens is not None
1849 1851 has_rbufs = result_buffer_lens is not None
1850 1852 for i,rec in enumerate(records):
1851 1853 # unpack datetime objects
1852 1854 for hkey in ('header', 'result_header'):
1853 1855 if hkey in rec:
1854 1856 rec[hkey] = extract_dates(rec[hkey])
1855 1857 for dtkey in ('submitted', 'started', 'completed', 'received'):
1856 1858 if dtkey in rec:
1857 1859 rec[dtkey] = parse_date(rec[dtkey])
1858 1860 # relink buffers
1859 1861 if has_bufs:
1860 1862 blen = buffer_lens[i]
1861 1863 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1862 1864 if has_rbufs:
1863 1865 blen = result_buffer_lens[i]
1864 1866 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1865 1867
1866 1868 return records
1867 1869
1868 1870 __all__ = [ 'Client' ]
@@ -1,166 +1,129 b''
1 1 # encoding: utf-8
2 2
3 3 """Classes used in scattering and gathering sequences.
4 4
5 5 Scattering consists of partitioning a sequence and sending the various
6 6 pieces to individual nodes in a cluster.
7
8
9 Authors:
10
11 * Brian Granger
12 * MinRK
13
14 7 """
15 8
16 #-------------------------------------------------------------------------------
17 # Copyright (C) 2008-2011 The IPython Development Team
18 #
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
21 #-------------------------------------------------------------------------------
22
23 #-------------------------------------------------------------------------------
24 # Imports
25 #-------------------------------------------------------------------------------
9 # Copyright (c) IPython Development Team.
10 # Distributed under the terms of the Modified BSD License.
26 11
27 12 from __future__ import division
28 13
14 import sys
29 15 from itertools import islice
30 16
31 17 from IPython.utils.data import flatten as utils_flatten
32 18
33 #-------------------------------------------------------------------------------
34 # Figure out which array packages are present and their array types
35 #-------------------------------------------------------------------------------
36 19
37 arrayModules = []
38 try:
39 import Numeric
40 except ImportError:
41 pass
42 else:
43 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
44 try:
20 numpy = None
21
22 def is_array(obj):
23 """Is an object a numpy array?
24
25 Avoids importing numpy until it is requested
26 """
27 global numpy
28 if 'numpy' not in sys.modules:
29 return False
30
31 if numpy is None:
45 32 import numpy
46 except ImportError:
47 pass
48 else:
49 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
50 try:
51 import numarray
52 except ImportError:
53 pass
54 else:
55 arrayModules.append({'module':numarray,
56 'type':numarray.numarraycore.NumArray})
33 return isinstance(obj, numpy.ndarray)
57 34
58 35 class Map(object):
59 36 """A class for partitioning a sequence using a map."""
60 37
61 38 def getPartition(self, seq, p, q, n=None):
62 39 """Returns the pth partition of q partitions of seq.
63 40
64 41 The length can be specified as `n`,
65 42 otherwise it is the value of `len(seq)`
66 43 """
67 44 n = len(seq) if n is None else n
68 45 # Test for error conditions here
69 46 if p<0 or p>=q:
70 47 raise ValueError("must have 0 <= p <= q, but have p=%s,q=%s" % (p, q))
71 48
72 49 remainder = n % q
73 50 basesize = n // q
74 51
75 52 if p < remainder:
76 53 low = p * (basesize + 1)
77 54 high = low + basesize + 1
78 55 else:
79 56 low = p * basesize + remainder
80 57 high = low + basesize
81 58
82 59 try:
83 60 result = seq[low:high]
84 61 except TypeError:
85 62 # some objects (iterators) can't be sliced,
86 63 # use islice:
87 64 result = list(islice(seq, low, high))
88 65
89 66 return result
90 67
91 68 def joinPartitions(self, listOfPartitions):
92 69 return self.concatenate(listOfPartitions)
93 70
94 71 def concatenate(self, listOfPartitions):
95 72 testObject = listOfPartitions[0]
96 73 # First see if we have a known array type
97 for m in arrayModules:
98 #print m
99 if isinstance(testObject, m['type']):
100 return m['module'].concatenate(listOfPartitions)
74 if is_array(testObject):
75 return numpy.concatenate(listOfPartitions)
101 76 # Next try for Python sequence types
102 77 if isinstance(testObject, (list, tuple)):
103 78 return utils_flatten(listOfPartitions)
104 79 # If we have scalars, just return listOfPartitions
105 80 return listOfPartitions
106 81
107 82 class RoundRobinMap(Map):
108 83 """Partitions a sequence in a round robin fashion.
109 84
110 85 This currently does not work!
111 86 """
112 87
113 88 def getPartition(self, seq, p, q, n=None):
114 89 n = len(seq) if n is None else n
115 90 return seq[p:n:q]
116 91
117 92 def joinPartitions(self, listOfPartitions):
118 93 testObject = listOfPartitions[0]
119 94 # First see if we have a known array type
120 for m in arrayModules:
121 #print m
122 if isinstance(testObject, m['type']):
123 return self.flatten_array(m['type'], listOfPartitions)
95 if is_array(testObject):
96 return self.flatten_array(listOfPartitions)
124 97 if isinstance(testObject, (list, tuple)):
125 98 return self.flatten_list(listOfPartitions)
126 99 return listOfPartitions
127 100
128 def flatten_array(self, klass, listOfPartitions):
101 def flatten_array(self, listOfPartitions):
129 102 test = listOfPartitions[0]
130 103 shape = list(test.shape)
131 104 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
132 A = klass(shape)
105 A = numpy.ndarray(shape)
133 106 N = shape[0]
134 107 q = len(listOfPartitions)
135 108 for p,part in enumerate(listOfPartitions):
136 109 A[p:N:q] = part
137 110 return A
138 111
139 112 def flatten_list(self, listOfPartitions):
140 113 flat = []
141 114 for i in range(len(listOfPartitions[0])):
142 115 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
143 116 return flat
144 #lengths = [len(x) for x in listOfPartitions]
145 #maxPartitionLength = len(listOfPartitions[0])
146 #numberOfPartitions = len(listOfPartitions)
147 #concat = self.concatenate(listOfPartitions)
148 #totalLength = len(concat)
149 #result = []
150 #for i in range(maxPartitionLength):
151 # result.append(concat[i:totalLength:maxPartitionLength])
152 # return self.concatenate(listOfPartitions)
153 117
154 118 def mappable(obj):
155 119 """return whether an object is mappable or not."""
156 120 if isinstance(obj, (tuple,list)):
157 121 return True
158 for m in arrayModules:
159 if isinstance(obj,m['type']):
122 if is_array(obj):
160 123 return True
161 124 return False
162 125
163 126 dists = {'b':Map,'r':RoundRobinMap}
164 127
165 128
166 129
@@ -1,297 +1,301 b''
1 1 """A simple engine that talks to a controller over 0MQ.
2 2 it handles registration, etc. and launches a kernel
3 3 connected to the Controller's Schedulers.
4 4 """
5 5
6 6 # Copyright (c) IPython Development Team.
7 7 # Distributed under the terms of the Modified BSD License.
8 8
9 9 from __future__ import print_function
10 10
11 11 import sys
12 12 import time
13 13 from getpass import getpass
14 14
15 15 import zmq
16 16 from zmq.eventloop import ioloop, zmqstream
17 from zmq.ssh import tunnel
18 17
19 18 from IPython.utils.localinterfaces import localhost
20 19 from IPython.utils.traitlets import (
21 Instance, Dict, Integer, Type, Float, Integer, Unicode, CBytes, Bool
20 Instance, Dict, Integer, Type, Float, Unicode, CBytes, Bool
22 21 )
23 22 from IPython.utils.py3compat import cast_bytes
24 23
25 24 from IPython.parallel.controller.heartmonitor import Heart
26 25 from IPython.parallel.factory import RegistrationFactory
27 26 from IPython.parallel.util import disambiguate_url
28 27
29 28 from IPython.kernel.zmq.session import Message
30 29 from IPython.kernel.zmq.ipkernel import Kernel
31 30 from IPython.kernel.zmq.kernelapp import IPKernelApp
32 31
33 32 class EngineFactory(RegistrationFactory):
34 33 """IPython engine"""
35 34
36 35 # configurables:
37 36 out_stream_factory=Type('IPython.kernel.zmq.iostream.OutStream', config=True,
38 37 help="""The OutStream for handling stdout/err.
39 38 Typically 'IPython.kernel.zmq.iostream.OutStream'""")
40 39 display_hook_factory=Type('IPython.kernel.zmq.displayhook.ZMQDisplayHook', config=True,
41 40 help="""The class for handling displayhook.
42 41 Typically 'IPython.kernel.zmq.displayhook.ZMQDisplayHook'""")
43 42 location=Unicode(config=True,
44 43 help="""The location (an IP address) of the controller. This is
45 44 used for disambiguating URLs, to determine whether
46 45 loopback should be used to connect or the public address.""")
47 46 timeout=Float(5.0, config=True,
48 47 help="""The time (in seconds) to wait for the Controller to respond
49 48 to registration requests before giving up.""")
50 49 max_heartbeat_misses=Integer(50, config=True,
51 50 help="""The maximum number of times a check for the heartbeat ping of a
52 51 controller can be missed before shutting down the engine.
53 52
54 53 If set to 0, the check is disabled.""")
55 54 sshserver=Unicode(config=True,
56 55 help="""The SSH server to use for tunneling connections to the Controller.""")
57 56 sshkey=Unicode(config=True,
58 57 help="""The SSH private key file to use when tunneling connections to the Controller.""")
59 58 paramiko=Bool(sys.platform == 'win32', config=True,
60 59 help="""Whether to use paramiko instead of openssh for tunnels.""")
61 60
61 @property
62 def tunnel_mod(self):
63 from zmq.ssh import tunnel
64 return tunnel
65
62 66
63 67 # not configurable:
64 68 connection_info = Dict()
65 69 user_ns = Dict()
66 70 id = Integer(allow_none=True)
67 71 registrar = Instance('zmq.eventloop.zmqstream.ZMQStream')
68 72 kernel = Instance(Kernel)
69 73 hb_check_period=Integer()
70 74
71 75 # States for the heartbeat monitoring
72 76 # Initial values for monitored and pinged must satisfy "monitored > pinged == False" so that
73 77 # during the first check no "missed" ping is reported. Must be floats for Python 3 compatibility.
74 78 _hb_last_pinged = 0.0
75 79 _hb_last_monitored = 0.0
76 80 _hb_missed_beats = 0
77 81 # The zmq Stream which receives the pings from the Heart
78 82 _hb_listener = None
79 83
80 84 bident = CBytes()
81 85 ident = Unicode()
82 86 def _ident_changed(self, name, old, new):
83 87 self.bident = cast_bytes(new)
84 88 using_ssh=Bool(False)
85 89
86 90
87 91 def __init__(self, **kwargs):
88 92 super(EngineFactory, self).__init__(**kwargs)
89 93 self.ident = self.session.session
90 94
91 95 def init_connector(self):
92 96 """construct connection function, which handles tunnels."""
93 97 self.using_ssh = bool(self.sshkey or self.sshserver)
94 98
95 99 if self.sshkey and not self.sshserver:
96 100 # We are using ssh directly to the controller, tunneling localhost to localhost
97 101 self.sshserver = self.url.split('://')[1].split(':')[0]
98 102
99 103 if self.using_ssh:
100 if tunnel.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
104 if self.tunnel_mod.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
101 105 password=False
102 106 else:
103 107 password = getpass("SSH Password for %s: "%self.sshserver)
104 108 else:
105 109 password = False
106 110
107 111 def connect(s, url):
108 112 url = disambiguate_url(url, self.location)
109 113 if self.using_ssh:
110 114 self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
111 return tunnel.tunnel_connection(s, url, self.sshserver,
115 return self.tunnel_mod.tunnel_connection(s, url, self.sshserver,
112 116 keyfile=self.sshkey, paramiko=self.paramiko,
113 117 password=password,
114 118 )
115 119 else:
116 120 return s.connect(url)
117 121
118 122 def maybe_tunnel(url):
119 123 """like connect, but don't complete the connection (for use by heartbeat)"""
120 124 url = disambiguate_url(url, self.location)
121 125 if self.using_ssh:
122 126 self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
123 url,tunnelobj = tunnel.open_tunnel(url, self.sshserver,
127 url, tunnelobj = self.tunnel_mod.open_tunnel(url, self.sshserver,
124 128 keyfile=self.sshkey, paramiko=self.paramiko,
125 129 password=password,
126 130 )
127 131 return str(url)
128 132 return connect, maybe_tunnel
129 133
130 134 def register(self):
131 135 """send the registration_request"""
132 136
133 137 self.log.info("Registering with controller at %s"%self.url)
134 138 ctx = self.context
135 139 connect,maybe_tunnel = self.init_connector()
136 140 reg = ctx.socket(zmq.DEALER)
137 141 reg.setsockopt(zmq.IDENTITY, self.bident)
138 142 connect(reg, self.url)
139 143 self.registrar = zmqstream.ZMQStream(reg, self.loop)
140 144
141 145
142 146 content = dict(uuid=self.ident)
143 147 self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
144 148 # print (self.session.key)
145 149 self.session.send(self.registrar, "registration_request", content=content)
146 150
147 151 def _report_ping(self, msg):
148 152 """Callback for when the heartmonitor.Heart receives a ping"""
149 153 #self.log.debug("Received a ping: %s", msg)
150 154 self._hb_last_pinged = time.time()
151 155
152 156 def complete_registration(self, msg, connect, maybe_tunnel):
153 157 # print msg
154 158 self._abort_dc.stop()
155 159 ctx = self.context
156 160 loop = self.loop
157 161 identity = self.bident
158 162 idents,msg = self.session.feed_identities(msg)
159 163 msg = self.session.unserialize(msg)
160 164 content = msg['content']
161 165 info = self.connection_info
162 166
163 167 def url(key):
164 168 """get zmq url for given channel"""
165 169 return str(info["interface"] + ":%i" % info[key])
166 170
167 171 if content['status'] == 'ok':
168 172 self.id = int(content['id'])
169 173
170 174 # launch heartbeat
171 175 # possibly forward hb ports with tunnels
172 176 hb_ping = maybe_tunnel(url('hb_ping'))
173 177 hb_pong = maybe_tunnel(url('hb_pong'))
174 178
175 179 hb_monitor = None
176 180 if self.max_heartbeat_misses > 0:
177 181 # Add a monitor socket which will record the last time a ping was seen
178 182 mon = self.context.socket(zmq.SUB)
179 183 mport = mon.bind_to_random_port('tcp://%s' % localhost())
180 184 mon.setsockopt(zmq.SUBSCRIBE, b"")
181 185 self._hb_listener = zmqstream.ZMQStream(mon, self.loop)
182 186 self._hb_listener.on_recv(self._report_ping)
183 187
184 188
185 189 hb_monitor = "tcp://%s:%i" % (localhost(), mport)
186 190
187 191 heart = Heart(hb_ping, hb_pong, hb_monitor , heart_id=identity)
188 192 heart.start()
189 193
190 194 # create Shell Connections (MUX, Task, etc.):
191 195 shell_addrs = url('mux'), url('task')
192 196
193 197 # Use only one shell stream for mux and tasks
194 198 stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
195 199 stream.setsockopt(zmq.IDENTITY, identity)
196 200 shell_streams = [stream]
197 201 for addr in shell_addrs:
198 202 connect(stream, addr)
199 203
200 204 # control stream:
201 205 control_addr = url('control')
202 206 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
203 207 control_stream.setsockopt(zmq.IDENTITY, identity)
204 208 connect(control_stream, control_addr)
205 209
206 210 # create iopub stream:
207 211 iopub_addr = url('iopub')
208 212 iopub_socket = ctx.socket(zmq.PUB)
209 213 iopub_socket.setsockopt(zmq.IDENTITY, identity)
210 214 connect(iopub_socket, iopub_addr)
211 215
212 216 # disable history:
213 217 self.config.HistoryManager.hist_file = ':memory:'
214 218
215 219 # Redirect input streams and set a display hook.
216 220 if self.out_stream_factory:
217 221 sys.stdout = self.out_stream_factory(self.session, iopub_socket, u'stdout')
218 222 sys.stdout.topic = cast_bytes('engine.%i.stdout' % self.id)
219 223 sys.stderr = self.out_stream_factory(self.session, iopub_socket, u'stderr')
220 224 sys.stderr.topic = cast_bytes('engine.%i.stderr' % self.id)
221 225 if self.display_hook_factory:
222 226 sys.displayhook = self.display_hook_factory(self.session, iopub_socket)
223 227 sys.displayhook.topic = cast_bytes('engine.%i.execute_result' % self.id)
224 228
225 229 self.kernel = Kernel(parent=self, int_id=self.id, ident=self.ident, session=self.session,
226 230 control_stream=control_stream, shell_streams=shell_streams, iopub_socket=iopub_socket,
227 231 loop=loop, user_ns=self.user_ns, log=self.log)
228 232
229 233 self.kernel.shell.display_pub.topic = cast_bytes('engine.%i.displaypub' % self.id)
230 234
231 235
232 236 # periodically check the heartbeat pings of the controller
233 237 # Should be started here and not in "start()" so that the right period can be taken
234 238 # from the hubs HeartBeatMonitor.period
235 239 if self.max_heartbeat_misses > 0:
236 240 # Use a slightly bigger check period than the hub signal period to not warn unnecessary
237 241 self.hb_check_period = int(content['hb_period'])+10
238 242 self.log.info("Starting to monitor the heartbeat signal from the hub every %i ms." , self.hb_check_period)
239 243 self._hb_reporter = ioloop.PeriodicCallback(self._hb_monitor, self.hb_check_period, self.loop)
240 244 self._hb_reporter.start()
241 245 else:
242 246 self.log.info("Monitoring of the heartbeat signal from the hub is not enabled.")
243 247
244 248
245 249 # FIXME: This is a hack until IPKernelApp and IPEngineApp can be fully merged
246 250 app = IPKernelApp(parent=self, shell=self.kernel.shell, kernel=self.kernel, log=self.log)
247 251 app.init_profile_dir()
248 252 app.init_code()
249 253
250 254 self.kernel.start()
251 255 else:
252 256 self.log.fatal("Registration Failed: %s"%msg)
253 257 raise Exception("Registration Failed: %s"%msg)
254 258
255 259 self.log.info("Completed registration with id %i"%self.id)
256 260
257 261
258 262 def abort(self):
259 263 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
260 264 if self.url.startswith('127.'):
261 265 self.log.fatal("""
262 266 If the controller and engines are not on the same machine,
263 267 you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
264 268 c.HubFactory.ip='*' # for all interfaces, internal and external
265 269 c.HubFactory.ip='192.168.1.101' # or any interface that the engines can see
266 270 or tunnel connections via ssh.
267 271 """)
268 272 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
269 273 time.sleep(1)
270 274 sys.exit(255)
271 275
272 276 def _hb_monitor(self):
273 277 """Callback to monitor the heartbeat from the controller"""
274 278 self._hb_listener.flush()
275 279 if self._hb_last_monitored > self._hb_last_pinged:
276 280 self._hb_missed_beats += 1
277 281 self.log.warn("No heartbeat in the last %s ms (%s time(s) in a row).", self.hb_check_period, self._hb_missed_beats)
278 282 else:
279 283 #self.log.debug("Heartbeat received (after missing %s beats).", self._hb_missed_beats)
280 284 self._hb_missed_beats = 0
281 285
282 286 if self._hb_missed_beats >= self.max_heartbeat_misses:
283 287 self.log.fatal("Maximum number of heartbeats misses reached (%s times %s ms), shutting down.",
284 288 self.max_heartbeat_misses, self.hb_check_period)
285 289 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
286 290 self.loop.stop()
287 291
288 292 self._hb_last_monitored = time.time()
289 293
290 294
291 295 def start(self):
292 296 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
293 297 dc.start()
294 298 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
295 299 self._abort_dc.start()
296 300
297 301
General Comments 0
You need to be logged in to leave comments. Login now