##// END OF EJS Templates
Allow shutdown when no engines are registered.
David Hirschfeld -
Show More
@@ -1,1726 +1,1729 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 from threading import Thread, Event
22 22 import time
23 23 import warnings
24 24 from datetime import datetime
25 25 from getpass import getpass
26 26 from pprint import pprint
27 27
28 28 pjoin = os.path.join
29 29
30 30 import zmq
31 31 # from zmq.eventloop import ioloop, zmqstream
32 32
33 33 from IPython.config.configurable import MultipleInstanceError
34 34 from IPython.core.application import BaseIPythonApplication
35 35 from IPython.core.profiledir import ProfileDir, ProfileDirError
36 36
37 37 from IPython.utils.coloransi import TermColors
38 38 from IPython.utils.jsonutil import rekey
39 39 from IPython.utils.localinterfaces import LOCAL_IPS
40 40 from IPython.utils.path import get_ipython_dir
41 41 from IPython.utils.py3compat import cast_bytes
42 42 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
43 43 Dict, List, Bool, Set, Any)
44 44 from IPython.external.decorator import decorator
45 45 from IPython.external.ssh import tunnel
46 46
47 47 from IPython.parallel import Reference
48 48 from IPython.parallel import error
49 49 from IPython.parallel import util
50 50
51 51 from IPython.zmq.session import Session, Message
52 52 from IPython.zmq import serialize
53 53
54 54 from .asyncresult import AsyncResult, AsyncHubResult
55 55 from .view import DirectView, LoadBalancedView
56 56
57 57 if sys.version_info[0] >= 3:
58 58 # xrange is used in a couple 'isinstance' tests in py2
59 59 # should be just 'range' in 3k
60 60 xrange = range
61 61
62 62 #--------------------------------------------------------------------------
63 63 # Decorators for Client methods
64 64 #--------------------------------------------------------------------------
65 65
66 66 @decorator
67 67 def spin_first(f, self, *args, **kwargs):
68 68 """Call spin() to sync state prior to calling the method."""
69 69 self.spin()
70 70 return f(self, *args, **kwargs)
71 71
72 72
73 73 #--------------------------------------------------------------------------
74 74 # Classes
75 75 #--------------------------------------------------------------------------
76 76
77 77
78 78 class ExecuteReply(object):
79 79 """wrapper for finished Execute results"""
80 80 def __init__(self, msg_id, content, metadata):
81 81 self.msg_id = msg_id
82 82 self._content = content
83 83 self.execution_count = content['execution_count']
84 84 self.metadata = metadata
85 85
86 86 def __getitem__(self, key):
87 87 return self.metadata[key]
88 88
89 89 def __getattr__(self, key):
90 90 if key not in self.metadata:
91 91 raise AttributeError(key)
92 92 return self.metadata[key]
93 93
94 94 def __repr__(self):
95 95 pyout = self.metadata['pyout'] or {'data':{}}
96 96 text_out = pyout['data'].get('text/plain', '')
97 97 if len(text_out) > 32:
98 98 text_out = text_out[:29] + '...'
99 99
100 100 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
101 101
102 102 def _repr_pretty_(self, p, cycle):
103 103 pyout = self.metadata['pyout'] or {'data':{}}
104 104 text_out = pyout['data'].get('text/plain', '')
105 105
106 106 if not text_out:
107 107 return
108 108
109 109 try:
110 110 ip = get_ipython()
111 111 except NameError:
112 112 colors = "NoColor"
113 113 else:
114 114 colors = ip.colors
115 115
116 116 if colors == "NoColor":
117 117 out = normal = ""
118 118 else:
119 119 out = TermColors.Red
120 120 normal = TermColors.Normal
121 121
122 122 if '\n' in text_out and not text_out.startswith('\n'):
123 123 # add newline for multiline reprs
124 124 text_out = '\n' + text_out
125 125
126 126 p.text(
127 127 out + u'Out[%i:%i]: ' % (
128 128 self.metadata['engine_id'], self.execution_count
129 129 ) + normal + text_out
130 130 )
131 131
132 132 def _repr_html_(self):
133 133 pyout = self.metadata['pyout'] or {'data':{}}
134 134 return pyout['data'].get("text/html")
135 135
136 136 def _repr_latex_(self):
137 137 pyout = self.metadata['pyout'] or {'data':{}}
138 138 return pyout['data'].get("text/latex")
139 139
140 140 def _repr_json_(self):
141 141 pyout = self.metadata['pyout'] or {'data':{}}
142 142 return pyout['data'].get("application/json")
143 143
144 144 def _repr_javascript_(self):
145 145 pyout = self.metadata['pyout'] or {'data':{}}
146 146 return pyout['data'].get("application/javascript")
147 147
148 148 def _repr_png_(self):
149 149 pyout = self.metadata['pyout'] or {'data':{}}
150 150 return pyout['data'].get("image/png")
151 151
152 152 def _repr_jpeg_(self):
153 153 pyout = self.metadata['pyout'] or {'data':{}}
154 154 return pyout['data'].get("image/jpeg")
155 155
156 156 def _repr_svg_(self):
157 157 pyout = self.metadata['pyout'] or {'data':{}}
158 158 return pyout['data'].get("image/svg+xml")
159 159
160 160
161 161 class Metadata(dict):
162 162 """Subclass of dict for initializing metadata values.
163 163
164 164 Attribute access works on keys.
165 165
166 166 These objects have a strict set of keys - errors will raise if you try
167 167 to add new keys.
168 168 """
169 169 def __init__(self, *args, **kwargs):
170 170 dict.__init__(self)
171 171 md = {'msg_id' : None,
172 172 'submitted' : None,
173 173 'started' : None,
174 174 'completed' : None,
175 175 'received' : None,
176 176 'engine_uuid' : None,
177 177 'engine_id' : None,
178 178 'follow' : None,
179 179 'after' : None,
180 180 'status' : None,
181 181
182 182 'pyin' : None,
183 183 'pyout' : None,
184 184 'pyerr' : None,
185 185 'stdout' : '',
186 186 'stderr' : '',
187 187 'outputs' : [],
188 188 'data': {},
189 189 'outputs_ready' : False,
190 190 }
191 191 self.update(md)
192 192 self.update(dict(*args, **kwargs))
193 193
194 194 def __getattr__(self, key):
195 195 """getattr aliased to getitem"""
196 196 if key in self.iterkeys():
197 197 return self[key]
198 198 else:
199 199 raise AttributeError(key)
200 200
201 201 def __setattr__(self, key, value):
202 202 """setattr aliased to setitem, with strict"""
203 203 if key in self.iterkeys():
204 204 self[key] = value
205 205 else:
206 206 raise AttributeError(key)
207 207
208 208 def __setitem__(self, key, value):
209 209 """strict static key enforcement"""
210 210 if key in self.iterkeys():
211 211 dict.__setitem__(self, key, value)
212 212 else:
213 213 raise KeyError(key)
214 214
215 215
216 216 class Client(HasTraits):
217 217 """A semi-synchronous client to the IPython ZMQ cluster
218 218
219 219 Parameters
220 220 ----------
221 221
222 222 url_file : str/unicode; path to ipcontroller-client.json
223 223 This JSON file should contain all the information needed to connect to a cluster,
224 224 and is likely the only argument needed.
225 225 Connection information for the Hub's registration. If a json connector
226 226 file is given, then likely no further configuration is necessary.
227 227 [Default: use profile]
228 228 profile : bytes
229 229 The name of the Cluster profile to be used to find connector information.
230 230 If run from an IPython application, the default profile will be the same
231 231 as the running application, otherwise it will be 'default'.
232 232 cluster_id : str
233 233 String id to added to runtime files, to prevent name collisions when using
234 234 multiple clusters with a single profile simultaneously.
235 235 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
236 236 Since this is text inserted into filenames, typical recommendations apply:
237 237 Simple character strings are ideal, and spaces are not recommended (but
238 238 should generally work)
239 239 context : zmq.Context
240 240 Pass an existing zmq.Context instance, otherwise the client will create its own.
241 241 debug : bool
242 242 flag for lots of message printing for debug purposes
243 243 timeout : int/float
244 244 time (in seconds) to wait for connection replies from the Hub
245 245 [Default: 10]
246 246
247 247 #-------------- session related args ----------------
248 248
249 249 config : Config object
250 250 If specified, this will be relayed to the Session for configuration
251 251 username : str
252 252 set username for the session object
253 253
254 254 #-------------- ssh related args ----------------
255 255 # These are args for configuring the ssh tunnel to be used
256 256 # credentials are used to forward connections over ssh to the Controller
257 257 # Note that the ip given in `addr` needs to be relative to sshserver
258 258 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
259 259 # and set sshserver as the same machine the Controller is on. However,
260 260 # the only requirement is that sshserver is able to see the Controller
261 261 # (i.e. is within the same trusted network).
262 262
263 263 sshserver : str
264 264 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
265 265 If keyfile or password is specified, and this is not, it will default to
266 266 the ip given in addr.
267 267 sshkey : str; path to ssh private key file
268 268 This specifies a key to be used in ssh login, default None.
269 269 Regular default ssh keys will be used without specifying this argument.
270 270 password : str
271 271 Your ssh password to sshserver. Note that if this is left None,
272 272 you will be prompted for it if passwordless key based login is unavailable.
273 273 paramiko : bool
274 274 flag for whether to use paramiko instead of shell ssh for tunneling.
275 275 [default: True on win32, False else]
276 276
277 277
278 278 Attributes
279 279 ----------
280 280
281 281 ids : list of int engine IDs
282 282 requesting the ids attribute always synchronizes
283 283 the registration state. To request ids without synchronization,
284 284 use semi-private _ids attributes.
285 285
286 286 history : list of msg_ids
287 287 a list of msg_ids, keeping track of all the execution
288 288 messages you have submitted in order.
289 289
290 290 outstanding : set of msg_ids
291 291 a set of msg_ids that have been submitted, but whose
292 292 results have not yet been received.
293 293
294 294 results : dict
295 295 a dict of all our results, keyed by msg_id
296 296
297 297 block : bool
298 298 determines default behavior when block not specified
299 299 in execution methods
300 300
301 301 Methods
302 302 -------
303 303
304 304 spin
305 305 flushes incoming results and registration state changes
306 306 control methods spin, and requesting `ids` also ensures up to date
307 307
308 308 wait
309 309 wait on one or more msg_ids
310 310
311 311 execution methods
312 312 apply
313 313 legacy: execute, run
314 314
315 315 data movement
316 316 push, pull, scatter, gather
317 317
318 318 query methods
319 319 queue_status, get_result, purge, result_status
320 320
321 321 control methods
322 322 abort, shutdown
323 323
324 324 """
325 325
326 326
327 327 block = Bool(False)
328 328 outstanding = Set()
329 329 results = Instance('collections.defaultdict', (dict,))
330 330 metadata = Instance('collections.defaultdict', (Metadata,))
331 331 history = List()
332 332 debug = Bool(False)
333 333 _spin_thread = Any()
334 334 _stop_spinning = Any()
335 335
336 336 profile=Unicode()
337 337 def _profile_default(self):
338 338 if BaseIPythonApplication.initialized():
339 339 # an IPython app *might* be running, try to get its profile
340 340 try:
341 341 return BaseIPythonApplication.instance().profile
342 342 except (AttributeError, MultipleInstanceError):
343 343 # could be a *different* subclass of config.Application,
344 344 # which would raise one of these two errors.
345 345 return u'default'
346 346 else:
347 347 return u'default'
348 348
349 349
350 350 _outstanding_dict = Instance('collections.defaultdict', (set,))
351 351 _ids = List()
352 352 _connected=Bool(False)
353 353 _ssh=Bool(False)
354 354 _context = Instance('zmq.Context')
355 355 _config = Dict()
356 356 _engines=Instance(util.ReverseDict, (), {})
357 357 # _hub_socket=Instance('zmq.Socket')
358 358 _query_socket=Instance('zmq.Socket')
359 359 _control_socket=Instance('zmq.Socket')
360 360 _iopub_socket=Instance('zmq.Socket')
361 361 _notification_socket=Instance('zmq.Socket')
362 362 _mux_socket=Instance('zmq.Socket')
363 363 _task_socket=Instance('zmq.Socket')
364 364 _task_scheme=Unicode()
365 365 _closed = False
366 366 _ignored_control_replies=Integer(0)
367 367 _ignored_hub_replies=Integer(0)
368 368
369 369 def __new__(self, *args, **kw):
370 370 # don't raise on positional args
371 371 return HasTraits.__new__(self, **kw)
372 372
373 373 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
374 374 context=None, debug=False,
375 375 sshserver=None, sshkey=None, password=None, paramiko=None,
376 376 timeout=10, cluster_id=None, **extra_args
377 377 ):
378 378 if profile:
379 379 super(Client, self).__init__(debug=debug, profile=profile)
380 380 else:
381 381 super(Client, self).__init__(debug=debug)
382 382 if context is None:
383 383 context = zmq.Context.instance()
384 384 self._context = context
385 385 self._stop_spinning = Event()
386 386
387 387 if 'url_or_file' in extra_args:
388 388 url_file = extra_args['url_or_file']
389 389 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
390 390
391 391 if url_file and util.is_url(url_file):
392 392 raise ValueError("single urls cannot be specified, url-files must be used.")
393 393
394 394 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
395 395
396 396 if self._cd is not None:
397 397 if url_file is None:
398 398 if not cluster_id:
399 399 client_json = 'ipcontroller-client.json'
400 400 else:
401 401 client_json = 'ipcontroller-%s-client.json' % cluster_id
402 402 url_file = pjoin(self._cd.security_dir, client_json)
403 403 if url_file is None:
404 404 raise ValueError(
405 405 "I can't find enough information to connect to a hub!"
406 406 " Please specify at least one of url_file or profile."
407 407 )
408 408
409 409 with open(url_file) as f:
410 410 cfg = json.load(f)
411 411
412 412 self._task_scheme = cfg['task_scheme']
413 413
414 414 # sync defaults from args, json:
415 415 if sshserver:
416 416 cfg['ssh'] = sshserver
417 417
418 418 location = cfg.setdefault('location', None)
419 419
420 420 proto,addr = cfg['interface'].split('://')
421 421 addr = util.disambiguate_ip_address(addr, location)
422 422 cfg['interface'] = "%s://%s" % (proto, addr)
423 423
424 424 # turn interface,port into full urls:
425 425 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
426 426 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
427 427
428 428 url = cfg['registration']
429 429
430 430 if location is not None and addr == '127.0.0.1':
431 431 # location specified, and connection is expected to be local
432 432 if location not in LOCAL_IPS and not sshserver:
433 433 # load ssh from JSON *only* if the controller is not on
434 434 # this machine
435 435 sshserver=cfg['ssh']
436 436 if location not in LOCAL_IPS and not sshserver:
437 437 # warn if no ssh specified, but SSH is probably needed
438 438 # This is only a warning, because the most likely cause
439 439 # is a local Controller on a laptop whose IP is dynamic
440 440 warnings.warn("""
441 441 Controller appears to be listening on localhost, but not on this machine.
442 442 If this is true, you should specify Client(...,sshserver='you@%s')
443 443 or instruct your controller to listen on an external IP."""%location,
444 444 RuntimeWarning)
445 445 elif not sshserver:
446 446 # otherwise sync with cfg
447 447 sshserver = cfg['ssh']
448 448
449 449 self._config = cfg
450 450
451 451 self._ssh = bool(sshserver or sshkey or password)
452 452 if self._ssh and sshserver is None:
453 453 # default to ssh via localhost
454 454 sshserver = addr
455 455 if self._ssh and password is None:
456 456 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
457 457 password=False
458 458 else:
459 459 password = getpass("SSH Password for %s: "%sshserver)
460 460 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
461 461
462 462 # configure and construct the session
463 463 extra_args['packer'] = cfg['pack']
464 464 extra_args['unpacker'] = cfg['unpack']
465 465 extra_args['key'] = cast_bytes(cfg['exec_key'])
466 466
467 467 self.session = Session(**extra_args)
468 468
469 469 self._query_socket = self._context.socket(zmq.DEALER)
470 470
471 471 if self._ssh:
472 472 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
473 473 else:
474 474 self._query_socket.connect(cfg['registration'])
475 475
476 476 self.session.debug = self.debug
477 477
478 478 self._notification_handlers = {'registration_notification' : self._register_engine,
479 479 'unregistration_notification' : self._unregister_engine,
480 480 'shutdown_notification' : lambda msg: self.close(),
481 481 }
482 482 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
483 483 'apply_reply' : self._handle_apply_reply}
484 484 self._connect(sshserver, ssh_kwargs, timeout)
485 485
486 486 # last step: setup magics, if we are in IPython:
487 487
488 488 try:
489 489 ip = get_ipython()
490 490 except NameError:
491 491 return
492 492 else:
493 493 if 'px' not in ip.magics_manager.magics:
494 494 # in IPython but we are the first Client.
495 495 # activate a default view for parallel magics.
496 496 self.activate()
497 497
498 498 def __del__(self):
499 499 """cleanup sockets, but _not_ context."""
500 500 self.close()
501 501
502 502 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
503 503 if ipython_dir is None:
504 504 ipython_dir = get_ipython_dir()
505 505 if profile_dir is not None:
506 506 try:
507 507 self._cd = ProfileDir.find_profile_dir(profile_dir)
508 508 return
509 509 except ProfileDirError:
510 510 pass
511 511 elif profile is not None:
512 512 try:
513 513 self._cd = ProfileDir.find_profile_dir_by_name(
514 514 ipython_dir, profile)
515 515 return
516 516 except ProfileDirError:
517 517 pass
518 518 self._cd = None
519 519
520 520 def _update_engines(self, engines):
521 521 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
522 522 for k,v in engines.iteritems():
523 523 eid = int(k)
524 524 if eid not in self._engines:
525 525 self._ids.append(eid)
526 526 self._engines[eid] = v
527 527 self._ids = sorted(self._ids)
528 528 if sorted(self._engines.keys()) != range(len(self._engines)) and \
529 529 self._task_scheme == 'pure' and self._task_socket:
530 530 self._stop_scheduling_tasks()
531 531
532 532 def _stop_scheduling_tasks(self):
533 533 """Stop scheduling tasks because an engine has been unregistered
534 534 from a pure ZMQ scheduler.
535 535 """
536 536 self._task_socket.close()
537 537 self._task_socket = None
538 538 msg = "An engine has been unregistered, and we are using pure " +\
539 539 "ZMQ task scheduling. Task farming will be disabled."
540 540 if self.outstanding:
541 541 msg += " If you were running tasks when this happened, " +\
542 542 "some `outstanding` msg_ids may never resolve."
543 543 warnings.warn(msg, RuntimeWarning)
544 544
545 545 def _build_targets(self, targets):
546 546 """Turn valid target IDs or 'all' into two lists:
547 547 (int_ids, uuids).
548 548 """
549 549 if not self._ids:
550 550 # flush notification socket if no engines yet, just in case
551 551 if not self.ids:
552 552 raise error.NoEnginesRegistered("Can't build targets without any engines")
553 553
554 554 if targets is None:
555 555 targets = self._ids
556 556 elif isinstance(targets, basestring):
557 557 if targets.lower() == 'all':
558 558 targets = self._ids
559 559 else:
560 560 raise TypeError("%r not valid str target, must be 'all'"%(targets))
561 561 elif isinstance(targets, int):
562 562 if targets < 0:
563 563 targets = self.ids[targets]
564 564 if targets not in self._ids:
565 565 raise IndexError("No such engine: %i"%targets)
566 566 targets = [targets]
567 567
568 568 if isinstance(targets, slice):
569 569 indices = range(len(self._ids))[targets]
570 570 ids = self.ids
571 571 targets = [ ids[i] for i in indices ]
572 572
573 573 if not isinstance(targets, (tuple, list, xrange)):
574 574 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
575 575
576 576 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
577 577
578 578 def _connect(self, sshserver, ssh_kwargs, timeout):
579 579 """setup all our socket connections to the cluster. This is called from
580 580 __init__."""
581 581
582 582 # Maybe allow reconnecting?
583 583 if self._connected:
584 584 return
585 585 self._connected=True
586 586
587 587 def connect_socket(s, url):
588 588 # url = util.disambiguate_url(url, self._config['location'])
589 589 if self._ssh:
590 590 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
591 591 else:
592 592 return s.connect(url)
593 593
594 594 self.session.send(self._query_socket, 'connection_request')
595 595 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
596 596 poller = zmq.Poller()
597 597 poller.register(self._query_socket, zmq.POLLIN)
598 598 # poll expects milliseconds, timeout is seconds
599 599 evts = poller.poll(timeout*1000)
600 600 if not evts:
601 601 raise error.TimeoutError("Hub connection request timed out")
602 602 idents,msg = self.session.recv(self._query_socket,mode=0)
603 603 if self.debug:
604 604 pprint(msg)
605 605 content = msg['content']
606 606 # self._config['registration'] = dict(content)
607 607 cfg = self._config
608 608 if content['status'] == 'ok':
609 609 self._mux_socket = self._context.socket(zmq.DEALER)
610 610 connect_socket(self._mux_socket, cfg['mux'])
611 611
612 612 self._task_socket = self._context.socket(zmq.DEALER)
613 613 connect_socket(self._task_socket, cfg['task'])
614 614
615 615 self._notification_socket = self._context.socket(zmq.SUB)
616 616 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
617 617 connect_socket(self._notification_socket, cfg['notification'])
618 618
619 619 self._control_socket = self._context.socket(zmq.DEALER)
620 620 connect_socket(self._control_socket, cfg['control'])
621 621
622 622 self._iopub_socket = self._context.socket(zmq.SUB)
623 623 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
624 624 connect_socket(self._iopub_socket, cfg['iopub'])
625 625
626 626 self._update_engines(dict(content['engines']))
627 627 else:
628 628 self._connected = False
629 629 raise Exception("Failed to connect!")
630 630
631 631 #--------------------------------------------------------------------------
632 632 # handlers and callbacks for incoming messages
633 633 #--------------------------------------------------------------------------
634 634
635 635 def _unwrap_exception(self, content):
636 636 """unwrap exception, and remap engine_id to int."""
637 637 e = error.unwrap_exception(content)
638 638 # print e.traceback
639 639 if e.engine_info:
640 640 e_uuid = e.engine_info['engine_uuid']
641 641 eid = self._engines[e_uuid]
642 642 e.engine_info['engine_id'] = eid
643 643 return e
644 644
645 645 def _extract_metadata(self, msg):
646 646 header = msg['header']
647 647 parent = msg['parent_header']
648 648 msg_meta = msg['metadata']
649 649 content = msg['content']
650 650 md = {'msg_id' : parent['msg_id'],
651 651 'received' : datetime.now(),
652 652 'engine_uuid' : msg_meta.get('engine', None),
653 653 'follow' : msg_meta.get('follow', []),
654 654 'after' : msg_meta.get('after', []),
655 655 'status' : content['status'],
656 656 }
657 657
658 658 if md['engine_uuid'] is not None:
659 659 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
660 660
661 661 if 'date' in parent:
662 662 md['submitted'] = parent['date']
663 663 if 'started' in msg_meta:
664 664 md['started'] = msg_meta['started']
665 665 if 'date' in header:
666 666 md['completed'] = header['date']
667 667 return md
668 668
669 669 def _register_engine(self, msg):
670 670 """Register a new engine, and update our connection info."""
671 671 content = msg['content']
672 672 eid = content['id']
673 673 d = {eid : content['uuid']}
674 674 self._update_engines(d)
675 675
676 676 def _unregister_engine(self, msg):
677 677 """Unregister an engine that has died."""
678 678 content = msg['content']
679 679 eid = int(content['id'])
680 680 if eid in self._ids:
681 681 self._ids.remove(eid)
682 682 uuid = self._engines.pop(eid)
683 683
684 684 self._handle_stranded_msgs(eid, uuid)
685 685
686 686 if self._task_socket and self._task_scheme == 'pure':
687 687 self._stop_scheduling_tasks()
688 688
689 689 def _handle_stranded_msgs(self, eid, uuid):
690 690 """Handle messages known to be on an engine when the engine unregisters.
691 691
692 692 It is possible that this will fire prematurely - that is, an engine will
693 693 go down after completing a result, and the client will be notified
694 694 of the unregistration and later receive the successful result.
695 695 """
696 696
697 697 outstanding = self._outstanding_dict[uuid]
698 698
699 699 for msg_id in list(outstanding):
700 700 if msg_id in self.results:
701 701 # we already
702 702 continue
703 703 try:
704 704 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
705 705 except:
706 706 content = error.wrap_exception()
707 707 # build a fake message:
708 708 msg = self.session.msg('apply_reply', content=content)
709 709 msg['parent_header']['msg_id'] = msg_id
710 710 msg['metadata']['engine'] = uuid
711 711 self._handle_apply_reply(msg)
712 712
713 713 def _handle_execute_reply(self, msg):
714 714 """Save the reply to an execute_request into our results.
715 715
716 716 execute messages are never actually used. apply is used instead.
717 717 """
718 718
719 719 parent = msg['parent_header']
720 720 msg_id = parent['msg_id']
721 721 if msg_id not in self.outstanding:
722 722 if msg_id in self.history:
723 723 print ("got stale result: %s"%msg_id)
724 724 else:
725 725 print ("got unknown result: %s"%msg_id)
726 726 else:
727 727 self.outstanding.remove(msg_id)
728 728
729 729 content = msg['content']
730 730 header = msg['header']
731 731
732 732 # construct metadata:
733 733 md = self.metadata[msg_id]
734 734 md.update(self._extract_metadata(msg))
735 735 # is this redundant?
736 736 self.metadata[msg_id] = md
737 737
738 738 e_outstanding = self._outstanding_dict[md['engine_uuid']]
739 739 if msg_id in e_outstanding:
740 740 e_outstanding.remove(msg_id)
741 741
742 742 # construct result:
743 743 if content['status'] == 'ok':
744 744 self.results[msg_id] = ExecuteReply(msg_id, content, md)
745 745 elif content['status'] == 'aborted':
746 746 self.results[msg_id] = error.TaskAborted(msg_id)
747 747 elif content['status'] == 'resubmitted':
748 748 # TODO: handle resubmission
749 749 pass
750 750 else:
751 751 self.results[msg_id] = self._unwrap_exception(content)
752 752
753 753 def _handle_apply_reply(self, msg):
754 754 """Save the reply to an apply_request into our results."""
755 755 parent = msg['parent_header']
756 756 msg_id = parent['msg_id']
757 757 if msg_id not in self.outstanding:
758 758 if msg_id in self.history:
759 759 print ("got stale result: %s"%msg_id)
760 760 print self.results[msg_id]
761 761 print msg
762 762 else:
763 763 print ("got unknown result: %s"%msg_id)
764 764 else:
765 765 self.outstanding.remove(msg_id)
766 766 content = msg['content']
767 767 header = msg['header']
768 768
769 769 # construct metadata:
770 770 md = self.metadata[msg_id]
771 771 md.update(self._extract_metadata(msg))
772 772 # is this redundant?
773 773 self.metadata[msg_id] = md
774 774
775 775 e_outstanding = self._outstanding_dict[md['engine_uuid']]
776 776 if msg_id in e_outstanding:
777 777 e_outstanding.remove(msg_id)
778 778
779 779 # construct result:
780 780 if content['status'] == 'ok':
781 781 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
782 782 elif content['status'] == 'aborted':
783 783 self.results[msg_id] = error.TaskAborted(msg_id)
784 784 elif content['status'] == 'resubmitted':
785 785 # TODO: handle resubmission
786 786 pass
787 787 else:
788 788 self.results[msg_id] = self._unwrap_exception(content)
789 789
790 790 def _flush_notifications(self):
791 791 """Flush notifications of engine registrations waiting
792 792 in ZMQ queue."""
793 793 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
794 794 while msg is not None:
795 795 if self.debug:
796 796 pprint(msg)
797 797 msg_type = msg['header']['msg_type']
798 798 handler = self._notification_handlers.get(msg_type, None)
799 799 if handler is None:
800 800 raise Exception("Unhandled message type: %s"%msg.msg_type)
801 801 else:
802 802 handler(msg)
803 803 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
804 804
805 805 def _flush_results(self, sock):
806 806 """Flush task or queue results waiting in ZMQ queue."""
807 807 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
808 808 while msg is not None:
809 809 if self.debug:
810 810 pprint(msg)
811 811 msg_type = msg['header']['msg_type']
812 812 handler = self._queue_handlers.get(msg_type, None)
813 813 if handler is None:
814 814 raise Exception("Unhandled message type: %s"%msg.msg_type)
815 815 else:
816 816 handler(msg)
817 817 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
818 818
819 819 def _flush_control(self, sock):
820 820 """Flush replies from the control channel waiting
821 821 in the ZMQ queue.
822 822
823 823 Currently: ignore them."""
824 824 if self._ignored_control_replies <= 0:
825 825 return
826 826 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
827 827 while msg is not None:
828 828 self._ignored_control_replies -= 1
829 829 if self.debug:
830 830 pprint(msg)
831 831 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
832 832
833 833 def _flush_ignored_control(self):
834 834 """flush ignored control replies"""
835 835 while self._ignored_control_replies > 0:
836 836 self.session.recv(self._control_socket)
837 837 self._ignored_control_replies -= 1
838 838
839 839 def _flush_ignored_hub_replies(self):
840 840 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
841 841 while msg is not None:
842 842 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
843 843
844 844 def _flush_iopub(self, sock):
845 845 """Flush replies from the iopub channel waiting
846 846 in the ZMQ queue.
847 847 """
848 848 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
849 849 while msg is not None:
850 850 if self.debug:
851 851 pprint(msg)
852 852 parent = msg['parent_header']
853 853 # ignore IOPub messages with no parent.
854 854 # Caused by print statements or warnings from before the first execution.
855 855 if not parent:
856 856 continue
857 857 msg_id = parent['msg_id']
858 858 content = msg['content']
859 859 header = msg['header']
860 860 msg_type = msg['header']['msg_type']
861 861
862 862 # init metadata:
863 863 md = self.metadata[msg_id]
864 864
865 865 if msg_type == 'stream':
866 866 name = content['name']
867 867 s = md[name] or ''
868 868 md[name] = s + content['data']
869 869 elif msg_type == 'pyerr':
870 870 md.update({'pyerr' : self._unwrap_exception(content)})
871 871 elif msg_type == 'pyin':
872 872 md.update({'pyin' : content['code']})
873 873 elif msg_type == 'display_data':
874 874 md['outputs'].append(content)
875 875 elif msg_type == 'pyout':
876 876 md['pyout'] = content
877 877 elif msg_type == 'data_message':
878 878 data, remainder = serialize.unserialize_object(msg['buffers'])
879 879 md['data'].update(data)
880 880 elif msg_type == 'status':
881 881 # idle message comes after all outputs
882 882 if content['execution_state'] == 'idle':
883 883 md['outputs_ready'] = True
884 884 else:
885 885 # unhandled msg_type (status, etc.)
886 886 pass
887 887
888 888 # reduntant?
889 889 self.metadata[msg_id] = md
890 890
891 891 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
892 892
893 893 #--------------------------------------------------------------------------
894 894 # len, getitem
895 895 #--------------------------------------------------------------------------
896 896
897 897 def __len__(self):
898 898 """len(client) returns # of engines."""
899 899 return len(self.ids)
900 900
901 901 def __getitem__(self, key):
902 902 """index access returns DirectView multiplexer objects
903 903
904 904 Must be int, slice, or list/tuple/xrange of ints"""
905 905 if not isinstance(key, (int, slice, tuple, list, xrange)):
906 906 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
907 907 else:
908 908 return self.direct_view(key)
909 909
910 910 #--------------------------------------------------------------------------
911 911 # Begin public methods
912 912 #--------------------------------------------------------------------------
913 913
914 914 @property
915 915 def ids(self):
916 916 """Always up-to-date ids property."""
917 917 self._flush_notifications()
918 918 # always copy:
919 919 return list(self._ids)
920 920
921 921 def activate(self, targets='all', suffix=''):
922 922 """Create a DirectView and register it with IPython magics
923 923
924 924 Defines the magics `%px, %autopx, %pxresult, %%px`
925 925
926 926 Parameters
927 927 ----------
928 928
929 929 targets: int, list of ints, or 'all'
930 930 The engines on which the view's magics will run
931 931 suffix: str [default: '']
932 932 The suffix, if any, for the magics. This allows you to have
933 933 multiple views associated with parallel magics at the same time.
934 934
935 935 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
936 936 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
937 937 on engine 0.
938 938 """
939 939 view = self.direct_view(targets)
940 940 view.block = True
941 941 view.activate(suffix)
942 942 return view
943 943
944 944 def close(self):
945 945 if self._closed:
946 946 return
947 947 self.stop_spin_thread()
948 948 snames = filter(lambda n: n.endswith('socket'), dir(self))
949 949 for socket in map(lambda name: getattr(self, name), snames):
950 950 if isinstance(socket, zmq.Socket) and not socket.closed:
951 951 socket.close()
952 952 self._closed = True
953 953
954 954 def _spin_every(self, interval=1):
955 955 """target func for use in spin_thread"""
956 956 while True:
957 957 if self._stop_spinning.is_set():
958 958 return
959 959 time.sleep(interval)
960 960 self.spin()
961 961
962 962 def spin_thread(self, interval=1):
963 963 """call Client.spin() in a background thread on some regular interval
964 964
965 965 This helps ensure that messages don't pile up too much in the zmq queue
966 966 while you are working on other things, or just leaving an idle terminal.
967 967
968 968 It also helps limit potential padding of the `received` timestamp
969 969 on AsyncResult objects, used for timings.
970 970
971 971 Parameters
972 972 ----------
973 973
974 974 interval : float, optional
975 975 The interval on which to spin the client in the background thread
976 976 (simply passed to time.sleep).
977 977
978 978 Notes
979 979 -----
980 980
981 981 For precision timing, you may want to use this method to put a bound
982 982 on the jitter (in seconds) in `received` timestamps used
983 983 in AsyncResult.wall_time.
984 984
985 985 """
986 986 if self._spin_thread is not None:
987 987 self.stop_spin_thread()
988 988 self._stop_spinning.clear()
989 989 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
990 990 self._spin_thread.daemon = True
991 991 self._spin_thread.start()
992 992
993 993 def stop_spin_thread(self):
994 994 """stop background spin_thread, if any"""
995 995 if self._spin_thread is not None:
996 996 self._stop_spinning.set()
997 997 self._spin_thread.join()
998 998 self._spin_thread = None
999 999
1000 1000 def spin(self):
1001 1001 """Flush any registration notifications and execution results
1002 1002 waiting in the ZMQ queue.
1003 1003 """
1004 1004 if self._notification_socket:
1005 1005 self._flush_notifications()
1006 1006 if self._iopub_socket:
1007 1007 self._flush_iopub(self._iopub_socket)
1008 1008 if self._mux_socket:
1009 1009 self._flush_results(self._mux_socket)
1010 1010 if self._task_socket:
1011 1011 self._flush_results(self._task_socket)
1012 1012 if self._control_socket:
1013 1013 self._flush_control(self._control_socket)
1014 1014 if self._query_socket:
1015 1015 self._flush_ignored_hub_replies()
1016 1016
1017 1017 def wait(self, jobs=None, timeout=-1):
1018 1018 """waits on one or more `jobs`, for up to `timeout` seconds.
1019 1019
1020 1020 Parameters
1021 1021 ----------
1022 1022
1023 1023 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1024 1024 ints are indices to self.history
1025 1025 strs are msg_ids
1026 1026 default: wait on all outstanding messages
1027 1027 timeout : float
1028 1028 a time in seconds, after which to give up.
1029 1029 default is -1, which means no timeout
1030 1030
1031 1031 Returns
1032 1032 -------
1033 1033
1034 1034 True : when all msg_ids are done
1035 1035 False : timeout reached, some msg_ids still outstanding
1036 1036 """
1037 1037 tic = time.time()
1038 1038 if jobs is None:
1039 1039 theids = self.outstanding
1040 1040 else:
1041 1041 if isinstance(jobs, (int, basestring, AsyncResult)):
1042 1042 jobs = [jobs]
1043 1043 theids = set()
1044 1044 for job in jobs:
1045 1045 if isinstance(job, int):
1046 1046 # index access
1047 1047 job = self.history[job]
1048 1048 elif isinstance(job, AsyncResult):
1049 1049 map(theids.add, job.msg_ids)
1050 1050 continue
1051 1051 theids.add(job)
1052 1052 if not theids.intersection(self.outstanding):
1053 1053 return True
1054 1054 self.spin()
1055 1055 while theids.intersection(self.outstanding):
1056 1056 if timeout >= 0 and ( time.time()-tic ) > timeout:
1057 1057 break
1058 1058 time.sleep(1e-3)
1059 1059 self.spin()
1060 1060 return len(theids.intersection(self.outstanding)) == 0
1061 1061
1062 1062 #--------------------------------------------------------------------------
1063 1063 # Control methods
1064 1064 #--------------------------------------------------------------------------
1065 1065
1066 1066 @spin_first
1067 1067 def clear(self, targets=None, block=None):
1068 1068 """Clear the namespace in target(s)."""
1069 1069 block = self.block if block is None else block
1070 1070 targets = self._build_targets(targets)[0]
1071 1071 for t in targets:
1072 1072 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1073 1073 error = False
1074 1074 if block:
1075 1075 self._flush_ignored_control()
1076 1076 for i in range(len(targets)):
1077 1077 idents,msg = self.session.recv(self._control_socket,0)
1078 1078 if self.debug:
1079 1079 pprint(msg)
1080 1080 if msg['content']['status'] != 'ok':
1081 1081 error = self._unwrap_exception(msg['content'])
1082 1082 else:
1083 1083 self._ignored_control_replies += len(targets)
1084 1084 if error:
1085 1085 raise error
1086 1086
1087 1087
1088 1088 @spin_first
1089 1089 def abort(self, jobs=None, targets=None, block=None):
1090 1090 """Abort specific jobs from the execution queues of target(s).
1091 1091
1092 1092 This is a mechanism to prevent jobs that have already been submitted
1093 1093 from executing.
1094 1094
1095 1095 Parameters
1096 1096 ----------
1097 1097
1098 1098 jobs : msg_id, list of msg_ids, or AsyncResult
1099 1099 The jobs to be aborted
1100 1100
1101 1101 If unspecified/None: abort all outstanding jobs.
1102 1102
1103 1103 """
1104 1104 block = self.block if block is None else block
1105 1105 jobs = jobs if jobs is not None else list(self.outstanding)
1106 1106 targets = self._build_targets(targets)[0]
1107 1107
1108 1108 msg_ids = []
1109 1109 if isinstance(jobs, (basestring,AsyncResult)):
1110 1110 jobs = [jobs]
1111 1111 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1112 1112 if bad_ids:
1113 1113 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1114 1114 for j in jobs:
1115 1115 if isinstance(j, AsyncResult):
1116 1116 msg_ids.extend(j.msg_ids)
1117 1117 else:
1118 1118 msg_ids.append(j)
1119 1119 content = dict(msg_ids=msg_ids)
1120 1120 for t in targets:
1121 1121 self.session.send(self._control_socket, 'abort_request',
1122 1122 content=content, ident=t)
1123 1123 error = False
1124 1124 if block:
1125 1125 self._flush_ignored_control()
1126 1126 for i in range(len(targets)):
1127 1127 idents,msg = self.session.recv(self._control_socket,0)
1128 1128 if self.debug:
1129 1129 pprint(msg)
1130 1130 if msg['content']['status'] != 'ok':
1131 1131 error = self._unwrap_exception(msg['content'])
1132 1132 else:
1133 1133 self._ignored_control_replies += len(targets)
1134 1134 if error:
1135 1135 raise error
1136 1136
1137 1137 @spin_first
1138 1138 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1139 1139 """Terminates one or more engine processes, optionally including the hub.
1140 1140
1141 1141 Parameters
1142 1142 ----------
1143 1143
1144 1144 targets: list of ints or 'all' [default: all]
1145 1145 Which engines to shutdown.
1146 1146 hub: bool [default: False]
1147 1147 Whether to include the Hub. hub=True implies targets='all'.
1148 1148 block: bool [default: self.block]
1149 1149 Whether to wait for clean shutdown replies or not.
1150 1150 restart: bool [default: False]
1151 1151 NOT IMPLEMENTED
1152 1152 whether to restart engines after shutting them down.
1153 1153 """
1154
1154 from IPython.parallel.error import NoEnginesRegistered
1155 1155 if restart:
1156 1156 raise NotImplementedError("Engine restart is not yet implemented")
1157 1157
1158 1158 block = self.block if block is None else block
1159 1159 if hub:
1160 1160 targets = 'all'
1161 targets = self._build_targets(targets)[0]
1161 try:
1162 targets = self._build_targets(targets)[0]
1163 except NoEnginesRegistered:
1164 targets = []
1162 1165 for t in targets:
1163 1166 self.session.send(self._control_socket, 'shutdown_request',
1164 1167 content={'restart':restart},ident=t)
1165 1168 error = False
1166 1169 if block or hub:
1167 1170 self._flush_ignored_control()
1168 1171 for i in range(len(targets)):
1169 1172 idents,msg = self.session.recv(self._control_socket, 0)
1170 1173 if self.debug:
1171 1174 pprint(msg)
1172 1175 if msg['content']['status'] != 'ok':
1173 1176 error = self._unwrap_exception(msg['content'])
1174 1177 else:
1175 1178 self._ignored_control_replies += len(targets)
1176 1179
1177 1180 if hub:
1178 1181 time.sleep(0.25)
1179 1182 self.session.send(self._query_socket, 'shutdown_request')
1180 1183 idents,msg = self.session.recv(self._query_socket, 0)
1181 1184 if self.debug:
1182 1185 pprint(msg)
1183 1186 if msg['content']['status'] != 'ok':
1184 1187 error = self._unwrap_exception(msg['content'])
1185 1188
1186 1189 if error:
1187 1190 raise error
1188 1191
1189 1192 #--------------------------------------------------------------------------
1190 1193 # Execution related methods
1191 1194 #--------------------------------------------------------------------------
1192 1195
1193 1196 def _maybe_raise(self, result):
1194 1197 """wrapper for maybe raising an exception if apply failed."""
1195 1198 if isinstance(result, error.RemoteError):
1196 1199 raise result
1197 1200
1198 1201 return result
1199 1202
1200 1203 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1201 1204 ident=None):
1202 1205 """construct and send an apply message via a socket.
1203 1206
1204 1207 This is the principal method with which all engine execution is performed by views.
1205 1208 """
1206 1209
1207 1210 if self._closed:
1208 1211 raise RuntimeError("Client cannot be used after its sockets have been closed")
1209 1212
1210 1213 # defaults:
1211 1214 args = args if args is not None else []
1212 1215 kwargs = kwargs if kwargs is not None else {}
1213 1216 metadata = metadata if metadata is not None else {}
1214 1217
1215 1218 # validate arguments
1216 1219 if not callable(f) and not isinstance(f, Reference):
1217 1220 raise TypeError("f must be callable, not %s"%type(f))
1218 1221 if not isinstance(args, (tuple, list)):
1219 1222 raise TypeError("args must be tuple or list, not %s"%type(args))
1220 1223 if not isinstance(kwargs, dict):
1221 1224 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1222 1225 if not isinstance(metadata, dict):
1223 1226 raise TypeError("metadata must be dict, not %s"%type(metadata))
1224 1227
1225 1228 bufs = serialize.pack_apply_message(f, args, kwargs,
1226 1229 buffer_threshold=self.session.buffer_threshold,
1227 1230 item_threshold=self.session.item_threshold,
1228 1231 )
1229 1232
1230 1233 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1231 1234 metadata=metadata, track=track)
1232 1235
1233 1236 msg_id = msg['header']['msg_id']
1234 1237 self.outstanding.add(msg_id)
1235 1238 if ident:
1236 1239 # possibly routed to a specific engine
1237 1240 if isinstance(ident, list):
1238 1241 ident = ident[-1]
1239 1242 if ident in self._engines.values():
1240 1243 # save for later, in case of engine death
1241 1244 self._outstanding_dict[ident].add(msg_id)
1242 1245 self.history.append(msg_id)
1243 1246 self.metadata[msg_id]['submitted'] = datetime.now()
1244 1247
1245 1248 return msg
1246 1249
1247 1250 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1248 1251 """construct and send an execute request via a socket.
1249 1252
1250 1253 """
1251 1254
1252 1255 if self._closed:
1253 1256 raise RuntimeError("Client cannot be used after its sockets have been closed")
1254 1257
1255 1258 # defaults:
1256 1259 metadata = metadata if metadata is not None else {}
1257 1260
1258 1261 # validate arguments
1259 1262 if not isinstance(code, basestring):
1260 1263 raise TypeError("code must be text, not %s" % type(code))
1261 1264 if not isinstance(metadata, dict):
1262 1265 raise TypeError("metadata must be dict, not %s" % type(metadata))
1263 1266
1264 1267 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1265 1268
1266 1269
1267 1270 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1268 1271 metadata=metadata)
1269 1272
1270 1273 msg_id = msg['header']['msg_id']
1271 1274 self.outstanding.add(msg_id)
1272 1275 if ident:
1273 1276 # possibly routed to a specific engine
1274 1277 if isinstance(ident, list):
1275 1278 ident = ident[-1]
1276 1279 if ident in self._engines.values():
1277 1280 # save for later, in case of engine death
1278 1281 self._outstanding_dict[ident].add(msg_id)
1279 1282 self.history.append(msg_id)
1280 1283 self.metadata[msg_id]['submitted'] = datetime.now()
1281 1284
1282 1285 return msg
1283 1286
1284 1287 #--------------------------------------------------------------------------
1285 1288 # construct a View object
1286 1289 #--------------------------------------------------------------------------
1287 1290
1288 1291 def load_balanced_view(self, targets=None):
1289 1292 """construct a DirectView object.
1290 1293
1291 1294 If no arguments are specified, create a LoadBalancedView
1292 1295 using all engines.
1293 1296
1294 1297 Parameters
1295 1298 ----------
1296 1299
1297 1300 targets: list,slice,int,etc. [default: use all engines]
1298 1301 The subset of engines across which to load-balance
1299 1302 """
1300 1303 if targets == 'all':
1301 1304 targets = None
1302 1305 if targets is not None:
1303 1306 targets = self._build_targets(targets)[1]
1304 1307 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1305 1308
1306 1309 def direct_view(self, targets='all'):
1307 1310 """construct a DirectView object.
1308 1311
1309 1312 If no targets are specified, create a DirectView using all engines.
1310 1313
1311 1314 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1312 1315 evaluate the target engines at each execution, whereas rc[:] will connect to
1313 1316 all *current* engines, and that list will not change.
1314 1317
1315 1318 That is, 'all' will always use all engines, whereas rc[:] will not use
1316 1319 engines added after the DirectView is constructed.
1317 1320
1318 1321 Parameters
1319 1322 ----------
1320 1323
1321 1324 targets: list,slice,int,etc. [default: use all engines]
1322 1325 The engines to use for the View
1323 1326 """
1324 1327 single = isinstance(targets, int)
1325 1328 # allow 'all' to be lazily evaluated at each execution
1326 1329 if targets != 'all':
1327 1330 targets = self._build_targets(targets)[1]
1328 1331 if single:
1329 1332 targets = targets[0]
1330 1333 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1331 1334
1332 1335 #--------------------------------------------------------------------------
1333 1336 # Query methods
1334 1337 #--------------------------------------------------------------------------
1335 1338
1336 1339 @spin_first
1337 1340 def get_result(self, indices_or_msg_ids=None, block=None):
1338 1341 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1339 1342
1340 1343 If the client already has the results, no request to the Hub will be made.
1341 1344
1342 1345 This is a convenient way to construct AsyncResult objects, which are wrappers
1343 1346 that include metadata about execution, and allow for awaiting results that
1344 1347 were not submitted by this Client.
1345 1348
1346 1349 It can also be a convenient way to retrieve the metadata associated with
1347 1350 blocking execution, since it always retrieves
1348 1351
1349 1352 Examples
1350 1353 --------
1351 1354 ::
1352 1355
1353 1356 In [10]: r = client.apply()
1354 1357
1355 1358 Parameters
1356 1359 ----------
1357 1360
1358 1361 indices_or_msg_ids : integer history index, str msg_id, or list of either
1359 1362 The indices or msg_ids of indices to be retrieved
1360 1363
1361 1364 block : bool
1362 1365 Whether to wait for the result to be done
1363 1366
1364 1367 Returns
1365 1368 -------
1366 1369
1367 1370 AsyncResult
1368 1371 A single AsyncResult object will always be returned.
1369 1372
1370 1373 AsyncHubResult
1371 1374 A subclass of AsyncResult that retrieves results from the Hub
1372 1375
1373 1376 """
1374 1377 block = self.block if block is None else block
1375 1378 if indices_or_msg_ids is None:
1376 1379 indices_or_msg_ids = -1
1377 1380
1378 1381 if not isinstance(indices_or_msg_ids, (list,tuple)):
1379 1382 indices_or_msg_ids = [indices_or_msg_ids]
1380 1383
1381 1384 theids = []
1382 1385 for id in indices_or_msg_ids:
1383 1386 if isinstance(id, int):
1384 1387 id = self.history[id]
1385 1388 if not isinstance(id, basestring):
1386 1389 raise TypeError("indices must be str or int, not %r"%id)
1387 1390 theids.append(id)
1388 1391
1389 1392 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1390 1393 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1391 1394
1392 1395 if remote_ids:
1393 1396 ar = AsyncHubResult(self, msg_ids=theids)
1394 1397 else:
1395 1398 ar = AsyncResult(self, msg_ids=theids)
1396 1399
1397 1400 if block:
1398 1401 ar.wait()
1399 1402
1400 1403 return ar
1401 1404
1402 1405 @spin_first
1403 1406 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1404 1407 """Resubmit one or more tasks.
1405 1408
1406 1409 in-flight tasks may not be resubmitted.
1407 1410
1408 1411 Parameters
1409 1412 ----------
1410 1413
1411 1414 indices_or_msg_ids : integer history index, str msg_id, or list of either
1412 1415 The indices or msg_ids of indices to be retrieved
1413 1416
1414 1417 block : bool
1415 1418 Whether to wait for the result to be done
1416 1419
1417 1420 Returns
1418 1421 -------
1419 1422
1420 1423 AsyncHubResult
1421 1424 A subclass of AsyncResult that retrieves results from the Hub
1422 1425
1423 1426 """
1424 1427 block = self.block if block is None else block
1425 1428 if indices_or_msg_ids is None:
1426 1429 indices_or_msg_ids = -1
1427 1430
1428 1431 if not isinstance(indices_or_msg_ids, (list,tuple)):
1429 1432 indices_or_msg_ids = [indices_or_msg_ids]
1430 1433
1431 1434 theids = []
1432 1435 for id in indices_or_msg_ids:
1433 1436 if isinstance(id, int):
1434 1437 id = self.history[id]
1435 1438 if not isinstance(id, basestring):
1436 1439 raise TypeError("indices must be str or int, not %r"%id)
1437 1440 theids.append(id)
1438 1441
1439 1442 content = dict(msg_ids = theids)
1440 1443
1441 1444 self.session.send(self._query_socket, 'resubmit_request', content)
1442 1445
1443 1446 zmq.select([self._query_socket], [], [])
1444 1447 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1445 1448 if self.debug:
1446 1449 pprint(msg)
1447 1450 content = msg['content']
1448 1451 if content['status'] != 'ok':
1449 1452 raise self._unwrap_exception(content)
1450 1453 mapping = content['resubmitted']
1451 1454 new_ids = [ mapping[msg_id] for msg_id in theids ]
1452 1455
1453 1456 ar = AsyncHubResult(self, msg_ids=new_ids)
1454 1457
1455 1458 if block:
1456 1459 ar.wait()
1457 1460
1458 1461 return ar
1459 1462
1460 1463 @spin_first
1461 1464 def result_status(self, msg_ids, status_only=True):
1462 1465 """Check on the status of the result(s) of the apply request with `msg_ids`.
1463 1466
1464 1467 If status_only is False, then the actual results will be retrieved, else
1465 1468 only the status of the results will be checked.
1466 1469
1467 1470 Parameters
1468 1471 ----------
1469 1472
1470 1473 msg_ids : list of msg_ids
1471 1474 if int:
1472 1475 Passed as index to self.history for convenience.
1473 1476 status_only : bool (default: True)
1474 1477 if False:
1475 1478 Retrieve the actual results of completed tasks.
1476 1479
1477 1480 Returns
1478 1481 -------
1479 1482
1480 1483 results : dict
1481 1484 There will always be the keys 'pending' and 'completed', which will
1482 1485 be lists of msg_ids that are incomplete or complete. If `status_only`
1483 1486 is False, then completed results will be keyed by their `msg_id`.
1484 1487 """
1485 1488 if not isinstance(msg_ids, (list,tuple)):
1486 1489 msg_ids = [msg_ids]
1487 1490
1488 1491 theids = []
1489 1492 for msg_id in msg_ids:
1490 1493 if isinstance(msg_id, int):
1491 1494 msg_id = self.history[msg_id]
1492 1495 if not isinstance(msg_id, basestring):
1493 1496 raise TypeError("msg_ids must be str, not %r"%msg_id)
1494 1497 theids.append(msg_id)
1495 1498
1496 1499 completed = []
1497 1500 local_results = {}
1498 1501
1499 1502 # comment this block out to temporarily disable local shortcut:
1500 1503 for msg_id in theids:
1501 1504 if msg_id in self.results:
1502 1505 completed.append(msg_id)
1503 1506 local_results[msg_id] = self.results[msg_id]
1504 1507 theids.remove(msg_id)
1505 1508
1506 1509 if theids: # some not locally cached
1507 1510 content = dict(msg_ids=theids, status_only=status_only)
1508 1511 msg = self.session.send(self._query_socket, "result_request", content=content)
1509 1512 zmq.select([self._query_socket], [], [])
1510 1513 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1511 1514 if self.debug:
1512 1515 pprint(msg)
1513 1516 content = msg['content']
1514 1517 if content['status'] != 'ok':
1515 1518 raise self._unwrap_exception(content)
1516 1519 buffers = msg['buffers']
1517 1520 else:
1518 1521 content = dict(completed=[],pending=[])
1519 1522
1520 1523 content['completed'].extend(completed)
1521 1524
1522 1525 if status_only:
1523 1526 return content
1524 1527
1525 1528 failures = []
1526 1529 # load cached results into result:
1527 1530 content.update(local_results)
1528 1531
1529 1532 # update cache with results:
1530 1533 for msg_id in sorted(theids):
1531 1534 if msg_id in content['completed']:
1532 1535 rec = content[msg_id]
1533 1536 parent = rec['header']
1534 1537 header = rec['result_header']
1535 1538 rcontent = rec['result_content']
1536 1539 iodict = rec['io']
1537 1540 if isinstance(rcontent, str):
1538 1541 rcontent = self.session.unpack(rcontent)
1539 1542
1540 1543 md = self.metadata[msg_id]
1541 1544 md_msg = dict(
1542 1545 content=rcontent,
1543 1546 parent_header=parent,
1544 1547 header=header,
1545 1548 metadata=rec['result_metadata'],
1546 1549 )
1547 1550 md.update(self._extract_metadata(md_msg))
1548 1551 if rec.get('received'):
1549 1552 md['received'] = rec['received']
1550 1553 md.update(iodict)
1551 1554
1552 1555 if rcontent['status'] == 'ok':
1553 1556 if header['msg_type'] == 'apply_reply':
1554 1557 res,buffers = serialize.unserialize_object(buffers)
1555 1558 elif header['msg_type'] == 'execute_reply':
1556 1559 res = ExecuteReply(msg_id, rcontent, md)
1557 1560 else:
1558 1561 raise KeyError("unhandled msg type: %r" % header[msg_type])
1559 1562 else:
1560 1563 res = self._unwrap_exception(rcontent)
1561 1564 failures.append(res)
1562 1565
1563 1566 self.results[msg_id] = res
1564 1567 content[msg_id] = res
1565 1568
1566 1569 if len(theids) == 1 and failures:
1567 1570 raise failures[0]
1568 1571
1569 1572 error.collect_exceptions(failures, "result_status")
1570 1573 return content
1571 1574
1572 1575 @spin_first
1573 1576 def queue_status(self, targets='all', verbose=False):
1574 1577 """Fetch the status of engine queues.
1575 1578
1576 1579 Parameters
1577 1580 ----------
1578 1581
1579 1582 targets : int/str/list of ints/strs
1580 1583 the engines whose states are to be queried.
1581 1584 default : all
1582 1585 verbose : bool
1583 1586 Whether to return lengths only, or lists of ids for each element
1584 1587 """
1585 1588 if targets == 'all':
1586 1589 # allow 'all' to be evaluated on the engine
1587 1590 engine_ids = None
1588 1591 else:
1589 1592 engine_ids = self._build_targets(targets)[1]
1590 1593 content = dict(targets=engine_ids, verbose=verbose)
1591 1594 self.session.send(self._query_socket, "queue_request", content=content)
1592 1595 idents,msg = self.session.recv(self._query_socket, 0)
1593 1596 if self.debug:
1594 1597 pprint(msg)
1595 1598 content = msg['content']
1596 1599 status = content.pop('status')
1597 1600 if status != 'ok':
1598 1601 raise self._unwrap_exception(content)
1599 1602 content = rekey(content)
1600 1603 if isinstance(targets, int):
1601 1604 return content[targets]
1602 1605 else:
1603 1606 return content
1604 1607
1605 1608 @spin_first
1606 1609 def purge_results(self, jobs=[], targets=[]):
1607 1610 """Tell the Hub to forget results.
1608 1611
1609 1612 Individual results can be purged by msg_id, or the entire
1610 1613 history of specific targets can be purged.
1611 1614
1612 1615 Use `purge_results('all')` to scrub everything from the Hub's db.
1613 1616
1614 1617 Parameters
1615 1618 ----------
1616 1619
1617 1620 jobs : str or list of str or AsyncResult objects
1618 1621 the msg_ids whose results should be forgotten.
1619 1622 targets : int/str/list of ints/strs
1620 1623 The targets, by int_id, whose entire history is to be purged.
1621 1624
1622 1625 default : None
1623 1626 """
1624 1627 if not targets and not jobs:
1625 1628 raise ValueError("Must specify at least one of `targets` and `jobs`")
1626 1629 if targets:
1627 1630 targets = self._build_targets(targets)[1]
1628 1631
1629 1632 # construct msg_ids from jobs
1630 1633 if jobs == 'all':
1631 1634 msg_ids = jobs
1632 1635 else:
1633 1636 msg_ids = []
1634 1637 if isinstance(jobs, (basestring,AsyncResult)):
1635 1638 jobs = [jobs]
1636 1639 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1637 1640 if bad_ids:
1638 1641 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1639 1642 for j in jobs:
1640 1643 if isinstance(j, AsyncResult):
1641 1644 msg_ids.extend(j.msg_ids)
1642 1645 else:
1643 1646 msg_ids.append(j)
1644 1647
1645 1648 content = dict(engine_ids=targets, msg_ids=msg_ids)
1646 1649 self.session.send(self._query_socket, "purge_request", content=content)
1647 1650 idents, msg = self.session.recv(self._query_socket, 0)
1648 1651 if self.debug:
1649 1652 pprint(msg)
1650 1653 content = msg['content']
1651 1654 if content['status'] != 'ok':
1652 1655 raise self._unwrap_exception(content)
1653 1656
1654 1657 @spin_first
1655 1658 def hub_history(self):
1656 1659 """Get the Hub's history
1657 1660
1658 1661 Just like the Client, the Hub has a history, which is a list of msg_ids.
1659 1662 This will contain the history of all clients, and, depending on configuration,
1660 1663 may contain history across multiple cluster sessions.
1661 1664
1662 1665 Any msg_id returned here is a valid argument to `get_result`.
1663 1666
1664 1667 Returns
1665 1668 -------
1666 1669
1667 1670 msg_ids : list of strs
1668 1671 list of all msg_ids, ordered by task submission time.
1669 1672 """
1670 1673
1671 1674 self.session.send(self._query_socket, "history_request", content={})
1672 1675 idents, msg = self.session.recv(self._query_socket, 0)
1673 1676
1674 1677 if self.debug:
1675 1678 pprint(msg)
1676 1679 content = msg['content']
1677 1680 if content['status'] != 'ok':
1678 1681 raise self._unwrap_exception(content)
1679 1682 else:
1680 1683 return content['history']
1681 1684
1682 1685 @spin_first
1683 1686 def db_query(self, query, keys=None):
1684 1687 """Query the Hub's TaskRecord database
1685 1688
1686 1689 This will return a list of task record dicts that match `query`
1687 1690
1688 1691 Parameters
1689 1692 ----------
1690 1693
1691 1694 query : mongodb query dict
1692 1695 The search dict. See mongodb query docs for details.
1693 1696 keys : list of strs [optional]
1694 1697 The subset of keys to be returned. The default is to fetch everything but buffers.
1695 1698 'msg_id' will *always* be included.
1696 1699 """
1697 1700 if isinstance(keys, basestring):
1698 1701 keys = [keys]
1699 1702 content = dict(query=query, keys=keys)
1700 1703 self.session.send(self._query_socket, "db_request", content=content)
1701 1704 idents, msg = self.session.recv(self._query_socket, 0)
1702 1705 if self.debug:
1703 1706 pprint(msg)
1704 1707 content = msg['content']
1705 1708 if content['status'] != 'ok':
1706 1709 raise self._unwrap_exception(content)
1707 1710
1708 1711 records = content['records']
1709 1712
1710 1713 buffer_lens = content['buffer_lens']
1711 1714 result_buffer_lens = content['result_buffer_lens']
1712 1715 buffers = msg['buffers']
1713 1716 has_bufs = buffer_lens is not None
1714 1717 has_rbufs = result_buffer_lens is not None
1715 1718 for i,rec in enumerate(records):
1716 1719 # relink buffers
1717 1720 if has_bufs:
1718 1721 blen = buffer_lens[i]
1719 1722 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1720 1723 if has_rbufs:
1721 1724 blen = result_buffer_lens[i]
1722 1725 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1723 1726
1724 1727 return records
1725 1728
1726 1729 __all__ = [ 'Client' ]
General Comments 0
You need to be logged in to leave comments. Login now