##// END OF EJS Templates
store execute_replies in a nice wrapper
MinRK -
Show More
@@ -1,1566 +1,1620 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 from threading import Thread, Event
22 22 import time
23 23 import warnings
24 24 from datetime import datetime
25 25 from getpass import getpass
26 26 from pprint import pprint
27 27
28 28 pjoin = os.path.join
29 29
30 30 import zmq
31 31 # from zmq.eventloop import ioloop, zmqstream
32 32
33 33 from IPython.config.configurable import MultipleInstanceError
34 34 from IPython.core.application import BaseIPythonApplication
35 35
36 36 from IPython.utils.jsonutil import rekey
37 37 from IPython.utils.localinterfaces import LOCAL_IPS
38 38 from IPython.utils.path import get_ipython_dir
39 39 from IPython.utils.py3compat import cast_bytes
40 40 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
41 41 Dict, List, Bool, Set, Any)
42 42 from IPython.external.decorator import decorator
43 43 from IPython.external.ssh import tunnel
44 44
45 45 from IPython.parallel import Reference
46 46 from IPython.parallel import error
47 47 from IPython.parallel import util
48 48
49 49 from IPython.zmq.session import Session, Message
50 50
51 51 from .asyncresult import AsyncResult, AsyncHubResult
52 52 from IPython.core.profiledir import ProfileDir, ProfileDirError
53 53 from .view import DirectView, LoadBalancedView
54 54
55 55 if sys.version_info[0] >= 3:
56 56 # xrange is used in a couple 'isinstance' tests in py2
57 57 # should be just 'range' in 3k
58 58 xrange = range
59 59
60 60 #--------------------------------------------------------------------------
61 61 # Decorators for Client methods
62 62 #--------------------------------------------------------------------------
63 63
64 64 @decorator
65 65 def spin_first(f, self, *args, **kwargs):
66 66 """Call spin() to sync state prior to calling the method."""
67 67 self.spin()
68 68 return f(self, *args, **kwargs)
69 69
70 70
71 71 #--------------------------------------------------------------------------
72 72 # Classes
73 73 #--------------------------------------------------------------------------
74 74
75
76 class ExecuteReply(object):
77 """wrapper for finished Execute results"""
78 def __init__(self, msg_id, content, metadata):
79 self.msg_id = msg_id
80 self._content = content
81 self.execution_count = content['execution_count']
82 self.metadata = metadata
83
84 def __getitem__(self, key):
85 return self.metadata[key]
86
87 def __getattr__(self, key):
88 if key not in self.metadata:
89 raise AttributeError(key)
90 return self.metadata[key]
91
92 def __repr__(self):
93 pyout = self.metadata['pyout'] or {}
94 text_out = pyout.get('text/plain', '')
95 if len(text_out) > 32:
96 text_out = text_out[:29] + '...'
97
98 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
99
100 def _repr_html_(self):
101 pyout = self.metadata['pyout'] or {}
102 return pyout.get("text/html")
103
104 def _repr_latex_(self):
105 pyout = self.metadata['pyout'] or {}
106 return pyout.get("text/latex")
107
108 def _repr_json_(self):
109 pyout = self.metadata['pyout'] or {}
110 return pyout.get("application/json")
111
112 def _repr_javascript_(self):
113 pyout = self.metadata['pyout'] or {}
114 return pyout.get("application/javascript")
115
116 def _repr_png_(self):
117 pyout = self.metadata['pyout'] or {}
118 return pyout.get("image/png")
119
120 def _repr_jpeg_(self):
121 pyout = self.metadata['pyout'] or {}
122 return pyout.get("image/jpeg")
123
124 def _repr_svg_(self):
125 pyout = self.metadata['pyout'] or {}
126 return pyout.get("image/svg+xml")
127
128
75 129 class Metadata(dict):
76 130 """Subclass of dict for initializing metadata values.
77 131
78 132 Attribute access works on keys.
79 133
80 134 These objects have a strict set of keys - errors will raise if you try
81 135 to add new keys.
82 136 """
83 137 def __init__(self, *args, **kwargs):
84 138 dict.__init__(self)
85 139 md = {'msg_id' : None,
86 140 'submitted' : None,
87 141 'started' : None,
88 142 'completed' : None,
89 143 'received' : None,
90 144 'engine_uuid' : None,
91 145 'engine_id' : None,
92 146 'follow' : None,
93 147 'after' : None,
94 148 'status' : None,
95 149
96 150 'pyin' : None,
97 151 'pyout' : None,
98 152 'pyerr' : None,
99 153 'stdout' : '',
100 154 'stderr' : '',
101 155 'outputs' : [],
102 156 }
103 157 self.update(md)
104 158 self.update(dict(*args, **kwargs))
105 159
106 160 def __getattr__(self, key):
107 161 """getattr aliased to getitem"""
108 162 if key in self.iterkeys():
109 163 return self[key]
110 164 else:
111 165 raise AttributeError(key)
112 166
113 167 def __setattr__(self, key, value):
114 168 """setattr aliased to setitem, with strict"""
115 169 if key in self.iterkeys():
116 170 self[key] = value
117 171 else:
118 172 raise AttributeError(key)
119 173
120 174 def __setitem__(self, key, value):
121 175 """strict static key enforcement"""
122 176 if key in self.iterkeys():
123 177 dict.__setitem__(self, key, value)
124 178 else:
125 179 raise KeyError(key)
126 180
127 181
128 182 class Client(HasTraits):
129 183 """A semi-synchronous client to the IPython ZMQ cluster
130 184
131 185 Parameters
132 186 ----------
133 187
134 188 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
135 189 Connection information for the Hub's registration. If a json connector
136 190 file is given, then likely no further configuration is necessary.
137 191 [Default: use profile]
138 192 profile : bytes
139 193 The name of the Cluster profile to be used to find connector information.
140 194 If run from an IPython application, the default profile will be the same
141 195 as the running application, otherwise it will be 'default'.
142 196 context : zmq.Context
143 197 Pass an existing zmq.Context instance, otherwise the client will create its own.
144 198 debug : bool
145 199 flag for lots of message printing for debug purposes
146 200 timeout : int/float
147 201 time (in seconds) to wait for connection replies from the Hub
148 202 [Default: 10]
149 203
150 204 #-------------- session related args ----------------
151 205
152 206 config : Config object
153 207 If specified, this will be relayed to the Session for configuration
154 208 username : str
155 209 set username for the session object
156 210 packer : str (import_string) or callable
157 211 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
158 212 function to serialize messages. Must support same input as
159 213 JSON, and output must be bytes.
160 214 You can pass a callable directly as `pack`
161 215 unpacker : str (import_string) or callable
162 216 The inverse of packer. Only necessary if packer is specified as *not* one
163 217 of 'json' or 'pickle'.
164 218
165 219 #-------------- ssh related args ----------------
166 220 # These are args for configuring the ssh tunnel to be used
167 221 # credentials are used to forward connections over ssh to the Controller
168 222 # Note that the ip given in `addr` needs to be relative to sshserver
169 223 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
170 224 # and set sshserver as the same machine the Controller is on. However,
171 225 # the only requirement is that sshserver is able to see the Controller
172 226 # (i.e. is within the same trusted network).
173 227
174 228 sshserver : str
175 229 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
176 230 If keyfile or password is specified, and this is not, it will default to
177 231 the ip given in addr.
178 232 sshkey : str; path to ssh private key file
179 233 This specifies a key to be used in ssh login, default None.
180 234 Regular default ssh keys will be used without specifying this argument.
181 235 password : str
182 236 Your ssh password to sshserver. Note that if this is left None,
183 237 you will be prompted for it if passwordless key based login is unavailable.
184 238 paramiko : bool
185 239 flag for whether to use paramiko instead of shell ssh for tunneling.
186 240 [default: True on win32, False else]
187 241
188 242 ------- exec authentication args -------
189 243 If even localhost is untrusted, you can have some protection against
190 244 unauthorized execution by signing messages with HMAC digests.
191 245 Messages are still sent as cleartext, so if someone can snoop your
192 246 loopback traffic this will not protect your privacy, but will prevent
193 247 unauthorized execution.
194 248
195 249 exec_key : str
196 250 an authentication key or file containing a key
197 251 default: None
198 252
199 253
200 254 Attributes
201 255 ----------
202 256
203 257 ids : list of int engine IDs
204 258 requesting the ids attribute always synchronizes
205 259 the registration state. To request ids without synchronization,
206 260 use semi-private _ids attributes.
207 261
208 262 history : list of msg_ids
209 263 a list of msg_ids, keeping track of all the execution
210 264 messages you have submitted in order.
211 265
212 266 outstanding : set of msg_ids
213 267 a set of msg_ids that have been submitted, but whose
214 268 results have not yet been received.
215 269
216 270 results : dict
217 271 a dict of all our results, keyed by msg_id
218 272
219 273 block : bool
220 274 determines default behavior when block not specified
221 275 in execution methods
222 276
223 277 Methods
224 278 -------
225 279
226 280 spin
227 281 flushes incoming results and registration state changes
228 282 control methods spin, and requesting `ids` also ensures up to date
229 283
230 284 wait
231 285 wait on one or more msg_ids
232 286
233 287 execution methods
234 288 apply
235 289 legacy: execute, run
236 290
237 291 data movement
238 292 push, pull, scatter, gather
239 293
240 294 query methods
241 295 queue_status, get_result, purge, result_status
242 296
243 297 control methods
244 298 abort, shutdown
245 299
246 300 """
247 301
248 302
249 303 block = Bool(False)
250 304 outstanding = Set()
251 305 results = Instance('collections.defaultdict', (dict,))
252 306 metadata = Instance('collections.defaultdict', (Metadata,))
253 307 history = List()
254 308 debug = Bool(False)
255 309 _spin_thread = Any()
256 310 _stop_spinning = Any()
257 311
258 312 profile=Unicode()
259 313 def _profile_default(self):
260 314 if BaseIPythonApplication.initialized():
261 315 # an IPython app *might* be running, try to get its profile
262 316 try:
263 317 return BaseIPythonApplication.instance().profile
264 318 except (AttributeError, MultipleInstanceError):
265 319 # could be a *different* subclass of config.Application,
266 320 # which would raise one of these two errors.
267 321 return u'default'
268 322 else:
269 323 return u'default'
270 324
271 325
272 326 _outstanding_dict = Instance('collections.defaultdict', (set,))
273 327 _ids = List()
274 328 _connected=Bool(False)
275 329 _ssh=Bool(False)
276 330 _context = Instance('zmq.Context')
277 331 _config = Dict()
278 332 _engines=Instance(util.ReverseDict, (), {})
279 333 # _hub_socket=Instance('zmq.Socket')
280 334 _query_socket=Instance('zmq.Socket')
281 335 _control_socket=Instance('zmq.Socket')
282 336 _iopub_socket=Instance('zmq.Socket')
283 337 _notification_socket=Instance('zmq.Socket')
284 338 _mux_socket=Instance('zmq.Socket')
285 339 _task_socket=Instance('zmq.Socket')
286 340 _task_scheme=Unicode()
287 341 _closed = False
288 342 _ignored_control_replies=Integer(0)
289 343 _ignored_hub_replies=Integer(0)
290 344
291 345 def __new__(self, *args, **kw):
292 346 # don't raise on positional args
293 347 return HasTraits.__new__(self, **kw)
294 348
295 349 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
296 350 context=None, debug=False, exec_key=None,
297 351 sshserver=None, sshkey=None, password=None, paramiko=None,
298 352 timeout=10, **extra_args
299 353 ):
300 354 if profile:
301 355 super(Client, self).__init__(debug=debug, profile=profile)
302 356 else:
303 357 super(Client, self).__init__(debug=debug)
304 358 if context is None:
305 359 context = zmq.Context.instance()
306 360 self._context = context
307 361 self._stop_spinning = Event()
308 362
309 363 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
310 364 if self._cd is not None:
311 365 if url_or_file is None:
312 366 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
313 367 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
314 368 " Please specify at least one of url_or_file or profile."
315 369
316 370 if not util.is_url(url_or_file):
317 371 # it's not a url, try for a file
318 372 if not os.path.exists(url_or_file):
319 373 if self._cd:
320 374 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
321 375 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
322 376 with open(url_or_file) as f:
323 377 cfg = json.loads(f.read())
324 378 else:
325 379 cfg = {'url':url_or_file}
326 380
327 381 # sync defaults from args, json:
328 382 if sshserver:
329 383 cfg['ssh'] = sshserver
330 384 if exec_key:
331 385 cfg['exec_key'] = exec_key
332 386 exec_key = cfg['exec_key']
333 387 location = cfg.setdefault('location', None)
334 388 cfg['url'] = util.disambiguate_url(cfg['url'], location)
335 389 url = cfg['url']
336 390 proto,addr,port = util.split_url(url)
337 391 if location is not None and addr == '127.0.0.1':
338 392 # location specified, and connection is expected to be local
339 393 if location not in LOCAL_IPS and not sshserver:
340 394 # load ssh from JSON *only* if the controller is not on
341 395 # this machine
342 396 sshserver=cfg['ssh']
343 397 if location not in LOCAL_IPS and not sshserver:
344 398 # warn if no ssh specified, but SSH is probably needed
345 399 # This is only a warning, because the most likely cause
346 400 # is a local Controller on a laptop whose IP is dynamic
347 401 warnings.warn("""
348 402 Controller appears to be listening on localhost, but not on this machine.
349 403 If this is true, you should specify Client(...,sshserver='you@%s')
350 404 or instruct your controller to listen on an external IP."""%location,
351 405 RuntimeWarning)
352 406 elif not sshserver:
353 407 # otherwise sync with cfg
354 408 sshserver = cfg['ssh']
355 409
356 410 self._config = cfg
357 411
358 412 self._ssh = bool(sshserver or sshkey or password)
359 413 if self._ssh and sshserver is None:
360 414 # default to ssh via localhost
361 415 sshserver = url.split('://')[1].split(':')[0]
362 416 if self._ssh and password is None:
363 417 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
364 418 password=False
365 419 else:
366 420 password = getpass("SSH Password for %s: "%sshserver)
367 421 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
368 422
369 423 # configure and construct the session
370 424 if exec_key is not None:
371 425 if os.path.isfile(exec_key):
372 426 extra_args['keyfile'] = exec_key
373 427 else:
374 428 exec_key = cast_bytes(exec_key)
375 429 extra_args['key'] = exec_key
376 430 self.session = Session(**extra_args)
377 431
378 432 self._query_socket = self._context.socket(zmq.DEALER)
379 433 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
380 434 if self._ssh:
381 435 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
382 436 else:
383 437 self._query_socket.connect(url)
384 438
385 439 self.session.debug = self.debug
386 440
387 441 self._notification_handlers = {'registration_notification' : self._register_engine,
388 442 'unregistration_notification' : self._unregister_engine,
389 443 'shutdown_notification' : lambda msg: self.close(),
390 444 }
391 445 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
392 446 'apply_reply' : self._handle_apply_reply}
393 447 self._connect(sshserver, ssh_kwargs, timeout)
394 448
395 449 def __del__(self):
396 450 """cleanup sockets, but _not_ context."""
397 451 self.close()
398 452
399 453 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
400 454 if ipython_dir is None:
401 455 ipython_dir = get_ipython_dir()
402 456 if profile_dir is not None:
403 457 try:
404 458 self._cd = ProfileDir.find_profile_dir(profile_dir)
405 459 return
406 460 except ProfileDirError:
407 461 pass
408 462 elif profile is not None:
409 463 try:
410 464 self._cd = ProfileDir.find_profile_dir_by_name(
411 465 ipython_dir, profile)
412 466 return
413 467 except ProfileDirError:
414 468 pass
415 469 self._cd = None
416 470
417 471 def _update_engines(self, engines):
418 472 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
419 473 for k,v in engines.iteritems():
420 474 eid = int(k)
421 475 self._engines[eid] = v
422 476 self._ids.append(eid)
423 477 self._ids = sorted(self._ids)
424 478 if sorted(self._engines.keys()) != range(len(self._engines)) and \
425 479 self._task_scheme == 'pure' and self._task_socket:
426 480 self._stop_scheduling_tasks()
427 481
428 482 def _stop_scheduling_tasks(self):
429 483 """Stop scheduling tasks because an engine has been unregistered
430 484 from a pure ZMQ scheduler.
431 485 """
432 486 self._task_socket.close()
433 487 self._task_socket = None
434 488 msg = "An engine has been unregistered, and we are using pure " +\
435 489 "ZMQ task scheduling. Task farming will be disabled."
436 490 if self.outstanding:
437 491 msg += " If you were running tasks when this happened, " +\
438 492 "some `outstanding` msg_ids may never resolve."
439 493 warnings.warn(msg, RuntimeWarning)
440 494
441 495 def _build_targets(self, targets):
442 496 """Turn valid target IDs or 'all' into two lists:
443 497 (int_ids, uuids).
444 498 """
445 499 if not self._ids:
446 500 # flush notification socket if no engines yet, just in case
447 501 if not self.ids:
448 502 raise error.NoEnginesRegistered("Can't build targets without any engines")
449 503
450 504 if targets is None:
451 505 targets = self._ids
452 506 elif isinstance(targets, basestring):
453 507 if targets.lower() == 'all':
454 508 targets = self._ids
455 509 else:
456 510 raise TypeError("%r not valid str target, must be 'all'"%(targets))
457 511 elif isinstance(targets, int):
458 512 if targets < 0:
459 513 targets = self.ids[targets]
460 514 if targets not in self._ids:
461 515 raise IndexError("No such engine: %i"%targets)
462 516 targets = [targets]
463 517
464 518 if isinstance(targets, slice):
465 519 indices = range(len(self._ids))[targets]
466 520 ids = self.ids
467 521 targets = [ ids[i] for i in indices ]
468 522
469 523 if not isinstance(targets, (tuple, list, xrange)):
470 524 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
471 525
472 526 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
473 527
474 528 def _connect(self, sshserver, ssh_kwargs, timeout):
475 529 """setup all our socket connections to the cluster. This is called from
476 530 __init__."""
477 531
478 532 # Maybe allow reconnecting?
479 533 if self._connected:
480 534 return
481 535 self._connected=True
482 536
483 537 def connect_socket(s, url):
484 538 url = util.disambiguate_url(url, self._config['location'])
485 539 if self._ssh:
486 540 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
487 541 else:
488 542 return s.connect(url)
489 543
490 544 self.session.send(self._query_socket, 'connection_request')
491 545 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
492 546 poller = zmq.Poller()
493 547 poller.register(self._query_socket, zmq.POLLIN)
494 548 # poll expects milliseconds, timeout is seconds
495 549 evts = poller.poll(timeout*1000)
496 550 if not evts:
497 551 raise error.TimeoutError("Hub connection request timed out")
498 552 idents,msg = self.session.recv(self._query_socket,mode=0)
499 553 if self.debug:
500 554 pprint(msg)
501 555 msg = Message(msg)
502 556 content = msg.content
503 557 self._config['registration'] = dict(content)
504 558 if content.status == 'ok':
505 559 ident = self.session.bsession
506 560 if content.mux:
507 561 self._mux_socket = self._context.socket(zmq.DEALER)
508 562 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
509 563 connect_socket(self._mux_socket, content.mux)
510 564 if content.task:
511 565 self._task_scheme, task_addr = content.task
512 566 self._task_socket = self._context.socket(zmq.DEALER)
513 567 self._task_socket.setsockopt(zmq.IDENTITY, ident)
514 568 connect_socket(self._task_socket, task_addr)
515 569 if content.notification:
516 570 self._notification_socket = self._context.socket(zmq.SUB)
517 571 connect_socket(self._notification_socket, content.notification)
518 572 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
519 573 # if content.query:
520 574 # self._query_socket = self._context.socket(zmq.DEALER)
521 575 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
522 576 # connect_socket(self._query_socket, content.query)
523 577 if content.control:
524 578 self._control_socket = self._context.socket(zmq.DEALER)
525 579 self._control_socket.setsockopt(zmq.IDENTITY, ident)
526 580 connect_socket(self._control_socket, content.control)
527 581 if content.iopub:
528 582 self._iopub_socket = self._context.socket(zmq.SUB)
529 583 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
530 584 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
531 585 connect_socket(self._iopub_socket, content.iopub)
532 586 self._update_engines(dict(content.engines))
533 587 else:
534 588 self._connected = False
535 589 raise Exception("Failed to connect!")
536 590
537 591 #--------------------------------------------------------------------------
538 592 # handlers and callbacks for incoming messages
539 593 #--------------------------------------------------------------------------
540 594
541 595 def _unwrap_exception(self, content):
542 596 """unwrap exception, and remap engine_id to int."""
543 597 e = error.unwrap_exception(content)
544 598 # print e.traceback
545 599 if e.engine_info:
546 600 e_uuid = e.engine_info['engine_uuid']
547 601 eid = self._engines[e_uuid]
548 602 e.engine_info['engine_id'] = eid
549 603 return e
550 604
551 605 def _extract_metadata(self, header, parent, content):
552 606 md = {'msg_id' : parent['msg_id'],
553 607 'received' : datetime.now(),
554 608 'engine_uuid' : header.get('engine', None),
555 609 'follow' : parent.get('follow', []),
556 610 'after' : parent.get('after', []),
557 611 'status' : content['status'],
558 612 }
559 613
560 614 if md['engine_uuid'] is not None:
561 615 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
562 616
563 617 if 'date' in parent:
564 618 md['submitted'] = parent['date']
565 619 if 'started' in header:
566 620 md['started'] = header['started']
567 621 if 'date' in header:
568 622 md['completed'] = header['date']
569 623 return md
570 624
571 625 def _register_engine(self, msg):
572 626 """Register a new engine, and update our connection info."""
573 627 content = msg['content']
574 628 eid = content['id']
575 629 d = {eid : content['queue']}
576 630 self._update_engines(d)
577 631
578 632 def _unregister_engine(self, msg):
579 633 """Unregister an engine that has died."""
580 634 content = msg['content']
581 635 eid = int(content['id'])
582 636 if eid in self._ids:
583 637 self._ids.remove(eid)
584 638 uuid = self._engines.pop(eid)
585 639
586 640 self._handle_stranded_msgs(eid, uuid)
587 641
588 642 if self._task_socket and self._task_scheme == 'pure':
589 643 self._stop_scheduling_tasks()
590 644
591 645 def _handle_stranded_msgs(self, eid, uuid):
592 646 """Handle messages known to be on an engine when the engine unregisters.
593 647
594 648 It is possible that this will fire prematurely - that is, an engine will
595 649 go down after completing a result, and the client will be notified
596 650 of the unregistration and later receive the successful result.
597 651 """
598 652
599 653 outstanding = self._outstanding_dict[uuid]
600 654
601 655 for msg_id in list(outstanding):
602 656 if msg_id in self.results:
603 657 # we already
604 658 continue
605 659 try:
606 660 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
607 661 except:
608 662 content = error.wrap_exception()
609 663 # build a fake message:
610 664 parent = {}
611 665 header = {}
612 666 parent['msg_id'] = msg_id
613 667 header['engine'] = uuid
614 668 header['date'] = datetime.now()
615 669 msg = dict(parent_header=parent, header=header, content=content)
616 670 self._handle_apply_reply(msg)
617 671
618 672 def _handle_execute_reply(self, msg):
619 673 """Save the reply to an execute_request into our results.
620 674
621 675 execute messages are never actually used. apply is used instead.
622 676 """
623 677
624 678 parent = msg['parent_header']
625 679 msg_id = parent['msg_id']
626 680 if msg_id not in self.outstanding:
627 681 if msg_id in self.history:
628 682 print ("got stale result: %s"%msg_id)
629 683 else:
630 684 print ("got unknown result: %s"%msg_id)
631 685 else:
632 686 self.outstanding.remove(msg_id)
633 687
634 688 content = msg['content']
635 689 header = msg['header']
636 690
637 691 # construct metadata:
638 692 md = self.metadata[msg_id]
639 693 md.update(self._extract_metadata(header, parent, content))
640 694 # is this redundant?
641 695 self.metadata[msg_id] = md
642 696
643 697 e_outstanding = self._outstanding_dict[md['engine_uuid']]
644 698 if msg_id in e_outstanding:
645 699 e_outstanding.remove(msg_id)
646 700
647 701 # construct result:
648 702 if content['status'] == 'ok':
649 self.results[msg_id] = content
703 self.results[msg_id] = ExecuteReply(msg_id, content, md)
650 704 elif content['status'] == 'aborted':
651 705 self.results[msg_id] = error.TaskAborted(msg_id)
652 706 elif content['status'] == 'resubmitted':
653 707 # TODO: handle resubmission
654 708 pass
655 709 else:
656 710 self.results[msg_id] = self._unwrap_exception(content)
657 711
658 712 def _handle_apply_reply(self, msg):
659 713 """Save the reply to an apply_request into our results."""
660 714 parent = msg['parent_header']
661 715 msg_id = parent['msg_id']
662 716 if msg_id not in self.outstanding:
663 717 if msg_id in self.history:
664 718 print ("got stale result: %s"%msg_id)
665 719 print self.results[msg_id]
666 720 print msg
667 721 else:
668 722 print ("got unknown result: %s"%msg_id)
669 723 else:
670 724 self.outstanding.remove(msg_id)
671 725 content = msg['content']
672 726 header = msg['header']
673 727
674 728 # construct metadata:
675 729 md = self.metadata[msg_id]
676 730 md.update(self._extract_metadata(header, parent, content))
677 731 # is this redundant?
678 732 self.metadata[msg_id] = md
679 733
680 734 e_outstanding = self._outstanding_dict[md['engine_uuid']]
681 735 if msg_id in e_outstanding:
682 736 e_outstanding.remove(msg_id)
683 737
684 738 # construct result:
685 739 if content['status'] == 'ok':
686 740 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
687 741 elif content['status'] == 'aborted':
688 742 self.results[msg_id] = error.TaskAborted(msg_id)
689 743 elif content['status'] == 'resubmitted':
690 744 # TODO: handle resubmission
691 745 pass
692 746 else:
693 747 self.results[msg_id] = self._unwrap_exception(content)
694 748
695 749 def _flush_notifications(self):
696 750 """Flush notifications of engine registrations waiting
697 751 in ZMQ queue."""
698 752 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
699 753 while msg is not None:
700 754 if self.debug:
701 755 pprint(msg)
702 756 msg_type = msg['header']['msg_type']
703 757 handler = self._notification_handlers.get(msg_type, None)
704 758 if handler is None:
705 759 raise Exception("Unhandled message type: %s"%msg.msg_type)
706 760 else:
707 761 handler(msg)
708 762 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
709 763
710 764 def _flush_results(self, sock):
711 765 """Flush task or queue results waiting in ZMQ queue."""
712 766 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
713 767 while msg is not None:
714 768 if self.debug:
715 769 pprint(msg)
716 770 msg_type = msg['header']['msg_type']
717 771 handler = self._queue_handlers.get(msg_type, None)
718 772 if handler is None:
719 773 raise Exception("Unhandled message type: %s"%msg.msg_type)
720 774 else:
721 775 handler(msg)
722 776 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
723 777
724 778 def _flush_control(self, sock):
725 779 """Flush replies from the control channel waiting
726 780 in the ZMQ queue.
727 781
728 782 Currently: ignore them."""
729 783 if self._ignored_control_replies <= 0:
730 784 return
731 785 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
732 786 while msg is not None:
733 787 self._ignored_control_replies -= 1
734 788 if self.debug:
735 789 pprint(msg)
736 790 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
737 791
738 792 def _flush_ignored_control(self):
739 793 """flush ignored control replies"""
740 794 while self._ignored_control_replies > 0:
741 795 self.session.recv(self._control_socket)
742 796 self._ignored_control_replies -= 1
743 797
744 798 def _flush_ignored_hub_replies(self):
745 799 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
746 800 while msg is not None:
747 801 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
748 802
749 803 def _flush_iopub(self, sock):
750 804 """Flush replies from the iopub channel waiting
751 805 in the ZMQ queue.
752 806 """
753 807 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
754 808 while msg is not None:
755 809 if self.debug:
756 810 pprint(msg)
757 811 parent = msg['parent_header']
758 812 # ignore IOPub messages with no parent.
759 813 # Caused by print statements or warnings from before the first execution.
760 814 if not parent:
761 815 continue
762 816 msg_id = parent['msg_id']
763 817 content = msg['content']
764 818 header = msg['header']
765 819 msg_type = msg['header']['msg_type']
766 820
767 821 # init metadata:
768 822 md = self.metadata[msg_id]
769 823
770 824 if msg_type == 'stream':
771 825 name = content['name']
772 826 s = md[name] or ''
773 827 md[name] = s + content['data']
774 828 elif msg_type == 'pyerr':
775 829 md.update({'pyerr' : self._unwrap_exception(content)})
776 830 elif msg_type == 'pyin':
777 831 md.update({'pyin' : content['code']})
778 832 elif msg_type == 'display_data':
779 833 md['outputs'].append(content.get('data'))
780 834 elif msg_type == 'pyout':
781 835 md['pyout'] = content.get('data')
782 836 else:
783 837 # unhandled msg_type (status, etc.)
784 838 pass
785 839
786 840 # reduntant?
787 841 self.metadata[msg_id] = md
788 842
789 843 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
790 844
791 845 #--------------------------------------------------------------------------
792 846 # len, getitem
793 847 #--------------------------------------------------------------------------
794 848
795 849 def __len__(self):
796 850 """len(client) returns # of engines."""
797 851 return len(self.ids)
798 852
799 853 def __getitem__(self, key):
800 854 """index access returns DirectView multiplexer objects
801 855
802 856 Must be int, slice, or list/tuple/xrange of ints"""
803 857 if not isinstance(key, (int, slice, tuple, list, xrange)):
804 858 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
805 859 else:
806 860 return self.direct_view(key)
807 861
808 862 #--------------------------------------------------------------------------
809 863 # Begin public methods
810 864 #--------------------------------------------------------------------------
811 865
812 866 @property
813 867 def ids(self):
814 868 """Always up-to-date ids property."""
815 869 self._flush_notifications()
816 870 # always copy:
817 871 return list(self._ids)
818 872
819 873 def close(self):
820 874 if self._closed:
821 875 return
822 876 self.stop_spin_thread()
823 877 snames = filter(lambda n: n.endswith('socket'), dir(self))
824 878 for socket in map(lambda name: getattr(self, name), snames):
825 879 if isinstance(socket, zmq.Socket) and not socket.closed:
826 880 socket.close()
827 881 self._closed = True
828 882
829 883 def _spin_every(self, interval=1):
830 884 """target func for use in spin_thread"""
831 885 while True:
832 886 if self._stop_spinning.is_set():
833 887 return
834 888 time.sleep(interval)
835 889 self.spin()
836 890
837 891 def spin_thread(self, interval=1):
838 892 """call Client.spin() in a background thread on some regular interval
839 893
840 894 This helps ensure that messages don't pile up too much in the zmq queue
841 895 while you are working on other things, or just leaving an idle terminal.
842 896
843 897 It also helps limit potential padding of the `received` timestamp
844 898 on AsyncResult objects, used for timings.
845 899
846 900 Parameters
847 901 ----------
848 902
849 903 interval : float, optional
850 904 The interval on which to spin the client in the background thread
851 905 (simply passed to time.sleep).
852 906
853 907 Notes
854 908 -----
855 909
856 910 For precision timing, you may want to use this method to put a bound
857 911 on the jitter (in seconds) in `received` timestamps used
858 912 in AsyncResult.wall_time.
859 913
860 914 """
861 915 if self._spin_thread is not None:
862 916 self.stop_spin_thread()
863 917 self._stop_spinning.clear()
864 918 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
865 919 self._spin_thread.daemon = True
866 920 self._spin_thread.start()
867 921
868 922 def stop_spin_thread(self):
869 923 """stop background spin_thread, if any"""
870 924 if self._spin_thread is not None:
871 925 self._stop_spinning.set()
872 926 self._spin_thread.join()
873 927 self._spin_thread = None
874 928
875 929 def spin(self):
876 930 """Flush any registration notifications and execution results
877 931 waiting in the ZMQ queue.
878 932 """
879 933 if self._notification_socket:
880 934 self._flush_notifications()
935 if self._iopub_socket:
936 self._flush_iopub(self._iopub_socket)
881 937 if self._mux_socket:
882 938 self._flush_results(self._mux_socket)
883 939 if self._task_socket:
884 940 self._flush_results(self._task_socket)
885 941 if self._control_socket:
886 942 self._flush_control(self._control_socket)
887 if self._iopub_socket:
888 self._flush_iopub(self._iopub_socket)
889 943 if self._query_socket:
890 944 self._flush_ignored_hub_replies()
891 945
892 946 def wait(self, jobs=None, timeout=-1):
893 947 """waits on one or more `jobs`, for up to `timeout` seconds.
894 948
895 949 Parameters
896 950 ----------
897 951
898 952 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
899 953 ints are indices to self.history
900 954 strs are msg_ids
901 955 default: wait on all outstanding messages
902 956 timeout : float
903 957 a time in seconds, after which to give up.
904 958 default is -1, which means no timeout
905 959
906 960 Returns
907 961 -------
908 962
909 963 True : when all msg_ids are done
910 964 False : timeout reached, some msg_ids still outstanding
911 965 """
912 966 tic = time.time()
913 967 if jobs is None:
914 968 theids = self.outstanding
915 969 else:
916 970 if isinstance(jobs, (int, basestring, AsyncResult)):
917 971 jobs = [jobs]
918 972 theids = set()
919 973 for job in jobs:
920 974 if isinstance(job, int):
921 975 # index access
922 976 job = self.history[job]
923 977 elif isinstance(job, AsyncResult):
924 978 map(theids.add, job.msg_ids)
925 979 continue
926 980 theids.add(job)
927 981 if not theids.intersection(self.outstanding):
928 982 return True
929 983 self.spin()
930 984 while theids.intersection(self.outstanding):
931 985 if timeout >= 0 and ( time.time()-tic ) > timeout:
932 986 break
933 987 time.sleep(1e-3)
934 988 self.spin()
935 989 return len(theids.intersection(self.outstanding)) == 0
936 990
937 991 #--------------------------------------------------------------------------
938 992 # Control methods
939 993 #--------------------------------------------------------------------------
940 994
941 995 @spin_first
942 996 def clear(self, targets=None, block=None):
943 997 """Clear the namespace in target(s)."""
944 998 block = self.block if block is None else block
945 999 targets = self._build_targets(targets)[0]
946 1000 for t in targets:
947 1001 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
948 1002 error = False
949 1003 if block:
950 1004 self._flush_ignored_control()
951 1005 for i in range(len(targets)):
952 1006 idents,msg = self.session.recv(self._control_socket,0)
953 1007 if self.debug:
954 1008 pprint(msg)
955 1009 if msg['content']['status'] != 'ok':
956 1010 error = self._unwrap_exception(msg['content'])
957 1011 else:
958 1012 self._ignored_control_replies += len(targets)
959 1013 if error:
960 1014 raise error
961 1015
962 1016
963 1017 @spin_first
964 1018 def abort(self, jobs=None, targets=None, block=None):
965 1019 """Abort specific jobs from the execution queues of target(s).
966 1020
967 1021 This is a mechanism to prevent jobs that have already been submitted
968 1022 from executing.
969 1023
970 1024 Parameters
971 1025 ----------
972 1026
973 1027 jobs : msg_id, list of msg_ids, or AsyncResult
974 1028 The jobs to be aborted
975 1029
976 1030 If unspecified/None: abort all outstanding jobs.
977 1031
978 1032 """
979 1033 block = self.block if block is None else block
980 1034 jobs = jobs if jobs is not None else list(self.outstanding)
981 1035 targets = self._build_targets(targets)[0]
982 1036
983 1037 msg_ids = []
984 1038 if isinstance(jobs, (basestring,AsyncResult)):
985 1039 jobs = [jobs]
986 1040 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
987 1041 if bad_ids:
988 1042 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
989 1043 for j in jobs:
990 1044 if isinstance(j, AsyncResult):
991 1045 msg_ids.extend(j.msg_ids)
992 1046 else:
993 1047 msg_ids.append(j)
994 1048 content = dict(msg_ids=msg_ids)
995 1049 for t in targets:
996 1050 self.session.send(self._control_socket, 'abort_request',
997 1051 content=content, ident=t)
998 1052 error = False
999 1053 if block:
1000 1054 self._flush_ignored_control()
1001 1055 for i in range(len(targets)):
1002 1056 idents,msg = self.session.recv(self._control_socket,0)
1003 1057 if self.debug:
1004 1058 pprint(msg)
1005 1059 if msg['content']['status'] != 'ok':
1006 1060 error = self._unwrap_exception(msg['content'])
1007 1061 else:
1008 1062 self._ignored_control_replies += len(targets)
1009 1063 if error:
1010 1064 raise error
1011 1065
1012 1066 @spin_first
1013 1067 def shutdown(self, targets=None, restart=False, hub=False, block=None):
1014 1068 """Terminates one or more engine processes, optionally including the hub."""
1015 1069 block = self.block if block is None else block
1016 1070 if hub:
1017 1071 targets = 'all'
1018 1072 targets = self._build_targets(targets)[0]
1019 1073 for t in targets:
1020 1074 self.session.send(self._control_socket, 'shutdown_request',
1021 1075 content={'restart':restart},ident=t)
1022 1076 error = False
1023 1077 if block or hub:
1024 1078 self._flush_ignored_control()
1025 1079 for i in range(len(targets)):
1026 1080 idents,msg = self.session.recv(self._control_socket, 0)
1027 1081 if self.debug:
1028 1082 pprint(msg)
1029 1083 if msg['content']['status'] != 'ok':
1030 1084 error = self._unwrap_exception(msg['content'])
1031 1085 else:
1032 1086 self._ignored_control_replies += len(targets)
1033 1087
1034 1088 if hub:
1035 1089 time.sleep(0.25)
1036 1090 self.session.send(self._query_socket, 'shutdown_request')
1037 1091 idents,msg = self.session.recv(self._query_socket, 0)
1038 1092 if self.debug:
1039 1093 pprint(msg)
1040 1094 if msg['content']['status'] != 'ok':
1041 1095 error = self._unwrap_exception(msg['content'])
1042 1096
1043 1097 if error:
1044 1098 raise error
1045 1099
1046 1100 #--------------------------------------------------------------------------
1047 1101 # Execution related methods
1048 1102 #--------------------------------------------------------------------------
1049 1103
1050 1104 def _maybe_raise(self, result):
1051 1105 """wrapper for maybe raising an exception if apply failed."""
1052 1106 if isinstance(result, error.RemoteError):
1053 1107 raise result
1054 1108
1055 1109 return result
1056 1110
1057 1111 def send_apply_request(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
1058 1112 ident=None):
1059 1113 """construct and send an apply message via a socket.
1060 1114
1061 1115 This is the principal method with which all engine execution is performed by views.
1062 1116 """
1063 1117
1064 1118 assert not self._closed, "cannot use me anymore, I'm closed!"
1065 1119 # defaults:
1066 1120 args = args if args is not None else []
1067 1121 kwargs = kwargs if kwargs is not None else {}
1068 1122 subheader = subheader if subheader is not None else {}
1069 1123
1070 1124 # validate arguments
1071 1125 if not callable(f) and not isinstance(f, Reference):
1072 1126 raise TypeError("f must be callable, not %s"%type(f))
1073 1127 if not isinstance(args, (tuple, list)):
1074 1128 raise TypeError("args must be tuple or list, not %s"%type(args))
1075 1129 if not isinstance(kwargs, dict):
1076 1130 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1077 1131 if not isinstance(subheader, dict):
1078 1132 raise TypeError("subheader must be dict, not %s"%type(subheader))
1079 1133
1080 1134 bufs = util.pack_apply_message(f,args,kwargs)
1081 1135
1082 1136 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1083 1137 subheader=subheader, track=track)
1084 1138
1085 1139 msg_id = msg['header']['msg_id']
1086 1140 self.outstanding.add(msg_id)
1087 1141 if ident:
1088 1142 # possibly routed to a specific engine
1089 1143 if isinstance(ident, list):
1090 1144 ident = ident[-1]
1091 1145 if ident in self._engines.values():
1092 1146 # save for later, in case of engine death
1093 1147 self._outstanding_dict[ident].add(msg_id)
1094 1148 self.history.append(msg_id)
1095 1149 self.metadata[msg_id]['submitted'] = datetime.now()
1096 1150
1097 1151 return msg
1098 1152
1099 1153 def send_execute_request(self, socket, code, silent=True, subheader=None, ident=None):
1100 1154 """construct and send an execute request via a socket.
1101 1155
1102 1156 """
1103 1157
1104 1158 assert not self._closed, "cannot use me anymore, I'm closed!"
1105 1159 # defaults:
1106 1160 subheader = subheader if subheader is not None else {}
1107 1161
1108 1162 # validate arguments
1109 1163 if not isinstance(code, basestring):
1110 1164 raise TypeError("code must be text, not %s" % type(code))
1111 1165 if not isinstance(subheader, dict):
1112 1166 raise TypeError("subheader must be dict, not %s" % type(subheader))
1113 1167
1114 1168 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1115 1169
1116 1170
1117 1171 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1118 1172 subheader=subheader)
1119 1173
1120 1174 msg_id = msg['header']['msg_id']
1121 1175 self.outstanding.add(msg_id)
1122 1176 if ident:
1123 1177 # possibly routed to a specific engine
1124 1178 if isinstance(ident, list):
1125 1179 ident = ident[-1]
1126 1180 if ident in self._engines.values():
1127 1181 # save for later, in case of engine death
1128 1182 self._outstanding_dict[ident].add(msg_id)
1129 1183 self.history.append(msg_id)
1130 1184 self.metadata[msg_id]['submitted'] = datetime.now()
1131 1185
1132 1186 return msg
1133 1187
1134 1188 #--------------------------------------------------------------------------
1135 1189 # construct a View object
1136 1190 #--------------------------------------------------------------------------
1137 1191
1138 1192 def load_balanced_view(self, targets=None):
1139 1193 """construct a DirectView object.
1140 1194
1141 1195 If no arguments are specified, create a LoadBalancedView
1142 1196 using all engines.
1143 1197
1144 1198 Parameters
1145 1199 ----------
1146 1200
1147 1201 targets: list,slice,int,etc. [default: use all engines]
1148 1202 The subset of engines across which to load-balance
1149 1203 """
1150 1204 if targets == 'all':
1151 1205 targets = None
1152 1206 if targets is not None:
1153 1207 targets = self._build_targets(targets)[1]
1154 1208 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1155 1209
1156 1210 def direct_view(self, targets='all'):
1157 1211 """construct a DirectView object.
1158 1212
1159 1213 If no targets are specified, create a DirectView using all engines.
1160 1214
1161 1215 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1162 1216 evaluate the target engines at each execution, whereas rc[:] will connect to
1163 1217 all *current* engines, and that list will not change.
1164 1218
1165 1219 That is, 'all' will always use all engines, whereas rc[:] will not use
1166 1220 engines added after the DirectView is constructed.
1167 1221
1168 1222 Parameters
1169 1223 ----------
1170 1224
1171 1225 targets: list,slice,int,etc. [default: use all engines]
1172 1226 The engines to use for the View
1173 1227 """
1174 1228 single = isinstance(targets, int)
1175 1229 # allow 'all' to be lazily evaluated at each execution
1176 1230 if targets != 'all':
1177 1231 targets = self._build_targets(targets)[1]
1178 1232 if single:
1179 1233 targets = targets[0]
1180 1234 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1181 1235
1182 1236 #--------------------------------------------------------------------------
1183 1237 # Query methods
1184 1238 #--------------------------------------------------------------------------
1185 1239
1186 1240 @spin_first
1187 1241 def get_result(self, indices_or_msg_ids=None, block=None):
1188 1242 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1189 1243
1190 1244 If the client already has the results, no request to the Hub will be made.
1191 1245
1192 1246 This is a convenient way to construct AsyncResult objects, which are wrappers
1193 1247 that include metadata about execution, and allow for awaiting results that
1194 1248 were not submitted by this Client.
1195 1249
1196 1250 It can also be a convenient way to retrieve the metadata associated with
1197 1251 blocking execution, since it always retrieves
1198 1252
1199 1253 Examples
1200 1254 --------
1201 1255 ::
1202 1256
1203 1257 In [10]: r = client.apply()
1204 1258
1205 1259 Parameters
1206 1260 ----------
1207 1261
1208 1262 indices_or_msg_ids : integer history index, str msg_id, or list of either
1209 1263 The indices or msg_ids of indices to be retrieved
1210 1264
1211 1265 block : bool
1212 1266 Whether to wait for the result to be done
1213 1267
1214 1268 Returns
1215 1269 -------
1216 1270
1217 1271 AsyncResult
1218 1272 A single AsyncResult object will always be returned.
1219 1273
1220 1274 AsyncHubResult
1221 1275 A subclass of AsyncResult that retrieves results from the Hub
1222 1276
1223 1277 """
1224 1278 block = self.block if block is None else block
1225 1279 if indices_or_msg_ids is None:
1226 1280 indices_or_msg_ids = -1
1227 1281
1228 1282 if not isinstance(indices_or_msg_ids, (list,tuple)):
1229 1283 indices_or_msg_ids = [indices_or_msg_ids]
1230 1284
1231 1285 theids = []
1232 1286 for id in indices_or_msg_ids:
1233 1287 if isinstance(id, int):
1234 1288 id = self.history[id]
1235 1289 if not isinstance(id, basestring):
1236 1290 raise TypeError("indices must be str or int, not %r"%id)
1237 1291 theids.append(id)
1238 1292
1239 1293 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1240 1294 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1241 1295
1242 1296 if remote_ids:
1243 1297 ar = AsyncHubResult(self, msg_ids=theids)
1244 1298 else:
1245 1299 ar = AsyncResult(self, msg_ids=theids)
1246 1300
1247 1301 if block:
1248 1302 ar.wait()
1249 1303
1250 1304 return ar
1251 1305
1252 1306 @spin_first
1253 1307 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1254 1308 """Resubmit one or more tasks.
1255 1309
1256 1310 in-flight tasks may not be resubmitted.
1257 1311
1258 1312 Parameters
1259 1313 ----------
1260 1314
1261 1315 indices_or_msg_ids : integer history index, str msg_id, or list of either
1262 1316 The indices or msg_ids of indices to be retrieved
1263 1317
1264 1318 block : bool
1265 1319 Whether to wait for the result to be done
1266 1320
1267 1321 Returns
1268 1322 -------
1269 1323
1270 1324 AsyncHubResult
1271 1325 A subclass of AsyncResult that retrieves results from the Hub
1272 1326
1273 1327 """
1274 1328 block = self.block if block is None else block
1275 1329 if indices_or_msg_ids is None:
1276 1330 indices_or_msg_ids = -1
1277 1331
1278 1332 if not isinstance(indices_or_msg_ids, (list,tuple)):
1279 1333 indices_or_msg_ids = [indices_or_msg_ids]
1280 1334
1281 1335 theids = []
1282 1336 for id in indices_or_msg_ids:
1283 1337 if isinstance(id, int):
1284 1338 id = self.history[id]
1285 1339 if not isinstance(id, basestring):
1286 1340 raise TypeError("indices must be str or int, not %r"%id)
1287 1341 theids.append(id)
1288 1342
1289 1343 content = dict(msg_ids = theids)
1290 1344
1291 1345 self.session.send(self._query_socket, 'resubmit_request', content)
1292 1346
1293 1347 zmq.select([self._query_socket], [], [])
1294 1348 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1295 1349 if self.debug:
1296 1350 pprint(msg)
1297 1351 content = msg['content']
1298 1352 if content['status'] != 'ok':
1299 1353 raise self._unwrap_exception(content)
1300 1354 mapping = content['resubmitted']
1301 1355 new_ids = [ mapping[msg_id] for msg_id in theids ]
1302 1356
1303 1357 ar = AsyncHubResult(self, msg_ids=new_ids)
1304 1358
1305 1359 if block:
1306 1360 ar.wait()
1307 1361
1308 1362 return ar
1309 1363
1310 1364 @spin_first
1311 1365 def result_status(self, msg_ids, status_only=True):
1312 1366 """Check on the status of the result(s) of the apply request with `msg_ids`.
1313 1367
1314 1368 If status_only is False, then the actual results will be retrieved, else
1315 1369 only the status of the results will be checked.
1316 1370
1317 1371 Parameters
1318 1372 ----------
1319 1373
1320 1374 msg_ids : list of msg_ids
1321 1375 if int:
1322 1376 Passed as index to self.history for convenience.
1323 1377 status_only : bool (default: True)
1324 1378 if False:
1325 1379 Retrieve the actual results of completed tasks.
1326 1380
1327 1381 Returns
1328 1382 -------
1329 1383
1330 1384 results : dict
1331 1385 There will always be the keys 'pending' and 'completed', which will
1332 1386 be lists of msg_ids that are incomplete or complete. If `status_only`
1333 1387 is False, then completed results will be keyed by their `msg_id`.
1334 1388 """
1335 1389 if not isinstance(msg_ids, (list,tuple)):
1336 1390 msg_ids = [msg_ids]
1337 1391
1338 1392 theids = []
1339 1393 for msg_id in msg_ids:
1340 1394 if isinstance(msg_id, int):
1341 1395 msg_id = self.history[msg_id]
1342 1396 if not isinstance(msg_id, basestring):
1343 1397 raise TypeError("msg_ids must be str, not %r"%msg_id)
1344 1398 theids.append(msg_id)
1345 1399
1346 1400 completed = []
1347 1401 local_results = {}
1348 1402
1349 1403 # comment this block out to temporarily disable local shortcut:
1350 1404 for msg_id in theids:
1351 1405 if msg_id in self.results:
1352 1406 completed.append(msg_id)
1353 1407 local_results[msg_id] = self.results[msg_id]
1354 1408 theids.remove(msg_id)
1355 1409
1356 1410 if theids: # some not locally cached
1357 1411 content = dict(msg_ids=theids, status_only=status_only)
1358 1412 msg = self.session.send(self._query_socket, "result_request", content=content)
1359 1413 zmq.select([self._query_socket], [], [])
1360 1414 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1361 1415 if self.debug:
1362 1416 pprint(msg)
1363 1417 content = msg['content']
1364 1418 if content['status'] != 'ok':
1365 1419 raise self._unwrap_exception(content)
1366 1420 buffers = msg['buffers']
1367 1421 else:
1368 1422 content = dict(completed=[],pending=[])
1369 1423
1370 1424 content['completed'].extend(completed)
1371 1425
1372 1426 if status_only:
1373 1427 return content
1374 1428
1375 1429 failures = []
1376 1430 # load cached results into result:
1377 1431 content.update(local_results)
1378 1432
1379 1433 # update cache with results:
1380 1434 for msg_id in sorted(theids):
1381 1435 if msg_id in content['completed']:
1382 1436 rec = content[msg_id]
1383 1437 parent = rec['header']
1384 1438 header = rec['result_header']
1385 1439 rcontent = rec['result_content']
1386 1440 iodict = rec['io']
1387 1441 if isinstance(rcontent, str):
1388 1442 rcontent = self.session.unpack(rcontent)
1389 1443
1390 1444 md = self.metadata[msg_id]
1391 1445 md.update(self._extract_metadata(header, parent, rcontent))
1392 1446 if rec.get('received'):
1393 1447 md['received'] = rec['received']
1394 1448 md.update(iodict)
1395 1449
1396 1450 if rcontent['status'] == 'ok':
1397 1451 res,buffers = util.unserialize_object(buffers)
1398 1452 else:
1399 1453 print rcontent
1400 1454 res = self._unwrap_exception(rcontent)
1401 1455 failures.append(res)
1402 1456
1403 1457 self.results[msg_id] = res
1404 1458 content[msg_id] = res
1405 1459
1406 1460 if len(theids) == 1 and failures:
1407 1461 raise failures[0]
1408 1462
1409 1463 error.collect_exceptions(failures, "result_status")
1410 1464 return content
1411 1465
1412 1466 @spin_first
1413 1467 def queue_status(self, targets='all', verbose=False):
1414 1468 """Fetch the status of engine queues.
1415 1469
1416 1470 Parameters
1417 1471 ----------
1418 1472
1419 1473 targets : int/str/list of ints/strs
1420 1474 the engines whose states are to be queried.
1421 1475 default : all
1422 1476 verbose : bool
1423 1477 Whether to return lengths only, or lists of ids for each element
1424 1478 """
1425 1479 if targets == 'all':
1426 1480 # allow 'all' to be evaluated on the engine
1427 1481 engine_ids = None
1428 1482 else:
1429 1483 engine_ids = self._build_targets(targets)[1]
1430 1484 content = dict(targets=engine_ids, verbose=verbose)
1431 1485 self.session.send(self._query_socket, "queue_request", content=content)
1432 1486 idents,msg = self.session.recv(self._query_socket, 0)
1433 1487 if self.debug:
1434 1488 pprint(msg)
1435 1489 content = msg['content']
1436 1490 status = content.pop('status')
1437 1491 if status != 'ok':
1438 1492 raise self._unwrap_exception(content)
1439 1493 content = rekey(content)
1440 1494 if isinstance(targets, int):
1441 1495 return content[targets]
1442 1496 else:
1443 1497 return content
1444 1498
1445 1499 @spin_first
1446 1500 def purge_results(self, jobs=[], targets=[]):
1447 1501 """Tell the Hub to forget results.
1448 1502
1449 1503 Individual results can be purged by msg_id, or the entire
1450 1504 history of specific targets can be purged.
1451 1505
1452 1506 Use `purge_results('all')` to scrub everything from the Hub's db.
1453 1507
1454 1508 Parameters
1455 1509 ----------
1456 1510
1457 1511 jobs : str or list of str or AsyncResult objects
1458 1512 the msg_ids whose results should be forgotten.
1459 1513 targets : int/str/list of ints/strs
1460 1514 The targets, by int_id, whose entire history is to be purged.
1461 1515
1462 1516 default : None
1463 1517 """
1464 1518 if not targets and not jobs:
1465 1519 raise ValueError("Must specify at least one of `targets` and `jobs`")
1466 1520 if targets:
1467 1521 targets = self._build_targets(targets)[1]
1468 1522
1469 1523 # construct msg_ids from jobs
1470 1524 if jobs == 'all':
1471 1525 msg_ids = jobs
1472 1526 else:
1473 1527 msg_ids = []
1474 1528 if isinstance(jobs, (basestring,AsyncResult)):
1475 1529 jobs = [jobs]
1476 1530 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1477 1531 if bad_ids:
1478 1532 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1479 1533 for j in jobs:
1480 1534 if isinstance(j, AsyncResult):
1481 1535 msg_ids.extend(j.msg_ids)
1482 1536 else:
1483 1537 msg_ids.append(j)
1484 1538
1485 1539 content = dict(engine_ids=targets, msg_ids=msg_ids)
1486 1540 self.session.send(self._query_socket, "purge_request", content=content)
1487 1541 idents, msg = self.session.recv(self._query_socket, 0)
1488 1542 if self.debug:
1489 1543 pprint(msg)
1490 1544 content = msg['content']
1491 1545 if content['status'] != 'ok':
1492 1546 raise self._unwrap_exception(content)
1493 1547
1494 1548 @spin_first
1495 1549 def hub_history(self):
1496 1550 """Get the Hub's history
1497 1551
1498 1552 Just like the Client, the Hub has a history, which is a list of msg_ids.
1499 1553 This will contain the history of all clients, and, depending on configuration,
1500 1554 may contain history across multiple cluster sessions.
1501 1555
1502 1556 Any msg_id returned here is a valid argument to `get_result`.
1503 1557
1504 1558 Returns
1505 1559 -------
1506 1560
1507 1561 msg_ids : list of strs
1508 1562 list of all msg_ids, ordered by task submission time.
1509 1563 """
1510 1564
1511 1565 self.session.send(self._query_socket, "history_request", content={})
1512 1566 idents, msg = self.session.recv(self._query_socket, 0)
1513 1567
1514 1568 if self.debug:
1515 1569 pprint(msg)
1516 1570 content = msg['content']
1517 1571 if content['status'] != 'ok':
1518 1572 raise self._unwrap_exception(content)
1519 1573 else:
1520 1574 return content['history']
1521 1575
1522 1576 @spin_first
1523 1577 def db_query(self, query, keys=None):
1524 1578 """Query the Hub's TaskRecord database
1525 1579
1526 1580 This will return a list of task record dicts that match `query`
1527 1581
1528 1582 Parameters
1529 1583 ----------
1530 1584
1531 1585 query : mongodb query dict
1532 1586 The search dict. See mongodb query docs for details.
1533 1587 keys : list of strs [optional]
1534 1588 The subset of keys to be returned. The default is to fetch everything but buffers.
1535 1589 'msg_id' will *always* be included.
1536 1590 """
1537 1591 if isinstance(keys, basestring):
1538 1592 keys = [keys]
1539 1593 content = dict(query=query, keys=keys)
1540 1594 self.session.send(self._query_socket, "db_request", content=content)
1541 1595 idents, msg = self.session.recv(self._query_socket, 0)
1542 1596 if self.debug:
1543 1597 pprint(msg)
1544 1598 content = msg['content']
1545 1599 if content['status'] != 'ok':
1546 1600 raise self._unwrap_exception(content)
1547 1601
1548 1602 records = content['records']
1549 1603
1550 1604 buffer_lens = content['buffer_lens']
1551 1605 result_buffer_lens = content['result_buffer_lens']
1552 1606 buffers = msg['buffers']
1553 1607 has_bufs = buffer_lens is not None
1554 1608 has_rbufs = result_buffer_lens is not None
1555 1609 for i,rec in enumerate(records):
1556 1610 # relink buffers
1557 1611 if has_bufs:
1558 1612 blen = buffer_lens[i]
1559 1613 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1560 1614 if has_rbufs:
1561 1615 blen = result_buffer_lens[i]
1562 1616 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1563 1617
1564 1618 return records
1565 1619
1566 1620 __all__ = [ 'Client' ]
General Comments 0
You need to be logged in to leave comments. Login now