##// END OF EJS Templates
fix typo
Robert McGibbon -
Show More
@@ -1,1729 +1,1729 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 from threading import Thread, Event
22 22 import time
23 23 import warnings
24 24 from datetime import datetime
25 25 from getpass import getpass
26 26 from pprint import pprint
27 27
28 28 pjoin = os.path.join
29 29
30 30 import zmq
31 31 # from zmq.eventloop import ioloop, zmqstream
32 32
33 33 from IPython.config.configurable import MultipleInstanceError
34 34 from IPython.core.application import BaseIPythonApplication
35 35 from IPython.core.profiledir import ProfileDir, ProfileDirError
36 36
37 37 from IPython.utils.coloransi import TermColors
38 38 from IPython.utils.jsonutil import rekey
39 39 from IPython.utils.localinterfaces import LOCAL_IPS
40 40 from IPython.utils.path import get_ipython_dir
41 41 from IPython.utils.py3compat import cast_bytes
42 42 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
43 43 Dict, List, Bool, Set, Any)
44 44 from IPython.external.decorator import decorator
45 45 from IPython.external.ssh import tunnel
46 46
47 47 from IPython.parallel import Reference
48 48 from IPython.parallel import error
49 49 from IPython.parallel import util
50 50
51 51 from IPython.zmq.session import Session, Message
52 52 from IPython.zmq import serialize
53 53
54 54 from .asyncresult import AsyncResult, AsyncHubResult
55 55 from .view import DirectView, LoadBalancedView
56 56
57 57 if sys.version_info[0] >= 3:
58 58 # xrange is used in a couple 'isinstance' tests in py2
59 59 # should be just 'range' in 3k
60 60 xrange = range
61 61
62 62 #--------------------------------------------------------------------------
63 63 # Decorators for Client methods
64 64 #--------------------------------------------------------------------------
65 65
66 66 @decorator
67 67 def spin_first(f, self, *args, **kwargs):
68 68 """Call spin() to sync state prior to calling the method."""
69 69 self.spin()
70 70 return f(self, *args, **kwargs)
71 71
72 72
73 73 #--------------------------------------------------------------------------
74 74 # Classes
75 75 #--------------------------------------------------------------------------
76 76
77 77
78 78 class ExecuteReply(object):
79 79 """wrapper for finished Execute results"""
80 80 def __init__(self, msg_id, content, metadata):
81 81 self.msg_id = msg_id
82 82 self._content = content
83 83 self.execution_count = content['execution_count']
84 84 self.metadata = metadata
85 85
86 86 def __getitem__(self, key):
87 87 return self.metadata[key]
88 88
89 89 def __getattr__(self, key):
90 90 if key not in self.metadata:
91 91 raise AttributeError(key)
92 92 return self.metadata[key]
93 93
94 94 def __repr__(self):
95 95 pyout = self.metadata['pyout'] or {'data':{}}
96 96 text_out = pyout['data'].get('text/plain', '')
97 97 if len(text_out) > 32:
98 98 text_out = text_out[:29] + '...'
99 99
100 100 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
101 101
102 102 def _repr_pretty_(self, p, cycle):
103 103 pyout = self.metadata['pyout'] or {'data':{}}
104 104 text_out = pyout['data'].get('text/plain', '')
105 105
106 106 if not text_out:
107 107 return
108 108
109 109 try:
110 110 ip = get_ipython()
111 111 except NameError:
112 112 colors = "NoColor"
113 113 else:
114 114 colors = ip.colors
115 115
116 116 if colors == "NoColor":
117 117 out = normal = ""
118 118 else:
119 119 out = TermColors.Red
120 120 normal = TermColors.Normal
121 121
122 122 if '\n' in text_out and not text_out.startswith('\n'):
123 123 # add newline for multiline reprs
124 124 text_out = '\n' + text_out
125 125
126 126 p.text(
127 127 out + u'Out[%i:%i]: ' % (
128 128 self.metadata['engine_id'], self.execution_count
129 129 ) + normal + text_out
130 130 )
131 131
132 132 def _repr_html_(self):
133 133 pyout = self.metadata['pyout'] or {'data':{}}
134 134 return pyout['data'].get("text/html")
135 135
136 136 def _repr_latex_(self):
137 137 pyout = self.metadata['pyout'] or {'data':{}}
138 138 return pyout['data'].get("text/latex")
139 139
140 140 def _repr_json_(self):
141 141 pyout = self.metadata['pyout'] or {'data':{}}
142 142 return pyout['data'].get("application/json")
143 143
144 144 def _repr_javascript_(self):
145 145 pyout = self.metadata['pyout'] or {'data':{}}
146 146 return pyout['data'].get("application/javascript")
147 147
148 148 def _repr_png_(self):
149 149 pyout = self.metadata['pyout'] or {'data':{}}
150 150 return pyout['data'].get("image/png")
151 151
152 152 def _repr_jpeg_(self):
153 153 pyout = self.metadata['pyout'] or {'data':{}}
154 154 return pyout['data'].get("image/jpeg")
155 155
156 156 def _repr_svg_(self):
157 157 pyout = self.metadata['pyout'] or {'data':{}}
158 158 return pyout['data'].get("image/svg+xml")
159 159
160 160
161 161 class Metadata(dict):
162 162 """Subclass of dict for initializing metadata values.
163 163
164 164 Attribute access works on keys.
165 165
166 166 These objects have a strict set of keys - errors will raise if you try
167 167 to add new keys.
168 168 """
169 169 def __init__(self, *args, **kwargs):
170 170 dict.__init__(self)
171 171 md = {'msg_id' : None,
172 172 'submitted' : None,
173 173 'started' : None,
174 174 'completed' : None,
175 175 'received' : None,
176 176 'engine_uuid' : None,
177 177 'engine_id' : None,
178 178 'follow' : None,
179 179 'after' : None,
180 180 'status' : None,
181 181
182 182 'pyin' : None,
183 183 'pyout' : None,
184 184 'pyerr' : None,
185 185 'stdout' : '',
186 186 'stderr' : '',
187 187 'outputs' : [],
188 188 'data': {},
189 189 'outputs_ready' : False,
190 190 }
191 191 self.update(md)
192 192 self.update(dict(*args, **kwargs))
193 193
194 194 def __getattr__(self, key):
195 195 """getattr aliased to getitem"""
196 196 if key in self.iterkeys():
197 197 return self[key]
198 198 else:
199 199 raise AttributeError(key)
200 200
201 201 def __setattr__(self, key, value):
202 202 """setattr aliased to setitem, with strict"""
203 203 if key in self.iterkeys():
204 204 self[key] = value
205 205 else:
206 206 raise AttributeError(key)
207 207
208 208 def __setitem__(self, key, value):
209 209 """strict static key enforcement"""
210 210 if key in self.iterkeys():
211 211 dict.__setitem__(self, key, value)
212 212 else:
213 213 raise KeyError(key)
214 214
215 215
216 216 class Client(HasTraits):
217 217 """A semi-synchronous client to the IPython ZMQ cluster
218 218
219 219 Parameters
220 220 ----------
221 221
222 222 url_file : str/unicode; path to ipcontroller-client.json
223 223 This JSON file should contain all the information needed to connect to a cluster,
224 224 and is likely the only argument needed.
225 225 Connection information for the Hub's registration. If a json connector
226 226 file is given, then likely no further configuration is necessary.
227 227 [Default: use profile]
228 228 profile : bytes
229 229 The name of the Cluster profile to be used to find connector information.
230 230 If run from an IPython application, the default profile will be the same
231 231 as the running application, otherwise it will be 'default'.
232 232 cluster_id : str
233 233 String id to added to runtime files, to prevent name collisions when using
234 234 multiple clusters with a single profile simultaneously.
235 235 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
236 236 Since this is text inserted into filenames, typical recommendations apply:
237 237 Simple character strings are ideal, and spaces are not recommended (but
238 238 should generally work)
239 239 context : zmq.Context
240 240 Pass an existing zmq.Context instance, otherwise the client will create its own.
241 241 debug : bool
242 242 flag for lots of message printing for debug purposes
243 243 timeout : int/float
244 244 time (in seconds) to wait for connection replies from the Hub
245 245 [Default: 10]
246 246
247 247 #-------------- session related args ----------------
248 248
249 249 config : Config object
250 250 If specified, this will be relayed to the Session for configuration
251 251 username : str
252 252 set username for the session object
253 253
254 254 #-------------- ssh related args ----------------
255 255 # These are args for configuring the ssh tunnel to be used
256 256 # credentials are used to forward connections over ssh to the Controller
257 257 # Note that the ip given in `addr` needs to be relative to sshserver
258 258 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
259 259 # and set sshserver as the same machine the Controller is on. However,
260 260 # the only requirement is that sshserver is able to see the Controller
261 261 # (i.e. is within the same trusted network).
262 262
263 263 sshserver : str
264 264 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
265 265 If keyfile or password is specified, and this is not, it will default to
266 266 the ip given in addr.
267 267 sshkey : str; path to ssh private key file
268 268 This specifies a key to be used in ssh login, default None.
269 269 Regular default ssh keys will be used without specifying this argument.
270 270 password : str
271 271 Your ssh password to sshserver. Note that if this is left None,
272 272 you will be prompted for it if passwordless key based login is unavailable.
273 273 paramiko : bool
274 274 flag for whether to use paramiko instead of shell ssh for tunneling.
275 275 [default: True on win32, False else]
276 276
277 277
278 278 Attributes
279 279 ----------
280 280
281 281 ids : list of int engine IDs
282 282 requesting the ids attribute always synchronizes
283 283 the registration state. To request ids without synchronization,
284 284 use semi-private _ids attributes.
285 285
286 286 history : list of msg_ids
287 287 a list of msg_ids, keeping track of all the execution
288 288 messages you have submitted in order.
289 289
290 290 outstanding : set of msg_ids
291 291 a set of msg_ids that have been submitted, but whose
292 292 results have not yet been received.
293 293
294 294 results : dict
295 295 a dict of all our results, keyed by msg_id
296 296
297 297 block : bool
298 298 determines default behavior when block not specified
299 299 in execution methods
300 300
301 301 Methods
302 302 -------
303 303
304 304 spin
305 305 flushes incoming results and registration state changes
306 306 control methods spin, and requesting `ids` also ensures up to date
307 307
308 308 wait
309 309 wait on one or more msg_ids
310 310
311 311 execution methods
312 312 apply
313 313 legacy: execute, run
314 314
315 315 data movement
316 316 push, pull, scatter, gather
317 317
318 318 query methods
319 319 queue_status, get_result, purge, result_status
320 320
321 321 control methods
322 322 abort, shutdown
323 323
324 324 """
325 325
326 326
327 327 block = Bool(False)
328 328 outstanding = Set()
329 329 results = Instance('collections.defaultdict', (dict,))
330 330 metadata = Instance('collections.defaultdict', (Metadata,))
331 331 history = List()
332 332 debug = Bool(False)
333 333 _spin_thread = Any()
334 334 _stop_spinning = Any()
335 335
336 336 profile=Unicode()
337 337 def _profile_default(self):
338 338 if BaseIPythonApplication.initialized():
339 339 # an IPython app *might* be running, try to get its profile
340 340 try:
341 341 return BaseIPythonApplication.instance().profile
342 342 except (AttributeError, MultipleInstanceError):
343 343 # could be a *different* subclass of config.Application,
344 344 # which would raise one of these two errors.
345 345 return u'default'
346 346 else:
347 347 return u'default'
348 348
349 349
350 350 _outstanding_dict = Instance('collections.defaultdict', (set,))
351 351 _ids = List()
352 352 _connected=Bool(False)
353 353 _ssh=Bool(False)
354 354 _context = Instance('zmq.Context')
355 355 _config = Dict()
356 356 _engines=Instance(util.ReverseDict, (), {})
357 357 # _hub_socket=Instance('zmq.Socket')
358 358 _query_socket=Instance('zmq.Socket')
359 359 _control_socket=Instance('zmq.Socket')
360 360 _iopub_socket=Instance('zmq.Socket')
361 361 _notification_socket=Instance('zmq.Socket')
362 362 _mux_socket=Instance('zmq.Socket')
363 363 _task_socket=Instance('zmq.Socket')
364 364 _task_scheme=Unicode()
365 365 _closed = False
366 366 _ignored_control_replies=Integer(0)
367 367 _ignored_hub_replies=Integer(0)
368 368
369 369 def __new__(self, *args, **kw):
370 370 # don't raise on positional args
371 371 return HasTraits.__new__(self, **kw)
372 372
373 373 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
374 374 context=None, debug=False,
375 375 sshserver=None, sshkey=None, password=None, paramiko=None,
376 376 timeout=10, cluster_id=None, **extra_args
377 377 ):
378 378 if profile:
379 379 super(Client, self).__init__(debug=debug, profile=profile)
380 380 else:
381 381 super(Client, self).__init__(debug=debug)
382 382 if context is None:
383 383 context = zmq.Context.instance()
384 384 self._context = context
385 385 self._stop_spinning = Event()
386 386
387 387 if 'url_or_file' in extra_args:
388 388 url_file = extra_args['url_or_file']
389 389 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
390 390
391 391 if url_file and util.is_url(url_file):
392 392 raise ValueError("single urls cannot be specified, url-files must be used.")
393 393
394 394 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
395 395
396 396 if self._cd is not None:
397 397 if url_file is None:
398 398 if not cluster_id:
399 399 client_json = 'ipcontroller-client.json'
400 400 else:
401 401 client_json = 'ipcontroller-%s-client.json' % cluster_id
402 url_or_file = pjoin(self._cd.security_dir, client_json)
402 url_file = pjoin(self._cd.security_dir, client_json)
403 403 if url_file is None:
404 404 raise ValueError(
405 405 "I can't find enough information to connect to a hub!"
406 406 " Please specify at least one of url_file or profile."
407 407 )
408 408
409 409 with open(url_file) as f:
410 410 cfg = json.load(f)
411 411
412 412 self._task_scheme = cfg['task_scheme']
413 413
414 414 # sync defaults from args, json:
415 415 if sshserver:
416 416 cfg['ssh'] = sshserver
417 417
418 418 location = cfg.setdefault('location', None)
419 419
420 420 proto,addr = cfg['interface'].split('://')
421 421 addr = util.disambiguate_ip_address(addr)
422 422 cfg['interface'] = "%s://%s" % (proto, addr)
423 423
424 424 # turn interface,port into full urls:
425 425 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
426 426 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
427 427
428 428 url = cfg['registration']
429 429
430 430 if location is not None and addr == '127.0.0.1':
431 431 # location specified, and connection is expected to be local
432 432 if location not in LOCAL_IPS and not sshserver:
433 433 # load ssh from JSON *only* if the controller is not on
434 434 # this machine
435 435 sshserver=cfg['ssh']
436 436 if location not in LOCAL_IPS and not sshserver:
437 437 # warn if no ssh specified, but SSH is probably needed
438 438 # This is only a warning, because the most likely cause
439 439 # is a local Controller on a laptop whose IP is dynamic
440 440 warnings.warn("""
441 441 Controller appears to be listening on localhost, but not on this machine.
442 442 If this is true, you should specify Client(...,sshserver='you@%s')
443 443 or instruct your controller to listen on an external IP."""%location,
444 444 RuntimeWarning)
445 445 elif not sshserver:
446 446 # otherwise sync with cfg
447 447 sshserver = cfg['ssh']
448 448
449 449 self._config = cfg
450 450
451 451 self._ssh = bool(sshserver or sshkey or password)
452 452 if self._ssh and sshserver is None:
453 453 # default to ssh via localhost
454 454 sshserver = addr
455 455 if self._ssh and password is None:
456 456 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
457 457 password=False
458 458 else:
459 459 password = getpass("SSH Password for %s: "%sshserver)
460 460 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
461 461
462 462 # configure and construct the session
463 463 extra_args['packer'] = cfg['pack']
464 464 extra_args['unpacker'] = cfg['unpack']
465 465 extra_args['key'] = cast_bytes(cfg['exec_key'])
466 466
467 467 self.session = Session(**extra_args)
468 468
469 469 self._query_socket = self._context.socket(zmq.DEALER)
470 470
471 471 if self._ssh:
472 472 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
473 473 else:
474 474 self._query_socket.connect(cfg['registration'])
475 475
476 476 self.session.debug = self.debug
477 477
478 478 self._notification_handlers = {'registration_notification' : self._register_engine,
479 479 'unregistration_notification' : self._unregister_engine,
480 480 'shutdown_notification' : lambda msg: self.close(),
481 481 }
482 482 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
483 483 'apply_reply' : self._handle_apply_reply}
484 484 self._connect(sshserver, ssh_kwargs, timeout)
485 485
486 486 # last step: setup magics, if we are in IPython:
487 487
488 488 try:
489 489 ip = get_ipython()
490 490 except NameError:
491 491 return
492 492 else:
493 493 if 'px' not in ip.magics_manager.magics:
494 494 # in IPython but we are the first Client.
495 495 # activate a default view for parallel magics.
496 496 self.activate()
497 497
498 498 def __del__(self):
499 499 """cleanup sockets, but _not_ context."""
500 500 self.close()
501 501
502 502 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
503 503 if ipython_dir is None:
504 504 ipython_dir = get_ipython_dir()
505 505 if profile_dir is not None:
506 506 try:
507 507 self._cd = ProfileDir.find_profile_dir(profile_dir)
508 508 return
509 509 except ProfileDirError:
510 510 pass
511 511 elif profile is not None:
512 512 try:
513 513 self._cd = ProfileDir.find_profile_dir_by_name(
514 514 ipython_dir, profile)
515 515 return
516 516 except ProfileDirError:
517 517 pass
518 518 self._cd = None
519 519
520 520 def _update_engines(self, engines):
521 521 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
522 522 for k,v in engines.iteritems():
523 523 eid = int(k)
524 524 if eid not in self._engines:
525 525 self._ids.append(eid)
526 526 self._engines[eid] = v
527 527 self._ids = sorted(self._ids)
528 528 if sorted(self._engines.keys()) != range(len(self._engines)) and \
529 529 self._task_scheme == 'pure' and self._task_socket:
530 530 self._stop_scheduling_tasks()
531 531
532 532 def _stop_scheduling_tasks(self):
533 533 """Stop scheduling tasks because an engine has been unregistered
534 534 from a pure ZMQ scheduler.
535 535 """
536 536 self._task_socket.close()
537 537 self._task_socket = None
538 538 msg = "An engine has been unregistered, and we are using pure " +\
539 539 "ZMQ task scheduling. Task farming will be disabled."
540 540 if self.outstanding:
541 541 msg += " If you were running tasks when this happened, " +\
542 542 "some `outstanding` msg_ids may never resolve."
543 543 warnings.warn(msg, RuntimeWarning)
544 544
545 545 def _build_targets(self, targets):
546 546 """Turn valid target IDs or 'all' into two lists:
547 547 (int_ids, uuids).
548 548 """
549 549 if not self._ids:
550 550 # flush notification socket if no engines yet, just in case
551 551 if not self.ids:
552 552 raise error.NoEnginesRegistered("Can't build targets without any engines")
553 553
554 554 if targets is None:
555 555 targets = self._ids
556 556 elif isinstance(targets, basestring):
557 557 if targets.lower() == 'all':
558 558 targets = self._ids
559 559 else:
560 560 raise TypeError("%r not valid str target, must be 'all'"%(targets))
561 561 elif isinstance(targets, int):
562 562 if targets < 0:
563 563 targets = self.ids[targets]
564 564 if targets not in self._ids:
565 565 raise IndexError("No such engine: %i"%targets)
566 566 targets = [targets]
567 567
568 568 if isinstance(targets, slice):
569 569 indices = range(len(self._ids))[targets]
570 570 ids = self.ids
571 571 targets = [ ids[i] for i in indices ]
572 572
573 573 if not isinstance(targets, (tuple, list, xrange)):
574 574 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
575 575
576 576 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
577 577
578 578 def _connect(self, sshserver, ssh_kwargs, timeout):
579 579 """setup all our socket connections to the cluster. This is called from
580 580 __init__."""
581 581
582 582 # Maybe allow reconnecting?
583 583 if self._connected:
584 584 return
585 585 self._connected=True
586 586
587 587 def connect_socket(s, url):
588 588 # url = util.disambiguate_url(url, self._config['location'])
589 589 if self._ssh:
590 590 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
591 591 else:
592 592 return s.connect(url)
593 593
594 594 self.session.send(self._query_socket, 'connection_request')
595 595 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
596 596 poller = zmq.Poller()
597 597 poller.register(self._query_socket, zmq.POLLIN)
598 598 # poll expects milliseconds, timeout is seconds
599 599 evts = poller.poll(timeout*1000)
600 600 if not evts:
601 601 raise error.TimeoutError("Hub connection request timed out")
602 602 idents,msg = self.session.recv(self._query_socket,mode=0)
603 603 if self.debug:
604 604 pprint(msg)
605 605 content = msg['content']
606 606 # self._config['registration'] = dict(content)
607 607 cfg = self._config
608 608 if content['status'] == 'ok':
609 609 self._mux_socket = self._context.socket(zmq.DEALER)
610 610 connect_socket(self._mux_socket, cfg['mux'])
611 611
612 612 self._task_socket = self._context.socket(zmq.DEALER)
613 613 connect_socket(self._task_socket, cfg['task'])
614 614
615 615 self._notification_socket = self._context.socket(zmq.SUB)
616 616 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
617 617 connect_socket(self._notification_socket, cfg['notification'])
618 618
619 619 self._control_socket = self._context.socket(zmq.DEALER)
620 620 connect_socket(self._control_socket, cfg['control'])
621 621
622 622 self._iopub_socket = self._context.socket(zmq.SUB)
623 623 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
624 624 connect_socket(self._iopub_socket, cfg['iopub'])
625 625
626 626 self._update_engines(dict(content['engines']))
627 627 else:
628 628 self._connected = False
629 629 raise Exception("Failed to connect!")
630 630
631 631 #--------------------------------------------------------------------------
632 632 # handlers and callbacks for incoming messages
633 633 #--------------------------------------------------------------------------
634 634
635 635 def _unwrap_exception(self, content):
636 636 """unwrap exception, and remap engine_id to int."""
637 637 e = error.unwrap_exception(content)
638 638 # print e.traceback
639 639 if e.engine_info:
640 640 e_uuid = e.engine_info['engine_uuid']
641 641 eid = self._engines[e_uuid]
642 642 e.engine_info['engine_id'] = eid
643 643 return e
644 644
645 645 def _extract_metadata(self, msg):
646 646 header = msg['header']
647 647 parent = msg['parent_header']
648 648 msg_meta = msg['metadata']
649 649 content = msg['content']
650 650 md = {'msg_id' : parent['msg_id'],
651 651 'received' : datetime.now(),
652 652 'engine_uuid' : msg_meta.get('engine', None),
653 653 'follow' : msg_meta.get('follow', []),
654 654 'after' : msg_meta.get('after', []),
655 655 'status' : content['status'],
656 656 }
657 657
658 658 if md['engine_uuid'] is not None:
659 659 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
660 660
661 661 if 'date' in parent:
662 662 md['submitted'] = parent['date']
663 663 if 'started' in msg_meta:
664 664 md['started'] = msg_meta['started']
665 665 if 'date' in header:
666 666 md['completed'] = header['date']
667 667 return md
668 668
669 669 def _register_engine(self, msg):
670 670 """Register a new engine, and update our connection info."""
671 671 content = msg['content']
672 672 eid = content['id']
673 673 d = {eid : content['uuid']}
674 674 self._update_engines(d)
675 675
676 676 def _unregister_engine(self, msg):
677 677 """Unregister an engine that has died."""
678 678 content = msg['content']
679 679 eid = int(content['id'])
680 680 if eid in self._ids:
681 681 self._ids.remove(eid)
682 682 uuid = self._engines.pop(eid)
683 683
684 684 self._handle_stranded_msgs(eid, uuid)
685 685
686 686 if self._task_socket and self._task_scheme == 'pure':
687 687 self._stop_scheduling_tasks()
688 688
689 689 def _handle_stranded_msgs(self, eid, uuid):
690 690 """Handle messages known to be on an engine when the engine unregisters.
691 691
692 692 It is possible that this will fire prematurely - that is, an engine will
693 693 go down after completing a result, and the client will be notified
694 694 of the unregistration and later receive the successful result.
695 695 """
696 696
697 697 outstanding = self._outstanding_dict[uuid]
698 698
699 699 for msg_id in list(outstanding):
700 700 if msg_id in self.results:
701 701 # we already
702 702 continue
703 703 try:
704 704 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
705 705 except:
706 706 content = error.wrap_exception()
707 707 # build a fake message:
708 708 parent = {}
709 709 header = {}
710 710 parent['msg_id'] = msg_id
711 711 header['engine'] = uuid
712 712 header['date'] = datetime.now()
713 713 msg = dict(parent_header=parent, header=header, content=content)
714 714 self._handle_apply_reply(msg)
715 715
716 716 def _handle_execute_reply(self, msg):
717 717 """Save the reply to an execute_request into our results.
718 718
719 719 execute messages are never actually used. apply is used instead.
720 720 """
721 721
722 722 parent = msg['parent_header']
723 723 msg_id = parent['msg_id']
724 724 if msg_id not in self.outstanding:
725 725 if msg_id in self.history:
726 726 print ("got stale result: %s"%msg_id)
727 727 else:
728 728 print ("got unknown result: %s"%msg_id)
729 729 else:
730 730 self.outstanding.remove(msg_id)
731 731
732 732 content = msg['content']
733 733 header = msg['header']
734 734
735 735 # construct metadata:
736 736 md = self.metadata[msg_id]
737 737 md.update(self._extract_metadata(msg))
738 738 # is this redundant?
739 739 self.metadata[msg_id] = md
740 740
741 741 e_outstanding = self._outstanding_dict[md['engine_uuid']]
742 742 if msg_id in e_outstanding:
743 743 e_outstanding.remove(msg_id)
744 744
745 745 # construct result:
746 746 if content['status'] == 'ok':
747 747 self.results[msg_id] = ExecuteReply(msg_id, content, md)
748 748 elif content['status'] == 'aborted':
749 749 self.results[msg_id] = error.TaskAborted(msg_id)
750 750 elif content['status'] == 'resubmitted':
751 751 # TODO: handle resubmission
752 752 pass
753 753 else:
754 754 self.results[msg_id] = self._unwrap_exception(content)
755 755
756 756 def _handle_apply_reply(self, msg):
757 757 """Save the reply to an apply_request into our results."""
758 758 parent = msg['parent_header']
759 759 msg_id = parent['msg_id']
760 760 if msg_id not in self.outstanding:
761 761 if msg_id in self.history:
762 762 print ("got stale result: %s"%msg_id)
763 763 print self.results[msg_id]
764 764 print msg
765 765 else:
766 766 print ("got unknown result: %s"%msg_id)
767 767 else:
768 768 self.outstanding.remove(msg_id)
769 769 content = msg['content']
770 770 header = msg['header']
771 771
772 772 # construct metadata:
773 773 md = self.metadata[msg_id]
774 774 md.update(self._extract_metadata(msg))
775 775 # is this redundant?
776 776 self.metadata[msg_id] = md
777 777
778 778 e_outstanding = self._outstanding_dict[md['engine_uuid']]
779 779 if msg_id in e_outstanding:
780 780 e_outstanding.remove(msg_id)
781 781
782 782 # construct result:
783 783 if content['status'] == 'ok':
784 784 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
785 785 elif content['status'] == 'aborted':
786 786 self.results[msg_id] = error.TaskAborted(msg_id)
787 787 elif content['status'] == 'resubmitted':
788 788 # TODO: handle resubmission
789 789 pass
790 790 else:
791 791 self.results[msg_id] = self._unwrap_exception(content)
792 792
793 793 def _flush_notifications(self):
794 794 """Flush notifications of engine registrations waiting
795 795 in ZMQ queue."""
796 796 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
797 797 while msg is not None:
798 798 if self.debug:
799 799 pprint(msg)
800 800 msg_type = msg['header']['msg_type']
801 801 handler = self._notification_handlers.get(msg_type, None)
802 802 if handler is None:
803 803 raise Exception("Unhandled message type: %s"%msg.msg_type)
804 804 else:
805 805 handler(msg)
806 806 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
807 807
808 808 def _flush_results(self, sock):
809 809 """Flush task or queue results waiting in ZMQ queue."""
810 810 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
811 811 while msg is not None:
812 812 if self.debug:
813 813 pprint(msg)
814 814 msg_type = msg['header']['msg_type']
815 815 handler = self._queue_handlers.get(msg_type, None)
816 816 if handler is None:
817 817 raise Exception("Unhandled message type: %s"%msg.msg_type)
818 818 else:
819 819 handler(msg)
820 820 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
821 821
822 822 def _flush_control(self, sock):
823 823 """Flush replies from the control channel waiting
824 824 in the ZMQ queue.
825 825
826 826 Currently: ignore them."""
827 827 if self._ignored_control_replies <= 0:
828 828 return
829 829 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
830 830 while msg is not None:
831 831 self._ignored_control_replies -= 1
832 832 if self.debug:
833 833 pprint(msg)
834 834 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
835 835
836 836 def _flush_ignored_control(self):
837 837 """flush ignored control replies"""
838 838 while self._ignored_control_replies > 0:
839 839 self.session.recv(self._control_socket)
840 840 self._ignored_control_replies -= 1
841 841
842 842 def _flush_ignored_hub_replies(self):
843 843 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
844 844 while msg is not None:
845 845 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
846 846
847 847 def _flush_iopub(self, sock):
848 848 """Flush replies from the iopub channel waiting
849 849 in the ZMQ queue.
850 850 """
851 851 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
852 852 while msg is not None:
853 853 if self.debug:
854 854 pprint(msg)
855 855 parent = msg['parent_header']
856 856 # ignore IOPub messages with no parent.
857 857 # Caused by print statements or warnings from before the first execution.
858 858 if not parent:
859 859 continue
860 860 msg_id = parent['msg_id']
861 861 content = msg['content']
862 862 header = msg['header']
863 863 msg_type = msg['header']['msg_type']
864 864
865 865 # init metadata:
866 866 md = self.metadata[msg_id]
867 867
868 868 if msg_type == 'stream':
869 869 name = content['name']
870 870 s = md[name] or ''
871 871 md[name] = s + content['data']
872 872 elif msg_type == 'pyerr':
873 873 md.update({'pyerr' : self._unwrap_exception(content)})
874 874 elif msg_type == 'pyin':
875 875 md.update({'pyin' : content['code']})
876 876 elif msg_type == 'display_data':
877 877 md['outputs'].append(content)
878 878 elif msg_type == 'pyout':
879 879 md['pyout'] = content
880 880 elif msg_type == 'data_message':
881 881 data, remainder = serialize.unserialize_object(msg['buffers'])
882 882 md['data'].update(data)
883 883 elif msg_type == 'status':
884 884 # idle message comes after all outputs
885 885 if content['execution_state'] == 'idle':
886 886 md['outputs_ready'] = True
887 887 else:
888 888 # unhandled msg_type (status, etc.)
889 889 pass
890 890
891 891 # reduntant?
892 892 self.metadata[msg_id] = md
893 893
894 894 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
895 895
896 896 #--------------------------------------------------------------------------
897 897 # len, getitem
898 898 #--------------------------------------------------------------------------
899 899
900 900 def __len__(self):
901 901 """len(client) returns # of engines."""
902 902 return len(self.ids)
903 903
904 904 def __getitem__(self, key):
905 905 """index access returns DirectView multiplexer objects
906 906
907 907 Must be int, slice, or list/tuple/xrange of ints"""
908 908 if not isinstance(key, (int, slice, tuple, list, xrange)):
909 909 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
910 910 else:
911 911 return self.direct_view(key)
912 912
913 913 #--------------------------------------------------------------------------
914 914 # Begin public methods
915 915 #--------------------------------------------------------------------------
916 916
917 917 @property
918 918 def ids(self):
919 919 """Always up-to-date ids property."""
920 920 self._flush_notifications()
921 921 # always copy:
922 922 return list(self._ids)
923 923
924 924 def activate(self, targets='all', suffix=''):
925 925 """Create a DirectView and register it with IPython magics
926 926
927 927 Defines the magics `%px, %autopx, %pxresult, %%px`
928 928
929 929 Parameters
930 930 ----------
931 931
932 932 targets: int, list of ints, or 'all'
933 933 The engines on which the view's magics will run
934 934 suffix: str [default: '']
935 935 The suffix, if any, for the magics. This allows you to have
936 936 multiple views associated with parallel magics at the same time.
937 937
938 938 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
939 939 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
940 940 on engine 0.
941 941 """
942 942 view = self.direct_view(targets)
943 943 view.block = True
944 944 view.activate(suffix)
945 945 return view
946 946
947 947 def close(self):
948 948 if self._closed:
949 949 return
950 950 self.stop_spin_thread()
951 951 snames = filter(lambda n: n.endswith('socket'), dir(self))
952 952 for socket in map(lambda name: getattr(self, name), snames):
953 953 if isinstance(socket, zmq.Socket) and not socket.closed:
954 954 socket.close()
955 955 self._closed = True
956 956
957 957 def _spin_every(self, interval=1):
958 958 """target func for use in spin_thread"""
959 959 while True:
960 960 if self._stop_spinning.is_set():
961 961 return
962 962 time.sleep(interval)
963 963 self.spin()
964 964
965 965 def spin_thread(self, interval=1):
966 966 """call Client.spin() in a background thread on some regular interval
967 967
968 968 This helps ensure that messages don't pile up too much in the zmq queue
969 969 while you are working on other things, or just leaving an idle terminal.
970 970
971 971 It also helps limit potential padding of the `received` timestamp
972 972 on AsyncResult objects, used for timings.
973 973
974 974 Parameters
975 975 ----------
976 976
977 977 interval : float, optional
978 978 The interval on which to spin the client in the background thread
979 979 (simply passed to time.sleep).
980 980
981 981 Notes
982 982 -----
983 983
984 984 For precision timing, you may want to use this method to put a bound
985 985 on the jitter (in seconds) in `received` timestamps used
986 986 in AsyncResult.wall_time.
987 987
988 988 """
989 989 if self._spin_thread is not None:
990 990 self.stop_spin_thread()
991 991 self._stop_spinning.clear()
992 992 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
993 993 self._spin_thread.daemon = True
994 994 self._spin_thread.start()
995 995
996 996 def stop_spin_thread(self):
997 997 """stop background spin_thread, if any"""
998 998 if self._spin_thread is not None:
999 999 self._stop_spinning.set()
1000 1000 self._spin_thread.join()
1001 1001 self._spin_thread = None
1002 1002
1003 1003 def spin(self):
1004 1004 """Flush any registration notifications and execution results
1005 1005 waiting in the ZMQ queue.
1006 1006 """
1007 1007 if self._notification_socket:
1008 1008 self._flush_notifications()
1009 1009 if self._iopub_socket:
1010 1010 self._flush_iopub(self._iopub_socket)
1011 1011 if self._mux_socket:
1012 1012 self._flush_results(self._mux_socket)
1013 1013 if self._task_socket:
1014 1014 self._flush_results(self._task_socket)
1015 1015 if self._control_socket:
1016 1016 self._flush_control(self._control_socket)
1017 1017 if self._query_socket:
1018 1018 self._flush_ignored_hub_replies()
1019 1019
1020 1020 def wait(self, jobs=None, timeout=-1):
1021 1021 """waits on one or more `jobs`, for up to `timeout` seconds.
1022 1022
1023 1023 Parameters
1024 1024 ----------
1025 1025
1026 1026 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1027 1027 ints are indices to self.history
1028 1028 strs are msg_ids
1029 1029 default: wait on all outstanding messages
1030 1030 timeout : float
1031 1031 a time in seconds, after which to give up.
1032 1032 default is -1, which means no timeout
1033 1033
1034 1034 Returns
1035 1035 -------
1036 1036
1037 1037 True : when all msg_ids are done
1038 1038 False : timeout reached, some msg_ids still outstanding
1039 1039 """
1040 1040 tic = time.time()
1041 1041 if jobs is None:
1042 1042 theids = self.outstanding
1043 1043 else:
1044 1044 if isinstance(jobs, (int, basestring, AsyncResult)):
1045 1045 jobs = [jobs]
1046 1046 theids = set()
1047 1047 for job in jobs:
1048 1048 if isinstance(job, int):
1049 1049 # index access
1050 1050 job = self.history[job]
1051 1051 elif isinstance(job, AsyncResult):
1052 1052 map(theids.add, job.msg_ids)
1053 1053 continue
1054 1054 theids.add(job)
1055 1055 if not theids.intersection(self.outstanding):
1056 1056 return True
1057 1057 self.spin()
1058 1058 while theids.intersection(self.outstanding):
1059 1059 if timeout >= 0 and ( time.time()-tic ) > timeout:
1060 1060 break
1061 1061 time.sleep(1e-3)
1062 1062 self.spin()
1063 1063 return len(theids.intersection(self.outstanding)) == 0
1064 1064
1065 1065 #--------------------------------------------------------------------------
1066 1066 # Control methods
1067 1067 #--------------------------------------------------------------------------
1068 1068
1069 1069 @spin_first
1070 1070 def clear(self, targets=None, block=None):
1071 1071 """Clear the namespace in target(s)."""
1072 1072 block = self.block if block is None else block
1073 1073 targets = self._build_targets(targets)[0]
1074 1074 for t in targets:
1075 1075 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1076 1076 error = False
1077 1077 if block:
1078 1078 self._flush_ignored_control()
1079 1079 for i in range(len(targets)):
1080 1080 idents,msg = self.session.recv(self._control_socket,0)
1081 1081 if self.debug:
1082 1082 pprint(msg)
1083 1083 if msg['content']['status'] != 'ok':
1084 1084 error = self._unwrap_exception(msg['content'])
1085 1085 else:
1086 1086 self._ignored_control_replies += len(targets)
1087 1087 if error:
1088 1088 raise error
1089 1089
1090 1090
1091 1091 @spin_first
1092 1092 def abort(self, jobs=None, targets=None, block=None):
1093 1093 """Abort specific jobs from the execution queues of target(s).
1094 1094
1095 1095 This is a mechanism to prevent jobs that have already been submitted
1096 1096 from executing.
1097 1097
1098 1098 Parameters
1099 1099 ----------
1100 1100
1101 1101 jobs : msg_id, list of msg_ids, or AsyncResult
1102 1102 The jobs to be aborted
1103 1103
1104 1104 If unspecified/None: abort all outstanding jobs.
1105 1105
1106 1106 """
1107 1107 block = self.block if block is None else block
1108 1108 jobs = jobs if jobs is not None else list(self.outstanding)
1109 1109 targets = self._build_targets(targets)[0]
1110 1110
1111 1111 msg_ids = []
1112 1112 if isinstance(jobs, (basestring,AsyncResult)):
1113 1113 jobs = [jobs]
1114 1114 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1115 1115 if bad_ids:
1116 1116 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1117 1117 for j in jobs:
1118 1118 if isinstance(j, AsyncResult):
1119 1119 msg_ids.extend(j.msg_ids)
1120 1120 else:
1121 1121 msg_ids.append(j)
1122 1122 content = dict(msg_ids=msg_ids)
1123 1123 for t in targets:
1124 1124 self.session.send(self._control_socket, 'abort_request',
1125 1125 content=content, ident=t)
1126 1126 error = False
1127 1127 if block:
1128 1128 self._flush_ignored_control()
1129 1129 for i in range(len(targets)):
1130 1130 idents,msg = self.session.recv(self._control_socket,0)
1131 1131 if self.debug:
1132 1132 pprint(msg)
1133 1133 if msg['content']['status'] != 'ok':
1134 1134 error = self._unwrap_exception(msg['content'])
1135 1135 else:
1136 1136 self._ignored_control_replies += len(targets)
1137 1137 if error:
1138 1138 raise error
1139 1139
1140 1140 @spin_first
1141 1141 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1142 1142 """Terminates one or more engine processes, optionally including the hub.
1143 1143
1144 1144 Parameters
1145 1145 ----------
1146 1146
1147 1147 targets: list of ints or 'all' [default: all]
1148 1148 Which engines to shutdown.
1149 1149 hub: bool [default: False]
1150 1150 Whether to include the Hub. hub=True implies targets='all'.
1151 1151 block: bool [default: self.block]
1152 1152 Whether to wait for clean shutdown replies or not.
1153 1153 restart: bool [default: False]
1154 1154 NOT IMPLEMENTED
1155 1155 whether to restart engines after shutting them down.
1156 1156 """
1157 1157
1158 1158 if restart:
1159 1159 raise NotImplementedError("Engine restart is not yet implemented")
1160 1160
1161 1161 block = self.block if block is None else block
1162 1162 if hub:
1163 1163 targets = 'all'
1164 1164 targets = self._build_targets(targets)[0]
1165 1165 for t in targets:
1166 1166 self.session.send(self._control_socket, 'shutdown_request',
1167 1167 content={'restart':restart},ident=t)
1168 1168 error = False
1169 1169 if block or hub:
1170 1170 self._flush_ignored_control()
1171 1171 for i in range(len(targets)):
1172 1172 idents,msg = self.session.recv(self._control_socket, 0)
1173 1173 if self.debug:
1174 1174 pprint(msg)
1175 1175 if msg['content']['status'] != 'ok':
1176 1176 error = self._unwrap_exception(msg['content'])
1177 1177 else:
1178 1178 self._ignored_control_replies += len(targets)
1179 1179
1180 1180 if hub:
1181 1181 time.sleep(0.25)
1182 1182 self.session.send(self._query_socket, 'shutdown_request')
1183 1183 idents,msg = self.session.recv(self._query_socket, 0)
1184 1184 if self.debug:
1185 1185 pprint(msg)
1186 1186 if msg['content']['status'] != 'ok':
1187 1187 error = self._unwrap_exception(msg['content'])
1188 1188
1189 1189 if error:
1190 1190 raise error
1191 1191
1192 1192 #--------------------------------------------------------------------------
1193 1193 # Execution related methods
1194 1194 #--------------------------------------------------------------------------
1195 1195
1196 1196 def _maybe_raise(self, result):
1197 1197 """wrapper for maybe raising an exception if apply failed."""
1198 1198 if isinstance(result, error.RemoteError):
1199 1199 raise result
1200 1200
1201 1201 return result
1202 1202
1203 1203 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1204 1204 ident=None):
1205 1205 """construct and send an apply message via a socket.
1206 1206
1207 1207 This is the principal method with which all engine execution is performed by views.
1208 1208 """
1209 1209
1210 1210 if self._closed:
1211 1211 raise RuntimeError("Client cannot be used after its sockets have been closed")
1212 1212
1213 1213 # defaults:
1214 1214 args = args if args is not None else []
1215 1215 kwargs = kwargs if kwargs is not None else {}
1216 1216 metadata = metadata if metadata is not None else {}
1217 1217
1218 1218 # validate arguments
1219 1219 if not callable(f) and not isinstance(f, Reference):
1220 1220 raise TypeError("f must be callable, not %s"%type(f))
1221 1221 if not isinstance(args, (tuple, list)):
1222 1222 raise TypeError("args must be tuple or list, not %s"%type(args))
1223 1223 if not isinstance(kwargs, dict):
1224 1224 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1225 1225 if not isinstance(metadata, dict):
1226 1226 raise TypeError("metadata must be dict, not %s"%type(metadata))
1227 1227
1228 1228 bufs = serialize.pack_apply_message(f, args, kwargs,
1229 1229 buffer_threshold=self.session.buffer_threshold,
1230 1230 item_threshold=self.session.item_threshold,
1231 1231 )
1232 1232
1233 1233 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1234 1234 metadata=metadata, track=track)
1235 1235
1236 1236 msg_id = msg['header']['msg_id']
1237 1237 self.outstanding.add(msg_id)
1238 1238 if ident:
1239 1239 # possibly routed to a specific engine
1240 1240 if isinstance(ident, list):
1241 1241 ident = ident[-1]
1242 1242 if ident in self._engines.values():
1243 1243 # save for later, in case of engine death
1244 1244 self._outstanding_dict[ident].add(msg_id)
1245 1245 self.history.append(msg_id)
1246 1246 self.metadata[msg_id]['submitted'] = datetime.now()
1247 1247
1248 1248 return msg
1249 1249
1250 1250 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1251 1251 """construct and send an execute request via a socket.
1252 1252
1253 1253 """
1254 1254
1255 1255 if self._closed:
1256 1256 raise RuntimeError("Client cannot be used after its sockets have been closed")
1257 1257
1258 1258 # defaults:
1259 1259 metadata = metadata if metadata is not None else {}
1260 1260
1261 1261 # validate arguments
1262 1262 if not isinstance(code, basestring):
1263 1263 raise TypeError("code must be text, not %s" % type(code))
1264 1264 if not isinstance(metadata, dict):
1265 1265 raise TypeError("metadata must be dict, not %s" % type(metadata))
1266 1266
1267 1267 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1268 1268
1269 1269
1270 1270 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1271 1271 metadata=metadata)
1272 1272
1273 1273 msg_id = msg['header']['msg_id']
1274 1274 self.outstanding.add(msg_id)
1275 1275 if ident:
1276 1276 # possibly routed to a specific engine
1277 1277 if isinstance(ident, list):
1278 1278 ident = ident[-1]
1279 1279 if ident in self._engines.values():
1280 1280 # save for later, in case of engine death
1281 1281 self._outstanding_dict[ident].add(msg_id)
1282 1282 self.history.append(msg_id)
1283 1283 self.metadata[msg_id]['submitted'] = datetime.now()
1284 1284
1285 1285 return msg
1286 1286
1287 1287 #--------------------------------------------------------------------------
1288 1288 # construct a View object
1289 1289 #--------------------------------------------------------------------------
1290 1290
1291 1291 def load_balanced_view(self, targets=None):
1292 1292 """construct a DirectView object.
1293 1293
1294 1294 If no arguments are specified, create a LoadBalancedView
1295 1295 using all engines.
1296 1296
1297 1297 Parameters
1298 1298 ----------
1299 1299
1300 1300 targets: list,slice,int,etc. [default: use all engines]
1301 1301 The subset of engines across which to load-balance
1302 1302 """
1303 1303 if targets == 'all':
1304 1304 targets = None
1305 1305 if targets is not None:
1306 1306 targets = self._build_targets(targets)[1]
1307 1307 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1308 1308
1309 1309 def direct_view(self, targets='all'):
1310 1310 """construct a DirectView object.
1311 1311
1312 1312 If no targets are specified, create a DirectView using all engines.
1313 1313
1314 1314 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1315 1315 evaluate the target engines at each execution, whereas rc[:] will connect to
1316 1316 all *current* engines, and that list will not change.
1317 1317
1318 1318 That is, 'all' will always use all engines, whereas rc[:] will not use
1319 1319 engines added after the DirectView is constructed.
1320 1320
1321 1321 Parameters
1322 1322 ----------
1323 1323
1324 1324 targets: list,slice,int,etc. [default: use all engines]
1325 1325 The engines to use for the View
1326 1326 """
1327 1327 single = isinstance(targets, int)
1328 1328 # allow 'all' to be lazily evaluated at each execution
1329 1329 if targets != 'all':
1330 1330 targets = self._build_targets(targets)[1]
1331 1331 if single:
1332 1332 targets = targets[0]
1333 1333 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1334 1334
1335 1335 #--------------------------------------------------------------------------
1336 1336 # Query methods
1337 1337 #--------------------------------------------------------------------------
1338 1338
1339 1339 @spin_first
1340 1340 def get_result(self, indices_or_msg_ids=None, block=None):
1341 1341 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1342 1342
1343 1343 If the client already has the results, no request to the Hub will be made.
1344 1344
1345 1345 This is a convenient way to construct AsyncResult objects, which are wrappers
1346 1346 that include metadata about execution, and allow for awaiting results that
1347 1347 were not submitted by this Client.
1348 1348
1349 1349 It can also be a convenient way to retrieve the metadata associated with
1350 1350 blocking execution, since it always retrieves
1351 1351
1352 1352 Examples
1353 1353 --------
1354 1354 ::
1355 1355
1356 1356 In [10]: r = client.apply()
1357 1357
1358 1358 Parameters
1359 1359 ----------
1360 1360
1361 1361 indices_or_msg_ids : integer history index, str msg_id, or list of either
1362 1362 The indices or msg_ids of indices to be retrieved
1363 1363
1364 1364 block : bool
1365 1365 Whether to wait for the result to be done
1366 1366
1367 1367 Returns
1368 1368 -------
1369 1369
1370 1370 AsyncResult
1371 1371 A single AsyncResult object will always be returned.
1372 1372
1373 1373 AsyncHubResult
1374 1374 A subclass of AsyncResult that retrieves results from the Hub
1375 1375
1376 1376 """
1377 1377 block = self.block if block is None else block
1378 1378 if indices_or_msg_ids is None:
1379 1379 indices_or_msg_ids = -1
1380 1380
1381 1381 if not isinstance(indices_or_msg_ids, (list,tuple)):
1382 1382 indices_or_msg_ids = [indices_or_msg_ids]
1383 1383
1384 1384 theids = []
1385 1385 for id in indices_or_msg_ids:
1386 1386 if isinstance(id, int):
1387 1387 id = self.history[id]
1388 1388 if not isinstance(id, basestring):
1389 1389 raise TypeError("indices must be str or int, not %r"%id)
1390 1390 theids.append(id)
1391 1391
1392 1392 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1393 1393 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1394 1394
1395 1395 if remote_ids:
1396 1396 ar = AsyncHubResult(self, msg_ids=theids)
1397 1397 else:
1398 1398 ar = AsyncResult(self, msg_ids=theids)
1399 1399
1400 1400 if block:
1401 1401 ar.wait()
1402 1402
1403 1403 return ar
1404 1404
1405 1405 @spin_first
1406 1406 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1407 1407 """Resubmit one or more tasks.
1408 1408
1409 1409 in-flight tasks may not be resubmitted.
1410 1410
1411 1411 Parameters
1412 1412 ----------
1413 1413
1414 1414 indices_or_msg_ids : integer history index, str msg_id, or list of either
1415 1415 The indices or msg_ids of indices to be retrieved
1416 1416
1417 1417 block : bool
1418 1418 Whether to wait for the result to be done
1419 1419
1420 1420 Returns
1421 1421 -------
1422 1422
1423 1423 AsyncHubResult
1424 1424 A subclass of AsyncResult that retrieves results from the Hub
1425 1425
1426 1426 """
1427 1427 block = self.block if block is None else block
1428 1428 if indices_or_msg_ids is None:
1429 1429 indices_or_msg_ids = -1
1430 1430
1431 1431 if not isinstance(indices_or_msg_ids, (list,tuple)):
1432 1432 indices_or_msg_ids = [indices_or_msg_ids]
1433 1433
1434 1434 theids = []
1435 1435 for id in indices_or_msg_ids:
1436 1436 if isinstance(id, int):
1437 1437 id = self.history[id]
1438 1438 if not isinstance(id, basestring):
1439 1439 raise TypeError("indices must be str or int, not %r"%id)
1440 1440 theids.append(id)
1441 1441
1442 1442 content = dict(msg_ids = theids)
1443 1443
1444 1444 self.session.send(self._query_socket, 'resubmit_request', content)
1445 1445
1446 1446 zmq.select([self._query_socket], [], [])
1447 1447 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1448 1448 if self.debug:
1449 1449 pprint(msg)
1450 1450 content = msg['content']
1451 1451 if content['status'] != 'ok':
1452 1452 raise self._unwrap_exception(content)
1453 1453 mapping = content['resubmitted']
1454 1454 new_ids = [ mapping[msg_id] for msg_id in theids ]
1455 1455
1456 1456 ar = AsyncHubResult(self, msg_ids=new_ids)
1457 1457
1458 1458 if block:
1459 1459 ar.wait()
1460 1460
1461 1461 return ar
1462 1462
1463 1463 @spin_first
1464 1464 def result_status(self, msg_ids, status_only=True):
1465 1465 """Check on the status of the result(s) of the apply request with `msg_ids`.
1466 1466
1467 1467 If status_only is False, then the actual results will be retrieved, else
1468 1468 only the status of the results will be checked.
1469 1469
1470 1470 Parameters
1471 1471 ----------
1472 1472
1473 1473 msg_ids : list of msg_ids
1474 1474 if int:
1475 1475 Passed as index to self.history for convenience.
1476 1476 status_only : bool (default: True)
1477 1477 if False:
1478 1478 Retrieve the actual results of completed tasks.
1479 1479
1480 1480 Returns
1481 1481 -------
1482 1482
1483 1483 results : dict
1484 1484 There will always be the keys 'pending' and 'completed', which will
1485 1485 be lists of msg_ids that are incomplete or complete. If `status_only`
1486 1486 is False, then completed results will be keyed by their `msg_id`.
1487 1487 """
1488 1488 if not isinstance(msg_ids, (list,tuple)):
1489 1489 msg_ids = [msg_ids]
1490 1490
1491 1491 theids = []
1492 1492 for msg_id in msg_ids:
1493 1493 if isinstance(msg_id, int):
1494 1494 msg_id = self.history[msg_id]
1495 1495 if not isinstance(msg_id, basestring):
1496 1496 raise TypeError("msg_ids must be str, not %r"%msg_id)
1497 1497 theids.append(msg_id)
1498 1498
1499 1499 completed = []
1500 1500 local_results = {}
1501 1501
1502 1502 # comment this block out to temporarily disable local shortcut:
1503 1503 for msg_id in theids:
1504 1504 if msg_id in self.results:
1505 1505 completed.append(msg_id)
1506 1506 local_results[msg_id] = self.results[msg_id]
1507 1507 theids.remove(msg_id)
1508 1508
1509 1509 if theids: # some not locally cached
1510 1510 content = dict(msg_ids=theids, status_only=status_only)
1511 1511 msg = self.session.send(self._query_socket, "result_request", content=content)
1512 1512 zmq.select([self._query_socket], [], [])
1513 1513 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1514 1514 if self.debug:
1515 1515 pprint(msg)
1516 1516 content = msg['content']
1517 1517 if content['status'] != 'ok':
1518 1518 raise self._unwrap_exception(content)
1519 1519 buffers = msg['buffers']
1520 1520 else:
1521 1521 content = dict(completed=[],pending=[])
1522 1522
1523 1523 content['completed'].extend(completed)
1524 1524
1525 1525 if status_only:
1526 1526 return content
1527 1527
1528 1528 failures = []
1529 1529 # load cached results into result:
1530 1530 content.update(local_results)
1531 1531
1532 1532 # update cache with results:
1533 1533 for msg_id in sorted(theids):
1534 1534 if msg_id in content['completed']:
1535 1535 rec = content[msg_id]
1536 1536 parent = rec['header']
1537 1537 header = rec['result_header']
1538 1538 rcontent = rec['result_content']
1539 1539 iodict = rec['io']
1540 1540 if isinstance(rcontent, str):
1541 1541 rcontent = self.session.unpack(rcontent)
1542 1542
1543 1543 md = self.metadata[msg_id]
1544 1544 md_msg = dict(
1545 1545 content=rcontent,
1546 1546 parent_header=parent,
1547 1547 header=header,
1548 1548 metadata=rec['result_metadata'],
1549 1549 )
1550 1550 md.update(self._extract_metadata(md_msg))
1551 1551 if rec.get('received'):
1552 1552 md['received'] = rec['received']
1553 1553 md.update(iodict)
1554 1554
1555 1555 if rcontent['status'] == 'ok':
1556 1556 if header['msg_type'] == 'apply_reply':
1557 1557 res,buffers = serialize.unserialize_object(buffers)
1558 1558 elif header['msg_type'] == 'execute_reply':
1559 1559 res = ExecuteReply(msg_id, rcontent, md)
1560 1560 else:
1561 1561 raise KeyError("unhandled msg type: %r" % header[msg_type])
1562 1562 else:
1563 1563 res = self._unwrap_exception(rcontent)
1564 1564 failures.append(res)
1565 1565
1566 1566 self.results[msg_id] = res
1567 1567 content[msg_id] = res
1568 1568
1569 1569 if len(theids) == 1 and failures:
1570 1570 raise failures[0]
1571 1571
1572 1572 error.collect_exceptions(failures, "result_status")
1573 1573 return content
1574 1574
1575 1575 @spin_first
1576 1576 def queue_status(self, targets='all', verbose=False):
1577 1577 """Fetch the status of engine queues.
1578 1578
1579 1579 Parameters
1580 1580 ----------
1581 1581
1582 1582 targets : int/str/list of ints/strs
1583 1583 the engines whose states are to be queried.
1584 1584 default : all
1585 1585 verbose : bool
1586 1586 Whether to return lengths only, or lists of ids for each element
1587 1587 """
1588 1588 if targets == 'all':
1589 1589 # allow 'all' to be evaluated on the engine
1590 1590 engine_ids = None
1591 1591 else:
1592 1592 engine_ids = self._build_targets(targets)[1]
1593 1593 content = dict(targets=engine_ids, verbose=verbose)
1594 1594 self.session.send(self._query_socket, "queue_request", content=content)
1595 1595 idents,msg = self.session.recv(self._query_socket, 0)
1596 1596 if self.debug:
1597 1597 pprint(msg)
1598 1598 content = msg['content']
1599 1599 status = content.pop('status')
1600 1600 if status != 'ok':
1601 1601 raise self._unwrap_exception(content)
1602 1602 content = rekey(content)
1603 1603 if isinstance(targets, int):
1604 1604 return content[targets]
1605 1605 else:
1606 1606 return content
1607 1607
1608 1608 @spin_first
1609 1609 def purge_results(self, jobs=[], targets=[]):
1610 1610 """Tell the Hub to forget results.
1611 1611
1612 1612 Individual results can be purged by msg_id, or the entire
1613 1613 history of specific targets can be purged.
1614 1614
1615 1615 Use `purge_results('all')` to scrub everything from the Hub's db.
1616 1616
1617 1617 Parameters
1618 1618 ----------
1619 1619
1620 1620 jobs : str or list of str or AsyncResult objects
1621 1621 the msg_ids whose results should be forgotten.
1622 1622 targets : int/str/list of ints/strs
1623 1623 The targets, by int_id, whose entire history is to be purged.
1624 1624
1625 1625 default : None
1626 1626 """
1627 1627 if not targets and not jobs:
1628 1628 raise ValueError("Must specify at least one of `targets` and `jobs`")
1629 1629 if targets:
1630 1630 targets = self._build_targets(targets)[1]
1631 1631
1632 1632 # construct msg_ids from jobs
1633 1633 if jobs == 'all':
1634 1634 msg_ids = jobs
1635 1635 else:
1636 1636 msg_ids = []
1637 1637 if isinstance(jobs, (basestring,AsyncResult)):
1638 1638 jobs = [jobs]
1639 1639 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1640 1640 if bad_ids:
1641 1641 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1642 1642 for j in jobs:
1643 1643 if isinstance(j, AsyncResult):
1644 1644 msg_ids.extend(j.msg_ids)
1645 1645 else:
1646 1646 msg_ids.append(j)
1647 1647
1648 1648 content = dict(engine_ids=targets, msg_ids=msg_ids)
1649 1649 self.session.send(self._query_socket, "purge_request", content=content)
1650 1650 idents, msg = self.session.recv(self._query_socket, 0)
1651 1651 if self.debug:
1652 1652 pprint(msg)
1653 1653 content = msg['content']
1654 1654 if content['status'] != 'ok':
1655 1655 raise self._unwrap_exception(content)
1656 1656
1657 1657 @spin_first
1658 1658 def hub_history(self):
1659 1659 """Get the Hub's history
1660 1660
1661 1661 Just like the Client, the Hub has a history, which is a list of msg_ids.
1662 1662 This will contain the history of all clients, and, depending on configuration,
1663 1663 may contain history across multiple cluster sessions.
1664 1664
1665 1665 Any msg_id returned here is a valid argument to `get_result`.
1666 1666
1667 1667 Returns
1668 1668 -------
1669 1669
1670 1670 msg_ids : list of strs
1671 1671 list of all msg_ids, ordered by task submission time.
1672 1672 """
1673 1673
1674 1674 self.session.send(self._query_socket, "history_request", content={})
1675 1675 idents, msg = self.session.recv(self._query_socket, 0)
1676 1676
1677 1677 if self.debug:
1678 1678 pprint(msg)
1679 1679 content = msg['content']
1680 1680 if content['status'] != 'ok':
1681 1681 raise self._unwrap_exception(content)
1682 1682 else:
1683 1683 return content['history']
1684 1684
1685 1685 @spin_first
1686 1686 def db_query(self, query, keys=None):
1687 1687 """Query the Hub's TaskRecord database
1688 1688
1689 1689 This will return a list of task record dicts that match `query`
1690 1690
1691 1691 Parameters
1692 1692 ----------
1693 1693
1694 1694 query : mongodb query dict
1695 1695 The search dict. See mongodb query docs for details.
1696 1696 keys : list of strs [optional]
1697 1697 The subset of keys to be returned. The default is to fetch everything but buffers.
1698 1698 'msg_id' will *always* be included.
1699 1699 """
1700 1700 if isinstance(keys, basestring):
1701 1701 keys = [keys]
1702 1702 content = dict(query=query, keys=keys)
1703 1703 self.session.send(self._query_socket, "db_request", content=content)
1704 1704 idents, msg = self.session.recv(self._query_socket, 0)
1705 1705 if self.debug:
1706 1706 pprint(msg)
1707 1707 content = msg['content']
1708 1708 if content['status'] != 'ok':
1709 1709 raise self._unwrap_exception(content)
1710 1710
1711 1711 records = content['records']
1712 1712
1713 1713 buffer_lens = content['buffer_lens']
1714 1714 result_buffer_lens = content['result_buffer_lens']
1715 1715 buffers = msg['buffers']
1716 1716 has_bufs = buffer_lens is not None
1717 1717 has_rbufs = result_buffer_lens is not None
1718 1718 for i,rec in enumerate(records):
1719 1719 # relink buffers
1720 1720 if has_bufs:
1721 1721 blen = buffer_lens[i]
1722 1722 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1723 1723 if has_rbufs:
1724 1724 blen = result_buffer_lens[i]
1725 1725 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1726 1726
1727 1727 return records
1728 1728
1729 1729 __all__ = [ 'Client' ]
General Comments 0
You need to be logged in to leave comments. Login now