##// END OF EJS Templates
fix message built for engine dying during task...
MinRK -
Show More
@@ -1,1718 +1,1715 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 from threading import Thread, Event
22 22 import time
23 23 import warnings
24 24 from datetime import datetime
25 25 from getpass import getpass
26 26 from pprint import pprint
27 27
28 28 pjoin = os.path.join
29 29
30 30 import zmq
31 31 # from zmq.eventloop import ioloop, zmqstream
32 32
33 33 from IPython.config.configurable import MultipleInstanceError
34 34 from IPython.core.application import BaseIPythonApplication
35 35 from IPython.core.profiledir import ProfileDir, ProfileDirError
36 36
37 37 from IPython.utils.coloransi import TermColors
38 38 from IPython.utils.jsonutil import rekey
39 39 from IPython.utils.localinterfaces import LOCAL_IPS
40 40 from IPython.utils.path import get_ipython_dir
41 41 from IPython.utils.py3compat import cast_bytes
42 42 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
43 43 Dict, List, Bool, Set, Any)
44 44 from IPython.external.decorator import decorator
45 45 from IPython.external.ssh import tunnel
46 46
47 47 from IPython.parallel import Reference
48 48 from IPython.parallel import error
49 49 from IPython.parallel import util
50 50
51 51 from IPython.zmq.session import Session, Message
52 52 from IPython.zmq import serialize
53 53
54 54 from .asyncresult import AsyncResult, AsyncHubResult
55 55 from .view import DirectView, LoadBalancedView
56 56
57 57 if sys.version_info[0] >= 3:
58 58 # xrange is used in a couple 'isinstance' tests in py2
59 59 # should be just 'range' in 3k
60 60 xrange = range
61 61
62 62 #--------------------------------------------------------------------------
63 63 # Decorators for Client methods
64 64 #--------------------------------------------------------------------------
65 65
66 66 @decorator
67 67 def spin_first(f, self, *args, **kwargs):
68 68 """Call spin() to sync state prior to calling the method."""
69 69 self.spin()
70 70 return f(self, *args, **kwargs)
71 71
72 72
73 73 #--------------------------------------------------------------------------
74 74 # Classes
75 75 #--------------------------------------------------------------------------
76 76
77 77
78 78 class ExecuteReply(object):
79 79 """wrapper for finished Execute results"""
80 80 def __init__(self, msg_id, content, metadata):
81 81 self.msg_id = msg_id
82 82 self._content = content
83 83 self.execution_count = content['execution_count']
84 84 self.metadata = metadata
85 85
86 86 def __getitem__(self, key):
87 87 return self.metadata[key]
88 88
89 89 def __getattr__(self, key):
90 90 if key not in self.metadata:
91 91 raise AttributeError(key)
92 92 return self.metadata[key]
93 93
94 94 def __repr__(self):
95 95 pyout = self.metadata['pyout'] or {'data':{}}
96 96 text_out = pyout['data'].get('text/plain', '')
97 97 if len(text_out) > 32:
98 98 text_out = text_out[:29] + '...'
99 99
100 100 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
101 101
102 102 def _repr_pretty_(self, p, cycle):
103 103 pyout = self.metadata['pyout'] or {'data':{}}
104 104 text_out = pyout['data'].get('text/plain', '')
105 105
106 106 if not text_out:
107 107 return
108 108
109 109 try:
110 110 ip = get_ipython()
111 111 except NameError:
112 112 colors = "NoColor"
113 113 else:
114 114 colors = ip.colors
115 115
116 116 if colors == "NoColor":
117 117 out = normal = ""
118 118 else:
119 119 out = TermColors.Red
120 120 normal = TermColors.Normal
121 121
122 122 if '\n' in text_out and not text_out.startswith('\n'):
123 123 # add newline for multiline reprs
124 124 text_out = '\n' + text_out
125 125
126 126 p.text(
127 127 out + u'Out[%i:%i]: ' % (
128 128 self.metadata['engine_id'], self.execution_count
129 129 ) + normal + text_out
130 130 )
131 131
132 132 def _repr_html_(self):
133 133 pyout = self.metadata['pyout'] or {'data':{}}
134 134 return pyout['data'].get("text/html")
135 135
136 136 def _repr_latex_(self):
137 137 pyout = self.metadata['pyout'] or {'data':{}}
138 138 return pyout['data'].get("text/latex")
139 139
140 140 def _repr_json_(self):
141 141 pyout = self.metadata['pyout'] or {'data':{}}
142 142 return pyout['data'].get("application/json")
143 143
144 144 def _repr_javascript_(self):
145 145 pyout = self.metadata['pyout'] or {'data':{}}
146 146 return pyout['data'].get("application/javascript")
147 147
148 148 def _repr_png_(self):
149 149 pyout = self.metadata['pyout'] or {'data':{}}
150 150 return pyout['data'].get("image/png")
151 151
152 152 def _repr_jpeg_(self):
153 153 pyout = self.metadata['pyout'] or {'data':{}}
154 154 return pyout['data'].get("image/jpeg")
155 155
156 156 def _repr_svg_(self):
157 157 pyout = self.metadata['pyout'] or {'data':{}}
158 158 return pyout['data'].get("image/svg+xml")
159 159
160 160
161 161 class Metadata(dict):
162 162 """Subclass of dict for initializing metadata values.
163 163
164 164 Attribute access works on keys.
165 165
166 166 These objects have a strict set of keys - errors will raise if you try
167 167 to add new keys.
168 168 """
169 169 def __init__(self, *args, **kwargs):
170 170 dict.__init__(self)
171 171 md = {'msg_id' : None,
172 172 'submitted' : None,
173 173 'started' : None,
174 174 'completed' : None,
175 175 'received' : None,
176 176 'engine_uuid' : None,
177 177 'engine_id' : None,
178 178 'follow' : None,
179 179 'after' : None,
180 180 'status' : None,
181 181
182 182 'pyin' : None,
183 183 'pyout' : None,
184 184 'pyerr' : None,
185 185 'stdout' : '',
186 186 'stderr' : '',
187 187 'outputs' : [],
188 188 'data': {},
189 189 'outputs_ready' : False,
190 190 }
191 191 self.update(md)
192 192 self.update(dict(*args, **kwargs))
193 193
194 194 def __getattr__(self, key):
195 195 """getattr aliased to getitem"""
196 196 if key in self.iterkeys():
197 197 return self[key]
198 198 else:
199 199 raise AttributeError(key)
200 200
201 201 def __setattr__(self, key, value):
202 202 """setattr aliased to setitem, with strict"""
203 203 if key in self.iterkeys():
204 204 self[key] = value
205 205 else:
206 206 raise AttributeError(key)
207 207
208 208 def __setitem__(self, key, value):
209 209 """strict static key enforcement"""
210 210 if key in self.iterkeys():
211 211 dict.__setitem__(self, key, value)
212 212 else:
213 213 raise KeyError(key)
214 214
215 215
216 216 class Client(HasTraits):
217 217 """A semi-synchronous client to the IPython ZMQ cluster
218 218
219 219 Parameters
220 220 ----------
221 221
222 222 url_file : str/unicode; path to ipcontroller-client.json
223 223 This JSON file should contain all the information needed to connect to a cluster,
224 224 and is likely the only argument needed.
225 225 Connection information for the Hub's registration. If a json connector
226 226 file is given, then likely no further configuration is necessary.
227 227 [Default: use profile]
228 228 profile : bytes
229 229 The name of the Cluster profile to be used to find connector information.
230 230 If run from an IPython application, the default profile will be the same
231 231 as the running application, otherwise it will be 'default'.
232 232 context : zmq.Context
233 233 Pass an existing zmq.Context instance, otherwise the client will create its own.
234 234 debug : bool
235 235 flag for lots of message printing for debug purposes
236 236 timeout : int/float
237 237 time (in seconds) to wait for connection replies from the Hub
238 238 [Default: 10]
239 239
240 240 #-------------- session related args ----------------
241 241
242 242 config : Config object
243 243 If specified, this will be relayed to the Session for configuration
244 244 username : str
245 245 set username for the session object
246 246
247 247 #-------------- ssh related args ----------------
248 248 # These are args for configuring the ssh tunnel to be used
249 249 # credentials are used to forward connections over ssh to the Controller
250 250 # Note that the ip given in `addr` needs to be relative to sshserver
251 251 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
252 252 # and set sshserver as the same machine the Controller is on. However,
253 253 # the only requirement is that sshserver is able to see the Controller
254 254 # (i.e. is within the same trusted network).
255 255
256 256 sshserver : str
257 257 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
258 258 If keyfile or password is specified, and this is not, it will default to
259 259 the ip given in addr.
260 260 sshkey : str; path to ssh private key file
261 261 This specifies a key to be used in ssh login, default None.
262 262 Regular default ssh keys will be used without specifying this argument.
263 263 password : str
264 264 Your ssh password to sshserver. Note that if this is left None,
265 265 you will be prompted for it if passwordless key based login is unavailable.
266 266 paramiko : bool
267 267 flag for whether to use paramiko instead of shell ssh for tunneling.
268 268 [default: True on win32, False else]
269 269
270 270
271 271 Attributes
272 272 ----------
273 273
274 274 ids : list of int engine IDs
275 275 requesting the ids attribute always synchronizes
276 276 the registration state. To request ids without synchronization,
277 277 use semi-private _ids attributes.
278 278
279 279 history : list of msg_ids
280 280 a list of msg_ids, keeping track of all the execution
281 281 messages you have submitted in order.
282 282
283 283 outstanding : set of msg_ids
284 284 a set of msg_ids that have been submitted, but whose
285 285 results have not yet been received.
286 286
287 287 results : dict
288 288 a dict of all our results, keyed by msg_id
289 289
290 290 block : bool
291 291 determines default behavior when block not specified
292 292 in execution methods
293 293
294 294 Methods
295 295 -------
296 296
297 297 spin
298 298 flushes incoming results and registration state changes
299 299 control methods spin, and requesting `ids` also ensures up to date
300 300
301 301 wait
302 302 wait on one or more msg_ids
303 303
304 304 execution methods
305 305 apply
306 306 legacy: execute, run
307 307
308 308 data movement
309 309 push, pull, scatter, gather
310 310
311 311 query methods
312 312 queue_status, get_result, purge, result_status
313 313
314 314 control methods
315 315 abort, shutdown
316 316
317 317 """
318 318
319 319
320 320 block = Bool(False)
321 321 outstanding = Set()
322 322 results = Instance('collections.defaultdict', (dict,))
323 323 metadata = Instance('collections.defaultdict', (Metadata,))
324 324 history = List()
325 325 debug = Bool(False)
326 326 _spin_thread = Any()
327 327 _stop_spinning = Any()
328 328
329 329 profile=Unicode()
330 330 def _profile_default(self):
331 331 if BaseIPythonApplication.initialized():
332 332 # an IPython app *might* be running, try to get its profile
333 333 try:
334 334 return BaseIPythonApplication.instance().profile
335 335 except (AttributeError, MultipleInstanceError):
336 336 # could be a *different* subclass of config.Application,
337 337 # which would raise one of these two errors.
338 338 return u'default'
339 339 else:
340 340 return u'default'
341 341
342 342
343 343 _outstanding_dict = Instance('collections.defaultdict', (set,))
344 344 _ids = List()
345 345 _connected=Bool(False)
346 346 _ssh=Bool(False)
347 347 _context = Instance('zmq.Context')
348 348 _config = Dict()
349 349 _engines=Instance(util.ReverseDict, (), {})
350 350 # _hub_socket=Instance('zmq.Socket')
351 351 _query_socket=Instance('zmq.Socket')
352 352 _control_socket=Instance('zmq.Socket')
353 353 _iopub_socket=Instance('zmq.Socket')
354 354 _notification_socket=Instance('zmq.Socket')
355 355 _mux_socket=Instance('zmq.Socket')
356 356 _task_socket=Instance('zmq.Socket')
357 357 _task_scheme=Unicode()
358 358 _closed = False
359 359 _ignored_control_replies=Integer(0)
360 360 _ignored_hub_replies=Integer(0)
361 361
362 362 def __new__(self, *args, **kw):
363 363 # don't raise on positional args
364 364 return HasTraits.__new__(self, **kw)
365 365
366 366 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
367 367 context=None, debug=False,
368 368 sshserver=None, sshkey=None, password=None, paramiko=None,
369 369 timeout=10, **extra_args
370 370 ):
371 371 if profile:
372 372 super(Client, self).__init__(debug=debug, profile=profile)
373 373 else:
374 374 super(Client, self).__init__(debug=debug)
375 375 if context is None:
376 376 context = zmq.Context.instance()
377 377 self._context = context
378 378 self._stop_spinning = Event()
379 379
380 380 if 'url_or_file' in extra_args:
381 381 url_file = extra_args['url_or_file']
382 382 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
383 383
384 384 if url_file and util.is_url(url_file):
385 385 raise ValueError("single urls cannot be specified, url-files must be used.")
386 386
387 387 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
388 388
389 389 if self._cd is not None:
390 390 if url_file is None:
391 391 url_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
392 392 if url_file is None:
393 393 raise ValueError(
394 394 "I can't find enough information to connect to a hub!"
395 395 " Please specify at least one of url_file or profile."
396 396 )
397 397
398 398 with open(url_file) as f:
399 399 cfg = json.load(f)
400 400
401 401 self._task_scheme = cfg['task_scheme']
402 402
403 403 # sync defaults from args, json:
404 404 if sshserver:
405 405 cfg['ssh'] = sshserver
406 406
407 407 location = cfg.setdefault('location', None)
408 408
409 409 proto,addr = cfg['interface'].split('://')
410 410 addr = util.disambiguate_ip_address(addr)
411 411 cfg['interface'] = "%s://%s" % (proto, addr)
412 412
413 413 # turn interface,port into full urls:
414 414 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
415 415 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
416 416
417 417 url = cfg['registration']
418 418
419 419 if location is not None and addr == '127.0.0.1':
420 420 # location specified, and connection is expected to be local
421 421 if location not in LOCAL_IPS and not sshserver:
422 422 # load ssh from JSON *only* if the controller is not on
423 423 # this machine
424 424 sshserver=cfg['ssh']
425 425 if location not in LOCAL_IPS and not sshserver:
426 426 # warn if no ssh specified, but SSH is probably needed
427 427 # This is only a warning, because the most likely cause
428 428 # is a local Controller on a laptop whose IP is dynamic
429 429 warnings.warn("""
430 430 Controller appears to be listening on localhost, but not on this machine.
431 431 If this is true, you should specify Client(...,sshserver='you@%s')
432 432 or instruct your controller to listen on an external IP."""%location,
433 433 RuntimeWarning)
434 434 elif not sshserver:
435 435 # otherwise sync with cfg
436 436 sshserver = cfg['ssh']
437 437
438 438 self._config = cfg
439 439
440 440 self._ssh = bool(sshserver or sshkey or password)
441 441 if self._ssh and sshserver is None:
442 442 # default to ssh via localhost
443 443 sshserver = addr
444 444 if self._ssh and password is None:
445 445 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
446 446 password=False
447 447 else:
448 448 password = getpass("SSH Password for %s: "%sshserver)
449 449 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
450 450
451 451 # configure and construct the session
452 452 extra_args['packer'] = cfg['pack']
453 453 extra_args['unpacker'] = cfg['unpack']
454 454 extra_args['key'] = cast_bytes(cfg['exec_key'])
455 455
456 456 self.session = Session(**extra_args)
457 457
458 458 self._query_socket = self._context.socket(zmq.DEALER)
459 459
460 460 if self._ssh:
461 461 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
462 462 else:
463 463 self._query_socket.connect(cfg['registration'])
464 464
465 465 self.session.debug = self.debug
466 466
467 467 self._notification_handlers = {'registration_notification' : self._register_engine,
468 468 'unregistration_notification' : self._unregister_engine,
469 469 'shutdown_notification' : lambda msg: self.close(),
470 470 }
471 471 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
472 472 'apply_reply' : self._handle_apply_reply}
473 473 self._connect(sshserver, ssh_kwargs, timeout)
474 474
475 475 # last step: setup magics, if we are in IPython:
476 476
477 477 try:
478 478 ip = get_ipython()
479 479 except NameError:
480 480 return
481 481 else:
482 482 if 'px' not in ip.magics_manager.magics:
483 483 # in IPython but we are the first Client.
484 484 # activate a default view for parallel magics.
485 485 self.activate()
486 486
487 487 def __del__(self):
488 488 """cleanup sockets, but _not_ context."""
489 489 self.close()
490 490
491 491 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
492 492 if ipython_dir is None:
493 493 ipython_dir = get_ipython_dir()
494 494 if profile_dir is not None:
495 495 try:
496 496 self._cd = ProfileDir.find_profile_dir(profile_dir)
497 497 return
498 498 except ProfileDirError:
499 499 pass
500 500 elif profile is not None:
501 501 try:
502 502 self._cd = ProfileDir.find_profile_dir_by_name(
503 503 ipython_dir, profile)
504 504 return
505 505 except ProfileDirError:
506 506 pass
507 507 self._cd = None
508 508
509 509 def _update_engines(self, engines):
510 510 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
511 511 for k,v in engines.iteritems():
512 512 eid = int(k)
513 513 if eid not in self._engines:
514 514 self._ids.append(eid)
515 515 self._engines[eid] = v
516 516 self._ids = sorted(self._ids)
517 517 if sorted(self._engines.keys()) != range(len(self._engines)) and \
518 518 self._task_scheme == 'pure' and self._task_socket:
519 519 self._stop_scheduling_tasks()
520 520
521 521 def _stop_scheduling_tasks(self):
522 522 """Stop scheduling tasks because an engine has been unregistered
523 523 from a pure ZMQ scheduler.
524 524 """
525 525 self._task_socket.close()
526 526 self._task_socket = None
527 527 msg = "An engine has been unregistered, and we are using pure " +\
528 528 "ZMQ task scheduling. Task farming will be disabled."
529 529 if self.outstanding:
530 530 msg += " If you were running tasks when this happened, " +\
531 531 "some `outstanding` msg_ids may never resolve."
532 532 warnings.warn(msg, RuntimeWarning)
533 533
534 534 def _build_targets(self, targets):
535 535 """Turn valid target IDs or 'all' into two lists:
536 536 (int_ids, uuids).
537 537 """
538 538 if not self._ids:
539 539 # flush notification socket if no engines yet, just in case
540 540 if not self.ids:
541 541 raise error.NoEnginesRegistered("Can't build targets without any engines")
542 542
543 543 if targets is None:
544 544 targets = self._ids
545 545 elif isinstance(targets, basestring):
546 546 if targets.lower() == 'all':
547 547 targets = self._ids
548 548 else:
549 549 raise TypeError("%r not valid str target, must be 'all'"%(targets))
550 550 elif isinstance(targets, int):
551 551 if targets < 0:
552 552 targets = self.ids[targets]
553 553 if targets not in self._ids:
554 554 raise IndexError("No such engine: %i"%targets)
555 555 targets = [targets]
556 556
557 557 if isinstance(targets, slice):
558 558 indices = range(len(self._ids))[targets]
559 559 ids = self.ids
560 560 targets = [ ids[i] for i in indices ]
561 561
562 562 if not isinstance(targets, (tuple, list, xrange)):
563 563 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
564 564
565 565 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
566 566
567 567 def _connect(self, sshserver, ssh_kwargs, timeout):
568 568 """setup all our socket connections to the cluster. This is called from
569 569 __init__."""
570 570
571 571 # Maybe allow reconnecting?
572 572 if self._connected:
573 573 return
574 574 self._connected=True
575 575
576 576 def connect_socket(s, url):
577 577 # url = util.disambiguate_url(url, self._config['location'])
578 578 if self._ssh:
579 579 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
580 580 else:
581 581 return s.connect(url)
582 582
583 583 self.session.send(self._query_socket, 'connection_request')
584 584 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
585 585 poller = zmq.Poller()
586 586 poller.register(self._query_socket, zmq.POLLIN)
587 587 # poll expects milliseconds, timeout is seconds
588 588 evts = poller.poll(timeout*1000)
589 589 if not evts:
590 590 raise error.TimeoutError("Hub connection request timed out")
591 591 idents,msg = self.session.recv(self._query_socket,mode=0)
592 592 if self.debug:
593 593 pprint(msg)
594 594 content = msg['content']
595 595 # self._config['registration'] = dict(content)
596 596 cfg = self._config
597 597 if content['status'] == 'ok':
598 598 self._mux_socket = self._context.socket(zmq.DEALER)
599 599 connect_socket(self._mux_socket, cfg['mux'])
600 600
601 601 self._task_socket = self._context.socket(zmq.DEALER)
602 602 connect_socket(self._task_socket, cfg['task'])
603 603
604 604 self._notification_socket = self._context.socket(zmq.SUB)
605 605 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
606 606 connect_socket(self._notification_socket, cfg['notification'])
607 607
608 608 self._control_socket = self._context.socket(zmq.DEALER)
609 609 connect_socket(self._control_socket, cfg['control'])
610 610
611 611 self._iopub_socket = self._context.socket(zmq.SUB)
612 612 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
613 613 connect_socket(self._iopub_socket, cfg['iopub'])
614 614
615 615 self._update_engines(dict(content['engines']))
616 616 else:
617 617 self._connected = False
618 618 raise Exception("Failed to connect!")
619 619
620 620 #--------------------------------------------------------------------------
621 621 # handlers and callbacks for incoming messages
622 622 #--------------------------------------------------------------------------
623 623
624 624 def _unwrap_exception(self, content):
625 625 """unwrap exception, and remap engine_id to int."""
626 626 e = error.unwrap_exception(content)
627 627 # print e.traceback
628 628 if e.engine_info:
629 629 e_uuid = e.engine_info['engine_uuid']
630 630 eid = self._engines[e_uuid]
631 631 e.engine_info['engine_id'] = eid
632 632 return e
633 633
634 634 def _extract_metadata(self, msg):
635 635 header = msg['header']
636 636 parent = msg['parent_header']
637 637 msg_meta = msg['metadata']
638 638 content = msg['content']
639 639 md = {'msg_id' : parent['msg_id'],
640 640 'received' : datetime.now(),
641 641 'engine_uuid' : msg_meta.get('engine', None),
642 642 'follow' : msg_meta.get('follow', []),
643 643 'after' : msg_meta.get('after', []),
644 644 'status' : content['status'],
645 645 }
646 646
647 647 if md['engine_uuid'] is not None:
648 648 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
649 649
650 650 if 'date' in parent:
651 651 md['submitted'] = parent['date']
652 652 if 'started' in msg_meta:
653 653 md['started'] = msg_meta['started']
654 654 if 'date' in header:
655 655 md['completed'] = header['date']
656 656 return md
657 657
658 658 def _register_engine(self, msg):
659 659 """Register a new engine, and update our connection info."""
660 660 content = msg['content']
661 661 eid = content['id']
662 662 d = {eid : content['uuid']}
663 663 self._update_engines(d)
664 664
665 665 def _unregister_engine(self, msg):
666 666 """Unregister an engine that has died."""
667 667 content = msg['content']
668 668 eid = int(content['id'])
669 669 if eid in self._ids:
670 670 self._ids.remove(eid)
671 671 uuid = self._engines.pop(eid)
672 672
673 673 self._handle_stranded_msgs(eid, uuid)
674 674
675 675 if self._task_socket and self._task_scheme == 'pure':
676 676 self._stop_scheduling_tasks()
677 677
678 678 def _handle_stranded_msgs(self, eid, uuid):
679 679 """Handle messages known to be on an engine when the engine unregisters.
680 680
681 681 It is possible that this will fire prematurely - that is, an engine will
682 682 go down after completing a result, and the client will be notified
683 683 of the unregistration and later receive the successful result.
684 684 """
685 685
686 686 outstanding = self._outstanding_dict[uuid]
687 687
688 688 for msg_id in list(outstanding):
689 689 if msg_id in self.results:
690 690 # we already
691 691 continue
692 692 try:
693 693 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
694 694 except:
695 695 content = error.wrap_exception()
696 696 # build a fake message:
697 parent = {}
698 header = {}
699 parent['msg_id'] = msg_id
700 header['engine'] = uuid
701 header['date'] = datetime.now()
702 msg = dict(parent_header=parent, header=header, content=content)
697 msg = self.session.msg('apply_reply', content=content)
698 msg['parent_header']['msg_id'] = msg_id
699 msg['metadata']['engine'] = uuid
703 700 self._handle_apply_reply(msg)
704 701
705 702 def _handle_execute_reply(self, msg):
706 703 """Save the reply to an execute_request into our results.
707 704
708 705 execute messages are never actually used. apply is used instead.
709 706 """
710 707
711 708 parent = msg['parent_header']
712 709 msg_id = parent['msg_id']
713 710 if msg_id not in self.outstanding:
714 711 if msg_id in self.history:
715 712 print ("got stale result: %s"%msg_id)
716 713 else:
717 714 print ("got unknown result: %s"%msg_id)
718 715 else:
719 716 self.outstanding.remove(msg_id)
720 717
721 718 content = msg['content']
722 719 header = msg['header']
723 720
724 721 # construct metadata:
725 722 md = self.metadata[msg_id]
726 723 md.update(self._extract_metadata(msg))
727 724 # is this redundant?
728 725 self.metadata[msg_id] = md
729 726
730 727 e_outstanding = self._outstanding_dict[md['engine_uuid']]
731 728 if msg_id in e_outstanding:
732 729 e_outstanding.remove(msg_id)
733 730
734 731 # construct result:
735 732 if content['status'] == 'ok':
736 733 self.results[msg_id] = ExecuteReply(msg_id, content, md)
737 734 elif content['status'] == 'aborted':
738 735 self.results[msg_id] = error.TaskAborted(msg_id)
739 736 elif content['status'] == 'resubmitted':
740 737 # TODO: handle resubmission
741 738 pass
742 739 else:
743 740 self.results[msg_id] = self._unwrap_exception(content)
744 741
745 742 def _handle_apply_reply(self, msg):
746 743 """Save the reply to an apply_request into our results."""
747 744 parent = msg['parent_header']
748 745 msg_id = parent['msg_id']
749 746 if msg_id not in self.outstanding:
750 747 if msg_id in self.history:
751 748 print ("got stale result: %s"%msg_id)
752 749 print self.results[msg_id]
753 750 print msg
754 751 else:
755 752 print ("got unknown result: %s"%msg_id)
756 753 else:
757 754 self.outstanding.remove(msg_id)
758 755 content = msg['content']
759 756 header = msg['header']
760 757
761 758 # construct metadata:
762 759 md = self.metadata[msg_id]
763 760 md.update(self._extract_metadata(msg))
764 761 # is this redundant?
765 762 self.metadata[msg_id] = md
766 763
767 764 e_outstanding = self._outstanding_dict[md['engine_uuid']]
768 765 if msg_id in e_outstanding:
769 766 e_outstanding.remove(msg_id)
770 767
771 768 # construct result:
772 769 if content['status'] == 'ok':
773 770 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
774 771 elif content['status'] == 'aborted':
775 772 self.results[msg_id] = error.TaskAborted(msg_id)
776 773 elif content['status'] == 'resubmitted':
777 774 # TODO: handle resubmission
778 775 pass
779 776 else:
780 777 self.results[msg_id] = self._unwrap_exception(content)
781 778
782 779 def _flush_notifications(self):
783 780 """Flush notifications of engine registrations waiting
784 781 in ZMQ queue."""
785 782 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
786 783 while msg is not None:
787 784 if self.debug:
788 785 pprint(msg)
789 786 msg_type = msg['header']['msg_type']
790 787 handler = self._notification_handlers.get(msg_type, None)
791 788 if handler is None:
792 789 raise Exception("Unhandled message type: %s"%msg.msg_type)
793 790 else:
794 791 handler(msg)
795 792 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
796 793
797 794 def _flush_results(self, sock):
798 795 """Flush task or queue results waiting in ZMQ queue."""
799 796 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
800 797 while msg is not None:
801 798 if self.debug:
802 799 pprint(msg)
803 800 msg_type = msg['header']['msg_type']
804 801 handler = self._queue_handlers.get(msg_type, None)
805 802 if handler is None:
806 803 raise Exception("Unhandled message type: %s"%msg.msg_type)
807 804 else:
808 805 handler(msg)
809 806 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
810 807
811 808 def _flush_control(self, sock):
812 809 """Flush replies from the control channel waiting
813 810 in the ZMQ queue.
814 811
815 812 Currently: ignore them."""
816 813 if self._ignored_control_replies <= 0:
817 814 return
818 815 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
819 816 while msg is not None:
820 817 self._ignored_control_replies -= 1
821 818 if self.debug:
822 819 pprint(msg)
823 820 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
824 821
825 822 def _flush_ignored_control(self):
826 823 """flush ignored control replies"""
827 824 while self._ignored_control_replies > 0:
828 825 self.session.recv(self._control_socket)
829 826 self._ignored_control_replies -= 1
830 827
831 828 def _flush_ignored_hub_replies(self):
832 829 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
833 830 while msg is not None:
834 831 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
835 832
836 833 def _flush_iopub(self, sock):
837 834 """Flush replies from the iopub channel waiting
838 835 in the ZMQ queue.
839 836 """
840 837 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
841 838 while msg is not None:
842 839 if self.debug:
843 840 pprint(msg)
844 841 parent = msg['parent_header']
845 842 # ignore IOPub messages with no parent.
846 843 # Caused by print statements or warnings from before the first execution.
847 844 if not parent:
848 845 continue
849 846 msg_id = parent['msg_id']
850 847 content = msg['content']
851 848 header = msg['header']
852 849 msg_type = msg['header']['msg_type']
853 850
854 851 # init metadata:
855 852 md = self.metadata[msg_id]
856 853
857 854 if msg_type == 'stream':
858 855 name = content['name']
859 856 s = md[name] or ''
860 857 md[name] = s + content['data']
861 858 elif msg_type == 'pyerr':
862 859 md.update({'pyerr' : self._unwrap_exception(content)})
863 860 elif msg_type == 'pyin':
864 861 md.update({'pyin' : content['code']})
865 862 elif msg_type == 'display_data':
866 863 md['outputs'].append(content)
867 864 elif msg_type == 'pyout':
868 865 md['pyout'] = content
869 866 elif msg_type == 'data_message':
870 867 data, remainder = serialize.unserialize_object(msg['buffers'])
871 868 md['data'].update(data)
872 869 elif msg_type == 'status':
873 870 # idle message comes after all outputs
874 871 if content['execution_state'] == 'idle':
875 872 md['outputs_ready'] = True
876 873 else:
877 874 # unhandled msg_type (status, etc.)
878 875 pass
879 876
880 877 # reduntant?
881 878 self.metadata[msg_id] = md
882 879
883 880 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
884 881
885 882 #--------------------------------------------------------------------------
886 883 # len, getitem
887 884 #--------------------------------------------------------------------------
888 885
889 886 def __len__(self):
890 887 """len(client) returns # of engines."""
891 888 return len(self.ids)
892 889
893 890 def __getitem__(self, key):
894 891 """index access returns DirectView multiplexer objects
895 892
896 893 Must be int, slice, or list/tuple/xrange of ints"""
897 894 if not isinstance(key, (int, slice, tuple, list, xrange)):
898 895 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
899 896 else:
900 897 return self.direct_view(key)
901 898
902 899 #--------------------------------------------------------------------------
903 900 # Begin public methods
904 901 #--------------------------------------------------------------------------
905 902
906 903 @property
907 904 def ids(self):
908 905 """Always up-to-date ids property."""
909 906 self._flush_notifications()
910 907 # always copy:
911 908 return list(self._ids)
912 909
913 910 def activate(self, targets='all', suffix=''):
914 911 """Create a DirectView and register it with IPython magics
915 912
916 913 Defines the magics `%px, %autopx, %pxresult, %%px`
917 914
918 915 Parameters
919 916 ----------
920 917
921 918 targets: int, list of ints, or 'all'
922 919 The engines on which the view's magics will run
923 920 suffix: str [default: '']
924 921 The suffix, if any, for the magics. This allows you to have
925 922 multiple views associated with parallel magics at the same time.
926 923
927 924 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
928 925 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
929 926 on engine 0.
930 927 """
931 928 view = self.direct_view(targets)
932 929 view.block = True
933 930 view.activate(suffix)
934 931 return view
935 932
936 933 def close(self):
937 934 if self._closed:
938 935 return
939 936 self.stop_spin_thread()
940 937 snames = filter(lambda n: n.endswith('socket'), dir(self))
941 938 for socket in map(lambda name: getattr(self, name), snames):
942 939 if isinstance(socket, zmq.Socket) and not socket.closed:
943 940 socket.close()
944 941 self._closed = True
945 942
946 943 def _spin_every(self, interval=1):
947 944 """target func for use in spin_thread"""
948 945 while True:
949 946 if self._stop_spinning.is_set():
950 947 return
951 948 time.sleep(interval)
952 949 self.spin()
953 950
954 951 def spin_thread(self, interval=1):
955 952 """call Client.spin() in a background thread on some regular interval
956 953
957 954 This helps ensure that messages don't pile up too much in the zmq queue
958 955 while you are working on other things, or just leaving an idle terminal.
959 956
960 957 It also helps limit potential padding of the `received` timestamp
961 958 on AsyncResult objects, used for timings.
962 959
963 960 Parameters
964 961 ----------
965 962
966 963 interval : float, optional
967 964 The interval on which to spin the client in the background thread
968 965 (simply passed to time.sleep).
969 966
970 967 Notes
971 968 -----
972 969
973 970 For precision timing, you may want to use this method to put a bound
974 971 on the jitter (in seconds) in `received` timestamps used
975 972 in AsyncResult.wall_time.
976 973
977 974 """
978 975 if self._spin_thread is not None:
979 976 self.stop_spin_thread()
980 977 self._stop_spinning.clear()
981 978 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
982 979 self._spin_thread.daemon = True
983 980 self._spin_thread.start()
984 981
985 982 def stop_spin_thread(self):
986 983 """stop background spin_thread, if any"""
987 984 if self._spin_thread is not None:
988 985 self._stop_spinning.set()
989 986 self._spin_thread.join()
990 987 self._spin_thread = None
991 988
992 989 def spin(self):
993 990 """Flush any registration notifications and execution results
994 991 waiting in the ZMQ queue.
995 992 """
996 993 if self._notification_socket:
997 994 self._flush_notifications()
998 995 if self._iopub_socket:
999 996 self._flush_iopub(self._iopub_socket)
1000 997 if self._mux_socket:
1001 998 self._flush_results(self._mux_socket)
1002 999 if self._task_socket:
1003 1000 self._flush_results(self._task_socket)
1004 1001 if self._control_socket:
1005 1002 self._flush_control(self._control_socket)
1006 1003 if self._query_socket:
1007 1004 self._flush_ignored_hub_replies()
1008 1005
1009 1006 def wait(self, jobs=None, timeout=-1):
1010 1007 """waits on one or more `jobs`, for up to `timeout` seconds.
1011 1008
1012 1009 Parameters
1013 1010 ----------
1014 1011
1015 1012 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1016 1013 ints are indices to self.history
1017 1014 strs are msg_ids
1018 1015 default: wait on all outstanding messages
1019 1016 timeout : float
1020 1017 a time in seconds, after which to give up.
1021 1018 default is -1, which means no timeout
1022 1019
1023 1020 Returns
1024 1021 -------
1025 1022
1026 1023 True : when all msg_ids are done
1027 1024 False : timeout reached, some msg_ids still outstanding
1028 1025 """
1029 1026 tic = time.time()
1030 1027 if jobs is None:
1031 1028 theids = self.outstanding
1032 1029 else:
1033 1030 if isinstance(jobs, (int, basestring, AsyncResult)):
1034 1031 jobs = [jobs]
1035 1032 theids = set()
1036 1033 for job in jobs:
1037 1034 if isinstance(job, int):
1038 1035 # index access
1039 1036 job = self.history[job]
1040 1037 elif isinstance(job, AsyncResult):
1041 1038 map(theids.add, job.msg_ids)
1042 1039 continue
1043 1040 theids.add(job)
1044 1041 if not theids.intersection(self.outstanding):
1045 1042 return True
1046 1043 self.spin()
1047 1044 while theids.intersection(self.outstanding):
1048 1045 if timeout >= 0 and ( time.time()-tic ) > timeout:
1049 1046 break
1050 1047 time.sleep(1e-3)
1051 1048 self.spin()
1052 1049 return len(theids.intersection(self.outstanding)) == 0
1053 1050
1054 1051 #--------------------------------------------------------------------------
1055 1052 # Control methods
1056 1053 #--------------------------------------------------------------------------
1057 1054
1058 1055 @spin_first
1059 1056 def clear(self, targets=None, block=None):
1060 1057 """Clear the namespace in target(s)."""
1061 1058 block = self.block if block is None else block
1062 1059 targets = self._build_targets(targets)[0]
1063 1060 for t in targets:
1064 1061 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1065 1062 error = False
1066 1063 if block:
1067 1064 self._flush_ignored_control()
1068 1065 for i in range(len(targets)):
1069 1066 idents,msg = self.session.recv(self._control_socket,0)
1070 1067 if self.debug:
1071 1068 pprint(msg)
1072 1069 if msg['content']['status'] != 'ok':
1073 1070 error = self._unwrap_exception(msg['content'])
1074 1071 else:
1075 1072 self._ignored_control_replies += len(targets)
1076 1073 if error:
1077 1074 raise error
1078 1075
1079 1076
1080 1077 @spin_first
1081 1078 def abort(self, jobs=None, targets=None, block=None):
1082 1079 """Abort specific jobs from the execution queues of target(s).
1083 1080
1084 1081 This is a mechanism to prevent jobs that have already been submitted
1085 1082 from executing.
1086 1083
1087 1084 Parameters
1088 1085 ----------
1089 1086
1090 1087 jobs : msg_id, list of msg_ids, or AsyncResult
1091 1088 The jobs to be aborted
1092 1089
1093 1090 If unspecified/None: abort all outstanding jobs.
1094 1091
1095 1092 """
1096 1093 block = self.block if block is None else block
1097 1094 jobs = jobs if jobs is not None else list(self.outstanding)
1098 1095 targets = self._build_targets(targets)[0]
1099 1096
1100 1097 msg_ids = []
1101 1098 if isinstance(jobs, (basestring,AsyncResult)):
1102 1099 jobs = [jobs]
1103 1100 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1104 1101 if bad_ids:
1105 1102 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1106 1103 for j in jobs:
1107 1104 if isinstance(j, AsyncResult):
1108 1105 msg_ids.extend(j.msg_ids)
1109 1106 else:
1110 1107 msg_ids.append(j)
1111 1108 content = dict(msg_ids=msg_ids)
1112 1109 for t in targets:
1113 1110 self.session.send(self._control_socket, 'abort_request',
1114 1111 content=content, ident=t)
1115 1112 error = False
1116 1113 if block:
1117 1114 self._flush_ignored_control()
1118 1115 for i in range(len(targets)):
1119 1116 idents,msg = self.session.recv(self._control_socket,0)
1120 1117 if self.debug:
1121 1118 pprint(msg)
1122 1119 if msg['content']['status'] != 'ok':
1123 1120 error = self._unwrap_exception(msg['content'])
1124 1121 else:
1125 1122 self._ignored_control_replies += len(targets)
1126 1123 if error:
1127 1124 raise error
1128 1125
1129 1126 @spin_first
1130 1127 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1131 1128 """Terminates one or more engine processes, optionally including the hub.
1132 1129
1133 1130 Parameters
1134 1131 ----------
1135 1132
1136 1133 targets: list of ints or 'all' [default: all]
1137 1134 Which engines to shutdown.
1138 1135 hub: bool [default: False]
1139 1136 Whether to include the Hub. hub=True implies targets='all'.
1140 1137 block: bool [default: self.block]
1141 1138 Whether to wait for clean shutdown replies or not.
1142 1139 restart: bool [default: False]
1143 1140 NOT IMPLEMENTED
1144 1141 whether to restart engines after shutting them down.
1145 1142 """
1146 1143
1147 1144 if restart:
1148 1145 raise NotImplementedError("Engine restart is not yet implemented")
1149 1146
1150 1147 block = self.block if block is None else block
1151 1148 if hub:
1152 1149 targets = 'all'
1153 1150 targets = self._build_targets(targets)[0]
1154 1151 for t in targets:
1155 1152 self.session.send(self._control_socket, 'shutdown_request',
1156 1153 content={'restart':restart},ident=t)
1157 1154 error = False
1158 1155 if block or hub:
1159 1156 self._flush_ignored_control()
1160 1157 for i in range(len(targets)):
1161 1158 idents,msg = self.session.recv(self._control_socket, 0)
1162 1159 if self.debug:
1163 1160 pprint(msg)
1164 1161 if msg['content']['status'] != 'ok':
1165 1162 error = self._unwrap_exception(msg['content'])
1166 1163 else:
1167 1164 self._ignored_control_replies += len(targets)
1168 1165
1169 1166 if hub:
1170 1167 time.sleep(0.25)
1171 1168 self.session.send(self._query_socket, 'shutdown_request')
1172 1169 idents,msg = self.session.recv(self._query_socket, 0)
1173 1170 if self.debug:
1174 1171 pprint(msg)
1175 1172 if msg['content']['status'] != 'ok':
1176 1173 error = self._unwrap_exception(msg['content'])
1177 1174
1178 1175 if error:
1179 1176 raise error
1180 1177
1181 1178 #--------------------------------------------------------------------------
1182 1179 # Execution related methods
1183 1180 #--------------------------------------------------------------------------
1184 1181
1185 1182 def _maybe_raise(self, result):
1186 1183 """wrapper for maybe raising an exception if apply failed."""
1187 1184 if isinstance(result, error.RemoteError):
1188 1185 raise result
1189 1186
1190 1187 return result
1191 1188
1192 1189 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1193 1190 ident=None):
1194 1191 """construct and send an apply message via a socket.
1195 1192
1196 1193 This is the principal method with which all engine execution is performed by views.
1197 1194 """
1198 1195
1199 1196 if self._closed:
1200 1197 raise RuntimeError("Client cannot be used after its sockets have been closed")
1201 1198
1202 1199 # defaults:
1203 1200 args = args if args is not None else []
1204 1201 kwargs = kwargs if kwargs is not None else {}
1205 1202 metadata = metadata if metadata is not None else {}
1206 1203
1207 1204 # validate arguments
1208 1205 if not callable(f) and not isinstance(f, Reference):
1209 1206 raise TypeError("f must be callable, not %s"%type(f))
1210 1207 if not isinstance(args, (tuple, list)):
1211 1208 raise TypeError("args must be tuple or list, not %s"%type(args))
1212 1209 if not isinstance(kwargs, dict):
1213 1210 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1214 1211 if not isinstance(metadata, dict):
1215 1212 raise TypeError("metadata must be dict, not %s"%type(metadata))
1216 1213
1217 1214 bufs = serialize.pack_apply_message(f, args, kwargs,
1218 1215 buffer_threshold=self.session.buffer_threshold,
1219 1216 item_threshold=self.session.item_threshold,
1220 1217 )
1221 1218
1222 1219 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1223 1220 metadata=metadata, track=track)
1224 1221
1225 1222 msg_id = msg['header']['msg_id']
1226 1223 self.outstanding.add(msg_id)
1227 1224 if ident:
1228 1225 # possibly routed to a specific engine
1229 1226 if isinstance(ident, list):
1230 1227 ident = ident[-1]
1231 1228 if ident in self._engines.values():
1232 1229 # save for later, in case of engine death
1233 1230 self._outstanding_dict[ident].add(msg_id)
1234 1231 self.history.append(msg_id)
1235 1232 self.metadata[msg_id]['submitted'] = datetime.now()
1236 1233
1237 1234 return msg
1238 1235
1239 1236 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1240 1237 """construct and send an execute request via a socket.
1241 1238
1242 1239 """
1243 1240
1244 1241 if self._closed:
1245 1242 raise RuntimeError("Client cannot be used after its sockets have been closed")
1246 1243
1247 1244 # defaults:
1248 1245 metadata = metadata if metadata is not None else {}
1249 1246
1250 1247 # validate arguments
1251 1248 if not isinstance(code, basestring):
1252 1249 raise TypeError("code must be text, not %s" % type(code))
1253 1250 if not isinstance(metadata, dict):
1254 1251 raise TypeError("metadata must be dict, not %s" % type(metadata))
1255 1252
1256 1253 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1257 1254
1258 1255
1259 1256 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1260 1257 metadata=metadata)
1261 1258
1262 1259 msg_id = msg['header']['msg_id']
1263 1260 self.outstanding.add(msg_id)
1264 1261 if ident:
1265 1262 # possibly routed to a specific engine
1266 1263 if isinstance(ident, list):
1267 1264 ident = ident[-1]
1268 1265 if ident in self._engines.values():
1269 1266 # save for later, in case of engine death
1270 1267 self._outstanding_dict[ident].add(msg_id)
1271 1268 self.history.append(msg_id)
1272 1269 self.metadata[msg_id]['submitted'] = datetime.now()
1273 1270
1274 1271 return msg
1275 1272
1276 1273 #--------------------------------------------------------------------------
1277 1274 # construct a View object
1278 1275 #--------------------------------------------------------------------------
1279 1276
1280 1277 def load_balanced_view(self, targets=None):
1281 1278 """construct a DirectView object.
1282 1279
1283 1280 If no arguments are specified, create a LoadBalancedView
1284 1281 using all engines.
1285 1282
1286 1283 Parameters
1287 1284 ----------
1288 1285
1289 1286 targets: list,slice,int,etc. [default: use all engines]
1290 1287 The subset of engines across which to load-balance
1291 1288 """
1292 1289 if targets == 'all':
1293 1290 targets = None
1294 1291 if targets is not None:
1295 1292 targets = self._build_targets(targets)[1]
1296 1293 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1297 1294
1298 1295 def direct_view(self, targets='all'):
1299 1296 """construct a DirectView object.
1300 1297
1301 1298 If no targets are specified, create a DirectView using all engines.
1302 1299
1303 1300 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1304 1301 evaluate the target engines at each execution, whereas rc[:] will connect to
1305 1302 all *current* engines, and that list will not change.
1306 1303
1307 1304 That is, 'all' will always use all engines, whereas rc[:] will not use
1308 1305 engines added after the DirectView is constructed.
1309 1306
1310 1307 Parameters
1311 1308 ----------
1312 1309
1313 1310 targets: list,slice,int,etc. [default: use all engines]
1314 1311 The engines to use for the View
1315 1312 """
1316 1313 single = isinstance(targets, int)
1317 1314 # allow 'all' to be lazily evaluated at each execution
1318 1315 if targets != 'all':
1319 1316 targets = self._build_targets(targets)[1]
1320 1317 if single:
1321 1318 targets = targets[0]
1322 1319 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1323 1320
1324 1321 #--------------------------------------------------------------------------
1325 1322 # Query methods
1326 1323 #--------------------------------------------------------------------------
1327 1324
1328 1325 @spin_first
1329 1326 def get_result(self, indices_or_msg_ids=None, block=None):
1330 1327 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1331 1328
1332 1329 If the client already has the results, no request to the Hub will be made.
1333 1330
1334 1331 This is a convenient way to construct AsyncResult objects, which are wrappers
1335 1332 that include metadata about execution, and allow for awaiting results that
1336 1333 were not submitted by this Client.
1337 1334
1338 1335 It can also be a convenient way to retrieve the metadata associated with
1339 1336 blocking execution, since it always retrieves
1340 1337
1341 1338 Examples
1342 1339 --------
1343 1340 ::
1344 1341
1345 1342 In [10]: r = client.apply()
1346 1343
1347 1344 Parameters
1348 1345 ----------
1349 1346
1350 1347 indices_or_msg_ids : integer history index, str msg_id, or list of either
1351 1348 The indices or msg_ids of indices to be retrieved
1352 1349
1353 1350 block : bool
1354 1351 Whether to wait for the result to be done
1355 1352
1356 1353 Returns
1357 1354 -------
1358 1355
1359 1356 AsyncResult
1360 1357 A single AsyncResult object will always be returned.
1361 1358
1362 1359 AsyncHubResult
1363 1360 A subclass of AsyncResult that retrieves results from the Hub
1364 1361
1365 1362 """
1366 1363 block = self.block if block is None else block
1367 1364 if indices_or_msg_ids is None:
1368 1365 indices_or_msg_ids = -1
1369 1366
1370 1367 if not isinstance(indices_or_msg_ids, (list,tuple)):
1371 1368 indices_or_msg_ids = [indices_or_msg_ids]
1372 1369
1373 1370 theids = []
1374 1371 for id in indices_or_msg_ids:
1375 1372 if isinstance(id, int):
1376 1373 id = self.history[id]
1377 1374 if not isinstance(id, basestring):
1378 1375 raise TypeError("indices must be str or int, not %r"%id)
1379 1376 theids.append(id)
1380 1377
1381 1378 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1382 1379 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1383 1380
1384 1381 if remote_ids:
1385 1382 ar = AsyncHubResult(self, msg_ids=theids)
1386 1383 else:
1387 1384 ar = AsyncResult(self, msg_ids=theids)
1388 1385
1389 1386 if block:
1390 1387 ar.wait()
1391 1388
1392 1389 return ar
1393 1390
1394 1391 @spin_first
1395 1392 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1396 1393 """Resubmit one or more tasks.
1397 1394
1398 1395 in-flight tasks may not be resubmitted.
1399 1396
1400 1397 Parameters
1401 1398 ----------
1402 1399
1403 1400 indices_or_msg_ids : integer history index, str msg_id, or list of either
1404 1401 The indices or msg_ids of indices to be retrieved
1405 1402
1406 1403 block : bool
1407 1404 Whether to wait for the result to be done
1408 1405
1409 1406 Returns
1410 1407 -------
1411 1408
1412 1409 AsyncHubResult
1413 1410 A subclass of AsyncResult that retrieves results from the Hub
1414 1411
1415 1412 """
1416 1413 block = self.block if block is None else block
1417 1414 if indices_or_msg_ids is None:
1418 1415 indices_or_msg_ids = -1
1419 1416
1420 1417 if not isinstance(indices_or_msg_ids, (list,tuple)):
1421 1418 indices_or_msg_ids = [indices_or_msg_ids]
1422 1419
1423 1420 theids = []
1424 1421 for id in indices_or_msg_ids:
1425 1422 if isinstance(id, int):
1426 1423 id = self.history[id]
1427 1424 if not isinstance(id, basestring):
1428 1425 raise TypeError("indices must be str or int, not %r"%id)
1429 1426 theids.append(id)
1430 1427
1431 1428 content = dict(msg_ids = theids)
1432 1429
1433 1430 self.session.send(self._query_socket, 'resubmit_request', content)
1434 1431
1435 1432 zmq.select([self._query_socket], [], [])
1436 1433 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1437 1434 if self.debug:
1438 1435 pprint(msg)
1439 1436 content = msg['content']
1440 1437 if content['status'] != 'ok':
1441 1438 raise self._unwrap_exception(content)
1442 1439 mapping = content['resubmitted']
1443 1440 new_ids = [ mapping[msg_id] for msg_id in theids ]
1444 1441
1445 1442 ar = AsyncHubResult(self, msg_ids=new_ids)
1446 1443
1447 1444 if block:
1448 1445 ar.wait()
1449 1446
1450 1447 return ar
1451 1448
1452 1449 @spin_first
1453 1450 def result_status(self, msg_ids, status_only=True):
1454 1451 """Check on the status of the result(s) of the apply request with `msg_ids`.
1455 1452
1456 1453 If status_only is False, then the actual results will be retrieved, else
1457 1454 only the status of the results will be checked.
1458 1455
1459 1456 Parameters
1460 1457 ----------
1461 1458
1462 1459 msg_ids : list of msg_ids
1463 1460 if int:
1464 1461 Passed as index to self.history for convenience.
1465 1462 status_only : bool (default: True)
1466 1463 if False:
1467 1464 Retrieve the actual results of completed tasks.
1468 1465
1469 1466 Returns
1470 1467 -------
1471 1468
1472 1469 results : dict
1473 1470 There will always be the keys 'pending' and 'completed', which will
1474 1471 be lists of msg_ids that are incomplete or complete. If `status_only`
1475 1472 is False, then completed results will be keyed by their `msg_id`.
1476 1473 """
1477 1474 if not isinstance(msg_ids, (list,tuple)):
1478 1475 msg_ids = [msg_ids]
1479 1476
1480 1477 theids = []
1481 1478 for msg_id in msg_ids:
1482 1479 if isinstance(msg_id, int):
1483 1480 msg_id = self.history[msg_id]
1484 1481 if not isinstance(msg_id, basestring):
1485 1482 raise TypeError("msg_ids must be str, not %r"%msg_id)
1486 1483 theids.append(msg_id)
1487 1484
1488 1485 completed = []
1489 1486 local_results = {}
1490 1487
1491 1488 # comment this block out to temporarily disable local shortcut:
1492 1489 for msg_id in theids:
1493 1490 if msg_id in self.results:
1494 1491 completed.append(msg_id)
1495 1492 local_results[msg_id] = self.results[msg_id]
1496 1493 theids.remove(msg_id)
1497 1494
1498 1495 if theids: # some not locally cached
1499 1496 content = dict(msg_ids=theids, status_only=status_only)
1500 1497 msg = self.session.send(self._query_socket, "result_request", content=content)
1501 1498 zmq.select([self._query_socket], [], [])
1502 1499 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1503 1500 if self.debug:
1504 1501 pprint(msg)
1505 1502 content = msg['content']
1506 1503 if content['status'] != 'ok':
1507 1504 raise self._unwrap_exception(content)
1508 1505 buffers = msg['buffers']
1509 1506 else:
1510 1507 content = dict(completed=[],pending=[])
1511 1508
1512 1509 content['completed'].extend(completed)
1513 1510
1514 1511 if status_only:
1515 1512 return content
1516 1513
1517 1514 failures = []
1518 1515 # load cached results into result:
1519 1516 content.update(local_results)
1520 1517
1521 1518 # update cache with results:
1522 1519 for msg_id in sorted(theids):
1523 1520 if msg_id in content['completed']:
1524 1521 rec = content[msg_id]
1525 1522 parent = rec['header']
1526 1523 header = rec['result_header']
1527 1524 rcontent = rec['result_content']
1528 1525 iodict = rec['io']
1529 1526 if isinstance(rcontent, str):
1530 1527 rcontent = self.session.unpack(rcontent)
1531 1528
1532 1529 md = self.metadata[msg_id]
1533 1530 md_msg = dict(
1534 1531 content=rcontent,
1535 1532 parent_header=parent,
1536 1533 header=header,
1537 1534 metadata=rec['result_metadata'],
1538 1535 )
1539 1536 md.update(self._extract_metadata(md_msg))
1540 1537 if rec.get('received'):
1541 1538 md['received'] = rec['received']
1542 1539 md.update(iodict)
1543 1540
1544 1541 if rcontent['status'] == 'ok':
1545 1542 if header['msg_type'] == 'apply_reply':
1546 1543 res,buffers = serialize.unserialize_object(buffers)
1547 1544 elif header['msg_type'] == 'execute_reply':
1548 1545 res = ExecuteReply(msg_id, rcontent, md)
1549 1546 else:
1550 1547 raise KeyError("unhandled msg type: %r" % header[msg_type])
1551 1548 else:
1552 1549 res = self._unwrap_exception(rcontent)
1553 1550 failures.append(res)
1554 1551
1555 1552 self.results[msg_id] = res
1556 1553 content[msg_id] = res
1557 1554
1558 1555 if len(theids) == 1 and failures:
1559 1556 raise failures[0]
1560 1557
1561 1558 error.collect_exceptions(failures, "result_status")
1562 1559 return content
1563 1560
1564 1561 @spin_first
1565 1562 def queue_status(self, targets='all', verbose=False):
1566 1563 """Fetch the status of engine queues.
1567 1564
1568 1565 Parameters
1569 1566 ----------
1570 1567
1571 1568 targets : int/str/list of ints/strs
1572 1569 the engines whose states are to be queried.
1573 1570 default : all
1574 1571 verbose : bool
1575 1572 Whether to return lengths only, or lists of ids for each element
1576 1573 """
1577 1574 if targets == 'all':
1578 1575 # allow 'all' to be evaluated on the engine
1579 1576 engine_ids = None
1580 1577 else:
1581 1578 engine_ids = self._build_targets(targets)[1]
1582 1579 content = dict(targets=engine_ids, verbose=verbose)
1583 1580 self.session.send(self._query_socket, "queue_request", content=content)
1584 1581 idents,msg = self.session.recv(self._query_socket, 0)
1585 1582 if self.debug:
1586 1583 pprint(msg)
1587 1584 content = msg['content']
1588 1585 status = content.pop('status')
1589 1586 if status != 'ok':
1590 1587 raise self._unwrap_exception(content)
1591 1588 content = rekey(content)
1592 1589 if isinstance(targets, int):
1593 1590 return content[targets]
1594 1591 else:
1595 1592 return content
1596 1593
1597 1594 @spin_first
1598 1595 def purge_results(self, jobs=[], targets=[]):
1599 1596 """Tell the Hub to forget results.
1600 1597
1601 1598 Individual results can be purged by msg_id, or the entire
1602 1599 history of specific targets can be purged.
1603 1600
1604 1601 Use `purge_results('all')` to scrub everything from the Hub's db.
1605 1602
1606 1603 Parameters
1607 1604 ----------
1608 1605
1609 1606 jobs : str or list of str or AsyncResult objects
1610 1607 the msg_ids whose results should be forgotten.
1611 1608 targets : int/str/list of ints/strs
1612 1609 The targets, by int_id, whose entire history is to be purged.
1613 1610
1614 1611 default : None
1615 1612 """
1616 1613 if not targets and not jobs:
1617 1614 raise ValueError("Must specify at least one of `targets` and `jobs`")
1618 1615 if targets:
1619 1616 targets = self._build_targets(targets)[1]
1620 1617
1621 1618 # construct msg_ids from jobs
1622 1619 if jobs == 'all':
1623 1620 msg_ids = jobs
1624 1621 else:
1625 1622 msg_ids = []
1626 1623 if isinstance(jobs, (basestring,AsyncResult)):
1627 1624 jobs = [jobs]
1628 1625 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1629 1626 if bad_ids:
1630 1627 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1631 1628 for j in jobs:
1632 1629 if isinstance(j, AsyncResult):
1633 1630 msg_ids.extend(j.msg_ids)
1634 1631 else:
1635 1632 msg_ids.append(j)
1636 1633
1637 1634 content = dict(engine_ids=targets, msg_ids=msg_ids)
1638 1635 self.session.send(self._query_socket, "purge_request", content=content)
1639 1636 idents, msg = self.session.recv(self._query_socket, 0)
1640 1637 if self.debug:
1641 1638 pprint(msg)
1642 1639 content = msg['content']
1643 1640 if content['status'] != 'ok':
1644 1641 raise self._unwrap_exception(content)
1645 1642
1646 1643 @spin_first
1647 1644 def hub_history(self):
1648 1645 """Get the Hub's history
1649 1646
1650 1647 Just like the Client, the Hub has a history, which is a list of msg_ids.
1651 1648 This will contain the history of all clients, and, depending on configuration,
1652 1649 may contain history across multiple cluster sessions.
1653 1650
1654 1651 Any msg_id returned here is a valid argument to `get_result`.
1655 1652
1656 1653 Returns
1657 1654 -------
1658 1655
1659 1656 msg_ids : list of strs
1660 1657 list of all msg_ids, ordered by task submission time.
1661 1658 """
1662 1659
1663 1660 self.session.send(self._query_socket, "history_request", content={})
1664 1661 idents, msg = self.session.recv(self._query_socket, 0)
1665 1662
1666 1663 if self.debug:
1667 1664 pprint(msg)
1668 1665 content = msg['content']
1669 1666 if content['status'] != 'ok':
1670 1667 raise self._unwrap_exception(content)
1671 1668 else:
1672 1669 return content['history']
1673 1670
1674 1671 @spin_first
1675 1672 def db_query(self, query, keys=None):
1676 1673 """Query the Hub's TaskRecord database
1677 1674
1678 1675 This will return a list of task record dicts that match `query`
1679 1676
1680 1677 Parameters
1681 1678 ----------
1682 1679
1683 1680 query : mongodb query dict
1684 1681 The search dict. See mongodb query docs for details.
1685 1682 keys : list of strs [optional]
1686 1683 The subset of keys to be returned. The default is to fetch everything but buffers.
1687 1684 'msg_id' will *always* be included.
1688 1685 """
1689 1686 if isinstance(keys, basestring):
1690 1687 keys = [keys]
1691 1688 content = dict(query=query, keys=keys)
1692 1689 self.session.send(self._query_socket, "db_request", content=content)
1693 1690 idents, msg = self.session.recv(self._query_socket, 0)
1694 1691 if self.debug:
1695 1692 pprint(msg)
1696 1693 content = msg['content']
1697 1694 if content['status'] != 'ok':
1698 1695 raise self._unwrap_exception(content)
1699 1696
1700 1697 records = content['records']
1701 1698
1702 1699 buffer_lens = content['buffer_lens']
1703 1700 result_buffer_lens = content['result_buffer_lens']
1704 1701 buffers = msg['buffers']
1705 1702 has_bufs = buffer_lens is not None
1706 1703 has_rbufs = result_buffer_lens is not None
1707 1704 for i,rec in enumerate(records):
1708 1705 # relink buffers
1709 1706 if has_bufs:
1710 1707 blen = buffer_lens[i]
1711 1708 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1712 1709 if has_rbufs:
1713 1710 blen = result_buffer_lens[i]
1714 1711 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1715 1712
1716 1713 return records
1717 1714
1718 1715 __all__ = [ 'Client' ]
General Comments 0
You need to be logged in to leave comments. Login now