##// END OF EJS Templates
Backport PR #2769: Allow shutdown when no engines are registered...
MinRK -
Show More
@@ -1,1721 +1,1724 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 from threading import Thread, Event
22 22 import time
23 23 import warnings
24 24 from datetime import datetime
25 25 from getpass import getpass
26 26 from pprint import pprint
27 27
28 28 pjoin = os.path.join
29 29
30 30 import zmq
31 31 # from zmq.eventloop import ioloop, zmqstream
32 32
33 33 from IPython.config.configurable import MultipleInstanceError
34 34 from IPython.core.application import BaseIPythonApplication
35 35 from IPython.core.profiledir import ProfileDir, ProfileDirError
36 36
37 37 from IPython.utils.coloransi import TermColors
38 38 from IPython.utils.jsonutil import rekey
39 39 from IPython.utils.localinterfaces import LOCAL_IPS
40 40 from IPython.utils.path import get_ipython_dir
41 41 from IPython.utils.py3compat import cast_bytes
42 42 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
43 43 Dict, List, Bool, Set, Any)
44 44 from IPython.external.decorator import decorator
45 45 from IPython.external.ssh import tunnel
46 46
47 47 from IPython.parallel import Reference
48 48 from IPython.parallel import error
49 49 from IPython.parallel import util
50 50
51 51 from IPython.zmq.session import Session, Message
52 52
53 53 from .asyncresult import AsyncResult, AsyncHubResult
54 54 from .view import DirectView, LoadBalancedView
55 55
56 56 if sys.version_info[0] >= 3:
57 57 # xrange is used in a couple 'isinstance' tests in py2
58 58 # should be just 'range' in 3k
59 59 xrange = range
60 60
61 61 #--------------------------------------------------------------------------
62 62 # Decorators for Client methods
63 63 #--------------------------------------------------------------------------
64 64
65 65 @decorator
66 66 def spin_first(f, self, *args, **kwargs):
67 67 """Call spin() to sync state prior to calling the method."""
68 68 self.spin()
69 69 return f(self, *args, **kwargs)
70 70
71 71
72 72 #--------------------------------------------------------------------------
73 73 # Classes
74 74 #--------------------------------------------------------------------------
75 75
76 76
77 77 class ExecuteReply(object):
78 78 """wrapper for finished Execute results"""
79 79 def __init__(self, msg_id, content, metadata):
80 80 self.msg_id = msg_id
81 81 self._content = content
82 82 self.execution_count = content['execution_count']
83 83 self.metadata = metadata
84 84
85 85 def __getitem__(self, key):
86 86 return self.metadata[key]
87 87
88 88 def __getattr__(self, key):
89 89 if key not in self.metadata:
90 90 raise AttributeError(key)
91 91 return self.metadata[key]
92 92
93 93 def __repr__(self):
94 94 pyout = self.metadata['pyout'] or {'data':{}}
95 95 text_out = pyout['data'].get('text/plain', '')
96 96 if len(text_out) > 32:
97 97 text_out = text_out[:29] + '...'
98 98
99 99 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
100 100
101 101 def _repr_pretty_(self, p, cycle):
102 102 pyout = self.metadata['pyout'] or {'data':{}}
103 103 text_out = pyout['data'].get('text/plain', '')
104 104
105 105 if not text_out:
106 106 return
107 107
108 108 try:
109 109 ip = get_ipython()
110 110 except NameError:
111 111 colors = "NoColor"
112 112 else:
113 113 colors = ip.colors
114 114
115 115 if colors == "NoColor":
116 116 out = normal = ""
117 117 else:
118 118 out = TermColors.Red
119 119 normal = TermColors.Normal
120 120
121 121 if '\n' in text_out and not text_out.startswith('\n'):
122 122 # add newline for multiline reprs
123 123 text_out = '\n' + text_out
124 124
125 125 p.text(
126 126 out + u'Out[%i:%i]: ' % (
127 127 self.metadata['engine_id'], self.execution_count
128 128 ) + normal + text_out
129 129 )
130 130
131 131 def _repr_html_(self):
132 132 pyout = self.metadata['pyout'] or {'data':{}}
133 133 return pyout['data'].get("text/html")
134 134
135 135 def _repr_latex_(self):
136 136 pyout = self.metadata['pyout'] or {'data':{}}
137 137 return pyout['data'].get("text/latex")
138 138
139 139 def _repr_json_(self):
140 140 pyout = self.metadata['pyout'] or {'data':{}}
141 141 return pyout['data'].get("application/json")
142 142
143 143 def _repr_javascript_(self):
144 144 pyout = self.metadata['pyout'] or {'data':{}}
145 145 return pyout['data'].get("application/javascript")
146 146
147 147 def _repr_png_(self):
148 148 pyout = self.metadata['pyout'] or {'data':{}}
149 149 return pyout['data'].get("image/png")
150 150
151 151 def _repr_jpeg_(self):
152 152 pyout = self.metadata['pyout'] or {'data':{}}
153 153 return pyout['data'].get("image/jpeg")
154 154
155 155 def _repr_svg_(self):
156 156 pyout = self.metadata['pyout'] or {'data':{}}
157 157 return pyout['data'].get("image/svg+xml")
158 158
159 159
160 160 class Metadata(dict):
161 161 """Subclass of dict for initializing metadata values.
162 162
163 163 Attribute access works on keys.
164 164
165 165 These objects have a strict set of keys - errors will raise if you try
166 166 to add new keys.
167 167 """
168 168 def __init__(self, *args, **kwargs):
169 169 dict.__init__(self)
170 170 md = {'msg_id' : None,
171 171 'submitted' : None,
172 172 'started' : None,
173 173 'completed' : None,
174 174 'received' : None,
175 175 'engine_uuid' : None,
176 176 'engine_id' : None,
177 177 'follow' : None,
178 178 'after' : None,
179 179 'status' : None,
180 180
181 181 'pyin' : None,
182 182 'pyout' : None,
183 183 'pyerr' : None,
184 184 'stdout' : '',
185 185 'stderr' : '',
186 186 'outputs' : [],
187 187 'outputs_ready' : False,
188 188 }
189 189 self.update(md)
190 190 self.update(dict(*args, **kwargs))
191 191
192 192 def __getattr__(self, key):
193 193 """getattr aliased to getitem"""
194 194 if key in self.iterkeys():
195 195 return self[key]
196 196 else:
197 197 raise AttributeError(key)
198 198
199 199 def __setattr__(self, key, value):
200 200 """setattr aliased to setitem, with strict"""
201 201 if key in self.iterkeys():
202 202 self[key] = value
203 203 else:
204 204 raise AttributeError(key)
205 205
206 206 def __setitem__(self, key, value):
207 207 """strict static key enforcement"""
208 208 if key in self.iterkeys():
209 209 dict.__setitem__(self, key, value)
210 210 else:
211 211 raise KeyError(key)
212 212
213 213
214 214 class Client(HasTraits):
215 215 """A semi-synchronous client to the IPython ZMQ cluster
216 216
217 217 Parameters
218 218 ----------
219 219
220 220 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
221 221 Connection information for the Hub's registration. If a json connector
222 222 file is given, then likely no further configuration is necessary.
223 223 [Default: use profile]
224 224 profile : bytes
225 225 The name of the Cluster profile to be used to find connector information.
226 226 If run from an IPython application, the default profile will be the same
227 227 as the running application, otherwise it will be 'default'.
228 228 context : zmq.Context
229 229 Pass an existing zmq.Context instance, otherwise the client will create its own.
230 230 debug : bool
231 231 flag for lots of message printing for debug purposes
232 232 timeout : int/float
233 233 time (in seconds) to wait for connection replies from the Hub
234 234 [Default: 10]
235 235
236 236 #-------------- session related args ----------------
237 237
238 238 config : Config object
239 239 If specified, this will be relayed to the Session for configuration
240 240 username : str
241 241 set username for the session object
242 242 packer : str (import_string) or callable
243 243 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
244 244 function to serialize messages. Must support same input as
245 245 JSON, and output must be bytes.
246 246 You can pass a callable directly as `pack`
247 247 unpacker : str (import_string) or callable
248 248 The inverse of packer. Only necessary if packer is specified as *not* one
249 249 of 'json' or 'pickle'.
250 250
251 251 #-------------- ssh related args ----------------
252 252 # These are args for configuring the ssh tunnel to be used
253 253 # credentials are used to forward connections over ssh to the Controller
254 254 # Note that the ip given in `addr` needs to be relative to sshserver
255 255 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
256 256 # and set sshserver as the same machine the Controller is on. However,
257 257 # the only requirement is that sshserver is able to see the Controller
258 258 # (i.e. is within the same trusted network).
259 259
260 260 sshserver : str
261 261 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
262 262 If keyfile or password is specified, and this is not, it will default to
263 263 the ip given in addr.
264 264 sshkey : str; path to ssh private key file
265 265 This specifies a key to be used in ssh login, default None.
266 266 Regular default ssh keys will be used without specifying this argument.
267 267 password : str
268 268 Your ssh password to sshserver. Note that if this is left None,
269 269 you will be prompted for it if passwordless key based login is unavailable.
270 270 paramiko : bool
271 271 flag for whether to use paramiko instead of shell ssh for tunneling.
272 272 [default: True on win32, False else]
273 273
274 274 ------- exec authentication args -------
275 275 If even localhost is untrusted, you can have some protection against
276 276 unauthorized execution by signing messages with HMAC digests.
277 277 Messages are still sent as cleartext, so if someone can snoop your
278 278 loopback traffic this will not protect your privacy, but will prevent
279 279 unauthorized execution.
280 280
281 281 exec_key : str
282 282 an authentication key or file containing a key
283 283 default: None
284 284
285 285
286 286 Attributes
287 287 ----------
288 288
289 289 ids : list of int engine IDs
290 290 requesting the ids attribute always synchronizes
291 291 the registration state. To request ids without synchronization,
292 292 use semi-private _ids attributes.
293 293
294 294 history : list of msg_ids
295 295 a list of msg_ids, keeping track of all the execution
296 296 messages you have submitted in order.
297 297
298 298 outstanding : set of msg_ids
299 299 a set of msg_ids that have been submitted, but whose
300 300 results have not yet been received.
301 301
302 302 results : dict
303 303 a dict of all our results, keyed by msg_id
304 304
305 305 block : bool
306 306 determines default behavior when block not specified
307 307 in execution methods
308 308
309 309 Methods
310 310 -------
311 311
312 312 spin
313 313 flushes incoming results and registration state changes
314 314 control methods spin, and requesting `ids` also ensures up to date
315 315
316 316 wait
317 317 wait on one or more msg_ids
318 318
319 319 execution methods
320 320 apply
321 321 legacy: execute, run
322 322
323 323 data movement
324 324 push, pull, scatter, gather
325 325
326 326 query methods
327 327 queue_status, get_result, purge, result_status
328 328
329 329 control methods
330 330 abort, shutdown
331 331
332 332 """
333 333
334 334
335 335 block = Bool(False)
336 336 outstanding = Set()
337 337 results = Instance('collections.defaultdict', (dict,))
338 338 metadata = Instance('collections.defaultdict', (Metadata,))
339 339 history = List()
340 340 debug = Bool(False)
341 341 _spin_thread = Any()
342 342 _stop_spinning = Any()
343 343
344 344 profile=Unicode()
345 345 def _profile_default(self):
346 346 if BaseIPythonApplication.initialized():
347 347 # an IPython app *might* be running, try to get its profile
348 348 try:
349 349 return BaseIPythonApplication.instance().profile
350 350 except (AttributeError, MultipleInstanceError):
351 351 # could be a *different* subclass of config.Application,
352 352 # which would raise one of these two errors.
353 353 return u'default'
354 354 else:
355 355 return u'default'
356 356
357 357
358 358 _outstanding_dict = Instance('collections.defaultdict', (set,))
359 359 _ids = List()
360 360 _connected=Bool(False)
361 361 _ssh=Bool(False)
362 362 _context = Instance('zmq.Context')
363 363 _config = Dict()
364 364 _engines=Instance(util.ReverseDict, (), {})
365 365 # _hub_socket=Instance('zmq.Socket')
366 366 _query_socket=Instance('zmq.Socket')
367 367 _control_socket=Instance('zmq.Socket')
368 368 _iopub_socket=Instance('zmq.Socket')
369 369 _notification_socket=Instance('zmq.Socket')
370 370 _mux_socket=Instance('zmq.Socket')
371 371 _task_socket=Instance('zmq.Socket')
372 372 _task_scheme=Unicode()
373 373 _closed = False
374 374 _ignored_control_replies=Integer(0)
375 375 _ignored_hub_replies=Integer(0)
376 376
377 377 def __new__(self, *args, **kw):
378 378 # don't raise on positional args
379 379 return HasTraits.__new__(self, **kw)
380 380
381 381 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
382 382 context=None, debug=False, exec_key=None,
383 383 sshserver=None, sshkey=None, password=None, paramiko=None,
384 384 timeout=10, **extra_args
385 385 ):
386 386 if profile:
387 387 super(Client, self).__init__(debug=debug, profile=profile)
388 388 else:
389 389 super(Client, self).__init__(debug=debug)
390 390 if context is None:
391 391 context = zmq.Context.instance()
392 392 self._context = context
393 393 self._stop_spinning = Event()
394 394
395 395 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
396 396 if self._cd is not None:
397 397 if url_or_file is None:
398 398 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
399 399 if url_or_file is None:
400 400 raise ValueError(
401 401 "I can't find enough information to connect to a hub!"
402 402 " Please specify at least one of url_or_file or profile."
403 403 )
404 404
405 405 if not util.is_url(url_or_file):
406 406 # it's not a url, try for a file
407 407 if not os.path.exists(url_or_file):
408 408 if self._cd:
409 409 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
410 410 if not os.path.exists(url_or_file):
411 411 raise IOError("Connection file not found: %r" % url_or_file)
412 412 with open(url_or_file) as f:
413 413 cfg = json.loads(f.read())
414 414 else:
415 415 cfg = {'url':url_or_file}
416 416
417 417 # sync defaults from args, json:
418 418 if sshserver:
419 419 cfg['ssh'] = sshserver
420 420 if exec_key:
421 421 cfg['exec_key'] = exec_key
422 422 exec_key = cfg['exec_key']
423 423 location = cfg.setdefault('location', None)
424 424 cfg['url'] = util.disambiguate_url(cfg['url'], location)
425 425 url = cfg['url']
426 426 proto,addr,port = util.split_url(url)
427 427 if location is not None and addr == '127.0.0.1':
428 428 # location specified, and connection is expected to be local
429 429 if location not in LOCAL_IPS and not sshserver:
430 430 # load ssh from JSON *only* if the controller is not on
431 431 # this machine
432 432 sshserver=cfg['ssh']
433 433 if location not in LOCAL_IPS and not sshserver:
434 434 # warn if no ssh specified, but SSH is probably needed
435 435 # This is only a warning, because the most likely cause
436 436 # is a local Controller on a laptop whose IP is dynamic
437 437 warnings.warn("""
438 438 Controller appears to be listening on localhost, but not on this machine.
439 439 If this is true, you should specify Client(...,sshserver='you@%s')
440 440 or instruct your controller to listen on an external IP."""%location,
441 441 RuntimeWarning)
442 442 elif not sshserver:
443 443 # otherwise sync with cfg
444 444 sshserver = cfg['ssh']
445 445
446 446 self._config = cfg
447 447
448 448 self._ssh = bool(sshserver or sshkey or password)
449 449 if self._ssh and sshserver is None:
450 450 # default to ssh via localhost
451 451 sshserver = url.split('://')[1].split(':')[0]
452 452 if self._ssh and password is None:
453 453 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
454 454 password=False
455 455 else:
456 456 password = getpass("SSH Password for %s: "%sshserver)
457 457 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
458 458
459 459 # configure and construct the session
460 460 if exec_key is not None:
461 461 if os.path.isfile(exec_key):
462 462 extra_args['keyfile'] = exec_key
463 463 else:
464 464 exec_key = cast_bytes(exec_key)
465 465 extra_args['key'] = exec_key
466 466 self.session = Session(**extra_args)
467 467
468 468 self._query_socket = self._context.socket(zmq.DEALER)
469 469 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
470 470 if self._ssh:
471 471 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
472 472 else:
473 473 self._query_socket.connect(url)
474 474
475 475 self.session.debug = self.debug
476 476
477 477 self._notification_handlers = {'registration_notification' : self._register_engine,
478 478 'unregistration_notification' : self._unregister_engine,
479 479 'shutdown_notification' : lambda msg: self.close(),
480 480 }
481 481 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
482 482 'apply_reply' : self._handle_apply_reply}
483 483 self._connect(sshserver, ssh_kwargs, timeout)
484 484
485 485 # last step: setup magics, if we are in IPython:
486 486
487 487 try:
488 488 ip = get_ipython()
489 489 except NameError:
490 490 return
491 491 else:
492 492 if 'px' not in ip.magics_manager.magics:
493 493 # in IPython but we are the first Client.
494 494 # activate a default view for parallel magics.
495 495 self.activate()
496 496
497 497 def __del__(self):
498 498 """cleanup sockets, but _not_ context."""
499 499 self.close()
500 500
501 501 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
502 502 if ipython_dir is None:
503 503 ipython_dir = get_ipython_dir()
504 504 if profile_dir is not None:
505 505 try:
506 506 self._cd = ProfileDir.find_profile_dir(profile_dir)
507 507 return
508 508 except ProfileDirError:
509 509 pass
510 510 elif profile is not None:
511 511 try:
512 512 self._cd = ProfileDir.find_profile_dir_by_name(
513 513 ipython_dir, profile)
514 514 return
515 515 except ProfileDirError:
516 516 pass
517 517 self._cd = None
518 518
519 519 def _update_engines(self, engines):
520 520 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
521 521 for k,v in engines.iteritems():
522 522 eid = int(k)
523 523 self._engines[eid] = v
524 524 self._ids.append(eid)
525 525 self._ids = sorted(self._ids)
526 526 if sorted(self._engines.keys()) != range(len(self._engines)) and \
527 527 self._task_scheme == 'pure' and self._task_socket:
528 528 self._stop_scheduling_tasks()
529 529
530 530 def _stop_scheduling_tasks(self):
531 531 """Stop scheduling tasks because an engine has been unregistered
532 532 from a pure ZMQ scheduler.
533 533 """
534 534 self._task_socket.close()
535 535 self._task_socket = None
536 536 msg = "An engine has been unregistered, and we are using pure " +\
537 537 "ZMQ task scheduling. Task farming will be disabled."
538 538 if self.outstanding:
539 539 msg += " If you were running tasks when this happened, " +\
540 540 "some `outstanding` msg_ids may never resolve."
541 541 warnings.warn(msg, RuntimeWarning)
542 542
543 543 def _build_targets(self, targets):
544 544 """Turn valid target IDs or 'all' into two lists:
545 545 (int_ids, uuids).
546 546 """
547 547 if not self._ids:
548 548 # flush notification socket if no engines yet, just in case
549 549 if not self.ids:
550 550 raise error.NoEnginesRegistered("Can't build targets without any engines")
551 551
552 552 if targets is None:
553 553 targets = self._ids
554 554 elif isinstance(targets, basestring):
555 555 if targets.lower() == 'all':
556 556 targets = self._ids
557 557 else:
558 558 raise TypeError("%r not valid str target, must be 'all'"%(targets))
559 559 elif isinstance(targets, int):
560 560 if targets < 0:
561 561 targets = self.ids[targets]
562 562 if targets not in self._ids:
563 563 raise IndexError("No such engine: %i"%targets)
564 564 targets = [targets]
565 565
566 566 if isinstance(targets, slice):
567 567 indices = range(len(self._ids))[targets]
568 568 ids = self.ids
569 569 targets = [ ids[i] for i in indices ]
570 570
571 571 if not isinstance(targets, (tuple, list, xrange)):
572 572 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
573 573
574 574 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
575 575
576 576 def _connect(self, sshserver, ssh_kwargs, timeout):
577 577 """setup all our socket connections to the cluster. This is called from
578 578 __init__."""
579 579
580 580 # Maybe allow reconnecting?
581 581 if self._connected:
582 582 return
583 583 self._connected=True
584 584
585 585 def connect_socket(s, url):
586 586 url = util.disambiguate_url(url, self._config['location'])
587 587 if self._ssh:
588 588 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
589 589 else:
590 590 return s.connect(url)
591 591
592 592 self.session.send(self._query_socket, 'connection_request')
593 593 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
594 594 poller = zmq.Poller()
595 595 poller.register(self._query_socket, zmq.POLLIN)
596 596 # poll expects milliseconds, timeout is seconds
597 597 evts = poller.poll(timeout*1000)
598 598 if not evts:
599 599 raise error.TimeoutError("Hub connection request timed out")
600 600 idents,msg = self.session.recv(self._query_socket,mode=0)
601 601 if self.debug:
602 602 pprint(msg)
603 603 msg = Message(msg)
604 604 content = msg.content
605 605 self._config['registration'] = dict(content)
606 606 if content.status == 'ok':
607 607 ident = self.session.bsession
608 608 if content.mux:
609 609 self._mux_socket = self._context.socket(zmq.DEALER)
610 610 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
611 611 connect_socket(self._mux_socket, content.mux)
612 612 if content.task:
613 613 self._task_scheme, task_addr = content.task
614 614 self._task_socket = self._context.socket(zmq.DEALER)
615 615 self._task_socket.setsockopt(zmq.IDENTITY, ident)
616 616 connect_socket(self._task_socket, task_addr)
617 617 if content.notification:
618 618 self._notification_socket = self._context.socket(zmq.SUB)
619 619 connect_socket(self._notification_socket, content.notification)
620 620 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
621 621 # if content.query:
622 622 # self._query_socket = self._context.socket(zmq.DEALER)
623 623 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
624 624 # connect_socket(self._query_socket, content.query)
625 625 if content.control:
626 626 self._control_socket = self._context.socket(zmq.DEALER)
627 627 self._control_socket.setsockopt(zmq.IDENTITY, ident)
628 628 connect_socket(self._control_socket, content.control)
629 629 if content.iopub:
630 630 self._iopub_socket = self._context.socket(zmq.SUB)
631 631 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
632 632 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
633 633 connect_socket(self._iopub_socket, content.iopub)
634 634 self._update_engines(dict(content.engines))
635 635 else:
636 636 self._connected = False
637 637 raise Exception("Failed to connect!")
638 638
639 639 #--------------------------------------------------------------------------
640 640 # handlers and callbacks for incoming messages
641 641 #--------------------------------------------------------------------------
642 642
643 643 def _unwrap_exception(self, content):
644 644 """unwrap exception, and remap engine_id to int."""
645 645 e = error.unwrap_exception(content)
646 646 # print e.traceback
647 647 if e.engine_info:
648 648 e_uuid = e.engine_info['engine_uuid']
649 649 eid = self._engines[e_uuid]
650 650 e.engine_info['engine_id'] = eid
651 651 return e
652 652
653 653 def _extract_metadata(self, header, parent, content):
654 654 md = {'msg_id' : parent['msg_id'],
655 655 'received' : datetime.now(),
656 656 'engine_uuid' : header.get('engine', None),
657 657 'follow' : parent.get('follow', []),
658 658 'after' : parent.get('after', []),
659 659 'status' : content['status'],
660 660 }
661 661
662 662 if md['engine_uuid'] is not None:
663 663 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
664 664
665 665 if 'date' in parent:
666 666 md['submitted'] = parent['date']
667 667 if 'started' in header:
668 668 md['started'] = header['started']
669 669 if 'date' in header:
670 670 md['completed'] = header['date']
671 671 return md
672 672
673 673 def _register_engine(self, msg):
674 674 """Register a new engine, and update our connection info."""
675 675 content = msg['content']
676 676 eid = content['id']
677 677 d = {eid : content['queue']}
678 678 self._update_engines(d)
679 679
680 680 def _unregister_engine(self, msg):
681 681 """Unregister an engine that has died."""
682 682 content = msg['content']
683 683 eid = int(content['id'])
684 684 if eid in self._ids:
685 685 self._ids.remove(eid)
686 686 uuid = self._engines.pop(eid)
687 687
688 688 self._handle_stranded_msgs(eid, uuid)
689 689
690 690 if self._task_socket and self._task_scheme == 'pure':
691 691 self._stop_scheduling_tasks()
692 692
693 693 def _handle_stranded_msgs(self, eid, uuid):
694 694 """Handle messages known to be on an engine when the engine unregisters.
695 695
696 696 It is possible that this will fire prematurely - that is, an engine will
697 697 go down after completing a result, and the client will be notified
698 698 of the unregistration and later receive the successful result.
699 699 """
700 700
701 701 outstanding = self._outstanding_dict[uuid]
702 702
703 703 for msg_id in list(outstanding):
704 704 if msg_id in self.results:
705 705 # we already
706 706 continue
707 707 try:
708 708 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
709 709 except:
710 710 content = error.wrap_exception()
711 711 # build a fake message:
712 712 parent = {}
713 713 header = {}
714 714 parent['msg_id'] = msg_id
715 715 header['engine'] = uuid
716 716 header['date'] = datetime.now()
717 717 msg = dict(parent_header=parent, header=header, content=content)
718 718 self._handle_apply_reply(msg)
719 719
720 720 def _handle_execute_reply(self, msg):
721 721 """Save the reply to an execute_request into our results.
722 722
723 723 execute messages are never actually used. apply is used instead.
724 724 """
725 725
726 726 parent = msg['parent_header']
727 727 msg_id = parent['msg_id']
728 728 if msg_id not in self.outstanding:
729 729 if msg_id in self.history:
730 730 print ("got stale result: %s"%msg_id)
731 731 else:
732 732 print ("got unknown result: %s"%msg_id)
733 733 else:
734 734 self.outstanding.remove(msg_id)
735 735
736 736 content = msg['content']
737 737 header = msg['header']
738 738
739 739 # construct metadata:
740 740 md = self.metadata[msg_id]
741 741 md.update(self._extract_metadata(header, parent, content))
742 742 # is this redundant?
743 743 self.metadata[msg_id] = md
744 744
745 745 e_outstanding = self._outstanding_dict[md['engine_uuid']]
746 746 if msg_id in e_outstanding:
747 747 e_outstanding.remove(msg_id)
748 748
749 749 # construct result:
750 750 if content['status'] == 'ok':
751 751 self.results[msg_id] = ExecuteReply(msg_id, content, md)
752 752 elif content['status'] == 'aborted':
753 753 self.results[msg_id] = error.TaskAborted(msg_id)
754 754 elif content['status'] == 'resubmitted':
755 755 # TODO: handle resubmission
756 756 pass
757 757 else:
758 758 self.results[msg_id] = self._unwrap_exception(content)
759 759
760 760 def _handle_apply_reply(self, msg):
761 761 """Save the reply to an apply_request into our results."""
762 762 parent = msg['parent_header']
763 763 msg_id = parent['msg_id']
764 764 if msg_id not in self.outstanding:
765 765 if msg_id in self.history:
766 766 print ("got stale result: %s"%msg_id)
767 767 print self.results[msg_id]
768 768 print msg
769 769 else:
770 770 print ("got unknown result: %s"%msg_id)
771 771 else:
772 772 self.outstanding.remove(msg_id)
773 773 content = msg['content']
774 774 header = msg['header']
775 775
776 776 # construct metadata:
777 777 md = self.metadata[msg_id]
778 778 md.update(self._extract_metadata(header, parent, content))
779 779 # is this redundant?
780 780 self.metadata[msg_id] = md
781 781
782 782 e_outstanding = self._outstanding_dict[md['engine_uuid']]
783 783 if msg_id in e_outstanding:
784 784 e_outstanding.remove(msg_id)
785 785
786 786 # construct result:
787 787 if content['status'] == 'ok':
788 788 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
789 789 elif content['status'] == 'aborted':
790 790 self.results[msg_id] = error.TaskAborted(msg_id)
791 791 elif content['status'] == 'resubmitted':
792 792 # TODO: handle resubmission
793 793 pass
794 794 else:
795 795 self.results[msg_id] = self._unwrap_exception(content)
796 796
797 797 def _flush_notifications(self):
798 798 """Flush notifications of engine registrations waiting
799 799 in ZMQ queue."""
800 800 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
801 801 while msg is not None:
802 802 if self.debug:
803 803 pprint(msg)
804 804 msg_type = msg['header']['msg_type']
805 805 handler = self._notification_handlers.get(msg_type, None)
806 806 if handler is None:
807 807 raise Exception("Unhandled message type: %s"%msg.msg_type)
808 808 else:
809 809 handler(msg)
810 810 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
811 811
812 812 def _flush_results(self, sock):
813 813 """Flush task or queue results waiting in ZMQ queue."""
814 814 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
815 815 while msg is not None:
816 816 if self.debug:
817 817 pprint(msg)
818 818 msg_type = msg['header']['msg_type']
819 819 handler = self._queue_handlers.get(msg_type, None)
820 820 if handler is None:
821 821 raise Exception("Unhandled message type: %s"%msg.msg_type)
822 822 else:
823 823 handler(msg)
824 824 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
825 825
826 826 def _flush_control(self, sock):
827 827 """Flush replies from the control channel waiting
828 828 in the ZMQ queue.
829 829
830 830 Currently: ignore them."""
831 831 if self._ignored_control_replies <= 0:
832 832 return
833 833 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
834 834 while msg is not None:
835 835 self._ignored_control_replies -= 1
836 836 if self.debug:
837 837 pprint(msg)
838 838 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
839 839
840 840 def _flush_ignored_control(self):
841 841 """flush ignored control replies"""
842 842 while self._ignored_control_replies > 0:
843 843 self.session.recv(self._control_socket)
844 844 self._ignored_control_replies -= 1
845 845
846 846 def _flush_ignored_hub_replies(self):
847 847 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
848 848 while msg is not None:
849 849 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
850 850
851 851 def _flush_iopub(self, sock):
852 852 """Flush replies from the iopub channel waiting
853 853 in the ZMQ queue.
854 854 """
855 855 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
856 856 while msg is not None:
857 857 if self.debug:
858 858 pprint(msg)
859 859 parent = msg['parent_header']
860 860 # ignore IOPub messages with no parent.
861 861 # Caused by print statements or warnings from before the first execution.
862 862 if not parent:
863 863 continue
864 864 msg_id = parent['msg_id']
865 865 content = msg['content']
866 866 header = msg['header']
867 867 msg_type = msg['header']['msg_type']
868 868
869 869 # init metadata:
870 870 md = self.metadata[msg_id]
871 871
872 872 if msg_type == 'stream':
873 873 name = content['name']
874 874 s = md[name] or ''
875 875 md[name] = s + content['data']
876 876 elif msg_type == 'pyerr':
877 877 md.update({'pyerr' : self._unwrap_exception(content)})
878 878 elif msg_type == 'pyin':
879 879 md.update({'pyin' : content['code']})
880 880 elif msg_type == 'display_data':
881 881 md['outputs'].append(content)
882 882 elif msg_type == 'pyout':
883 883 md['pyout'] = content
884 884 elif msg_type == 'status':
885 885 # idle message comes after all outputs
886 886 if content['execution_state'] == 'idle':
887 887 md['outputs_ready'] = True
888 888 else:
889 889 # unhandled msg_type (status, etc.)
890 890 pass
891 891
892 892 # reduntant?
893 893 self.metadata[msg_id] = md
894 894
895 895 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
896 896
897 897 #--------------------------------------------------------------------------
898 898 # len, getitem
899 899 #--------------------------------------------------------------------------
900 900
901 901 def __len__(self):
902 902 """len(client) returns # of engines."""
903 903 return len(self.ids)
904 904
905 905 def __getitem__(self, key):
906 906 """index access returns DirectView multiplexer objects
907 907
908 908 Must be int, slice, or list/tuple/xrange of ints"""
909 909 if not isinstance(key, (int, slice, tuple, list, xrange)):
910 910 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
911 911 else:
912 912 return self.direct_view(key)
913 913
914 914 #--------------------------------------------------------------------------
915 915 # Begin public methods
916 916 #--------------------------------------------------------------------------
917 917
918 918 @property
919 919 def ids(self):
920 920 """Always up-to-date ids property."""
921 921 self._flush_notifications()
922 922 # always copy:
923 923 return list(self._ids)
924 924
925 925 def activate(self, targets='all', suffix=''):
926 926 """Create a DirectView and register it with IPython magics
927 927
928 928 Defines the magics `%px, %autopx, %pxresult, %%px`
929 929
930 930 Parameters
931 931 ----------
932 932
933 933 targets: int, list of ints, or 'all'
934 934 The engines on which the view's magics will run
935 935 suffix: str [default: '']
936 936 The suffix, if any, for the magics. This allows you to have
937 937 multiple views associated with parallel magics at the same time.
938 938
939 939 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
940 940 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
941 941 on engine 0.
942 942 """
943 943 view = self.direct_view(targets)
944 944 view.block = True
945 945 view.activate(suffix)
946 946 return view
947 947
948 948 def close(self):
949 949 if self._closed:
950 950 return
951 951 self.stop_spin_thread()
952 952 snames = filter(lambda n: n.endswith('socket'), dir(self))
953 953 for socket in map(lambda name: getattr(self, name), snames):
954 954 if isinstance(socket, zmq.Socket) and not socket.closed:
955 955 socket.close()
956 956 self._closed = True
957 957
958 958 def _spin_every(self, interval=1):
959 959 """target func for use in spin_thread"""
960 960 while True:
961 961 if self._stop_spinning.is_set():
962 962 return
963 963 time.sleep(interval)
964 964 self.spin()
965 965
966 966 def spin_thread(self, interval=1):
967 967 """call Client.spin() in a background thread on some regular interval
968 968
969 969 This helps ensure that messages don't pile up too much in the zmq queue
970 970 while you are working on other things, or just leaving an idle terminal.
971 971
972 972 It also helps limit potential padding of the `received` timestamp
973 973 on AsyncResult objects, used for timings.
974 974
975 975 Parameters
976 976 ----------
977 977
978 978 interval : float, optional
979 979 The interval on which to spin the client in the background thread
980 980 (simply passed to time.sleep).
981 981
982 982 Notes
983 983 -----
984 984
985 985 For precision timing, you may want to use this method to put a bound
986 986 on the jitter (in seconds) in `received` timestamps used
987 987 in AsyncResult.wall_time.
988 988
989 989 """
990 990 if self._spin_thread is not None:
991 991 self.stop_spin_thread()
992 992 self._stop_spinning.clear()
993 993 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
994 994 self._spin_thread.daemon = True
995 995 self._spin_thread.start()
996 996
997 997 def stop_spin_thread(self):
998 998 """stop background spin_thread, if any"""
999 999 if self._spin_thread is not None:
1000 1000 self._stop_spinning.set()
1001 1001 self._spin_thread.join()
1002 1002 self._spin_thread = None
1003 1003
1004 1004 def spin(self):
1005 1005 """Flush any registration notifications and execution results
1006 1006 waiting in the ZMQ queue.
1007 1007 """
1008 1008 if self._notification_socket:
1009 1009 self._flush_notifications()
1010 1010 if self._iopub_socket:
1011 1011 self._flush_iopub(self._iopub_socket)
1012 1012 if self._mux_socket:
1013 1013 self._flush_results(self._mux_socket)
1014 1014 if self._task_socket:
1015 1015 self._flush_results(self._task_socket)
1016 1016 if self._control_socket:
1017 1017 self._flush_control(self._control_socket)
1018 1018 if self._query_socket:
1019 1019 self._flush_ignored_hub_replies()
1020 1020
1021 1021 def wait(self, jobs=None, timeout=-1):
1022 1022 """waits on one or more `jobs`, for up to `timeout` seconds.
1023 1023
1024 1024 Parameters
1025 1025 ----------
1026 1026
1027 1027 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1028 1028 ints are indices to self.history
1029 1029 strs are msg_ids
1030 1030 default: wait on all outstanding messages
1031 1031 timeout : float
1032 1032 a time in seconds, after which to give up.
1033 1033 default is -1, which means no timeout
1034 1034
1035 1035 Returns
1036 1036 -------
1037 1037
1038 1038 True : when all msg_ids are done
1039 1039 False : timeout reached, some msg_ids still outstanding
1040 1040 """
1041 1041 tic = time.time()
1042 1042 if jobs is None:
1043 1043 theids = self.outstanding
1044 1044 else:
1045 1045 if isinstance(jobs, (int, basestring, AsyncResult)):
1046 1046 jobs = [jobs]
1047 1047 theids = set()
1048 1048 for job in jobs:
1049 1049 if isinstance(job, int):
1050 1050 # index access
1051 1051 job = self.history[job]
1052 1052 elif isinstance(job, AsyncResult):
1053 1053 map(theids.add, job.msg_ids)
1054 1054 continue
1055 1055 theids.add(job)
1056 1056 if not theids.intersection(self.outstanding):
1057 1057 return True
1058 1058 self.spin()
1059 1059 while theids.intersection(self.outstanding):
1060 1060 if timeout >= 0 and ( time.time()-tic ) > timeout:
1061 1061 break
1062 1062 time.sleep(1e-3)
1063 1063 self.spin()
1064 1064 return len(theids.intersection(self.outstanding)) == 0
1065 1065
1066 1066 #--------------------------------------------------------------------------
1067 1067 # Control methods
1068 1068 #--------------------------------------------------------------------------
1069 1069
1070 1070 @spin_first
1071 1071 def clear(self, targets=None, block=None):
1072 1072 """Clear the namespace in target(s)."""
1073 1073 block = self.block if block is None else block
1074 1074 targets = self._build_targets(targets)[0]
1075 1075 for t in targets:
1076 1076 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1077 1077 error = False
1078 1078 if block:
1079 1079 self._flush_ignored_control()
1080 1080 for i in range(len(targets)):
1081 1081 idents,msg = self.session.recv(self._control_socket,0)
1082 1082 if self.debug:
1083 1083 pprint(msg)
1084 1084 if msg['content']['status'] != 'ok':
1085 1085 error = self._unwrap_exception(msg['content'])
1086 1086 else:
1087 1087 self._ignored_control_replies += len(targets)
1088 1088 if error:
1089 1089 raise error
1090 1090
1091 1091
1092 1092 @spin_first
1093 1093 def abort(self, jobs=None, targets=None, block=None):
1094 1094 """Abort specific jobs from the execution queues of target(s).
1095 1095
1096 1096 This is a mechanism to prevent jobs that have already been submitted
1097 1097 from executing.
1098 1098
1099 1099 Parameters
1100 1100 ----------
1101 1101
1102 1102 jobs : msg_id, list of msg_ids, or AsyncResult
1103 1103 The jobs to be aborted
1104 1104
1105 1105 If unspecified/None: abort all outstanding jobs.
1106 1106
1107 1107 """
1108 1108 block = self.block if block is None else block
1109 1109 jobs = jobs if jobs is not None else list(self.outstanding)
1110 1110 targets = self._build_targets(targets)[0]
1111 1111
1112 1112 msg_ids = []
1113 1113 if isinstance(jobs, (basestring,AsyncResult)):
1114 1114 jobs = [jobs]
1115 1115 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1116 1116 if bad_ids:
1117 1117 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1118 1118 for j in jobs:
1119 1119 if isinstance(j, AsyncResult):
1120 1120 msg_ids.extend(j.msg_ids)
1121 1121 else:
1122 1122 msg_ids.append(j)
1123 1123 content = dict(msg_ids=msg_ids)
1124 1124 for t in targets:
1125 1125 self.session.send(self._control_socket, 'abort_request',
1126 1126 content=content, ident=t)
1127 1127 error = False
1128 1128 if block:
1129 1129 self._flush_ignored_control()
1130 1130 for i in range(len(targets)):
1131 1131 idents,msg = self.session.recv(self._control_socket,0)
1132 1132 if self.debug:
1133 1133 pprint(msg)
1134 1134 if msg['content']['status'] != 'ok':
1135 1135 error = self._unwrap_exception(msg['content'])
1136 1136 else:
1137 1137 self._ignored_control_replies += len(targets)
1138 1138 if error:
1139 1139 raise error
1140 1140
1141 1141 @spin_first
1142 1142 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1143 1143 """Terminates one or more engine processes, optionally including the hub.
1144 1144
1145 1145 Parameters
1146 1146 ----------
1147 1147
1148 1148 targets: list of ints or 'all' [default: all]
1149 1149 Which engines to shutdown.
1150 1150 hub: bool [default: False]
1151 1151 Whether to include the Hub. hub=True implies targets='all'.
1152 1152 block: bool [default: self.block]
1153 1153 Whether to wait for clean shutdown replies or not.
1154 1154 restart: bool [default: False]
1155 1155 NOT IMPLEMENTED
1156 1156 whether to restart engines after shutting them down.
1157 1157 """
1158
1158 from IPython.parallel.error import NoEnginesRegistered
1159 1159 if restart:
1160 1160 raise NotImplementedError("Engine restart is not yet implemented")
1161 1161
1162 1162 block = self.block if block is None else block
1163 1163 if hub:
1164 1164 targets = 'all'
1165 targets = self._build_targets(targets)[0]
1165 try:
1166 targets = self._build_targets(targets)[0]
1167 except NoEnginesRegistered:
1168 targets = []
1166 1169 for t in targets:
1167 1170 self.session.send(self._control_socket, 'shutdown_request',
1168 1171 content={'restart':restart},ident=t)
1169 1172 error = False
1170 1173 if block or hub:
1171 1174 self._flush_ignored_control()
1172 1175 for i in range(len(targets)):
1173 1176 idents,msg = self.session.recv(self._control_socket, 0)
1174 1177 if self.debug:
1175 1178 pprint(msg)
1176 1179 if msg['content']['status'] != 'ok':
1177 1180 error = self._unwrap_exception(msg['content'])
1178 1181 else:
1179 1182 self._ignored_control_replies += len(targets)
1180 1183
1181 1184 if hub:
1182 1185 time.sleep(0.25)
1183 1186 self.session.send(self._query_socket, 'shutdown_request')
1184 1187 idents,msg = self.session.recv(self._query_socket, 0)
1185 1188 if self.debug:
1186 1189 pprint(msg)
1187 1190 if msg['content']['status'] != 'ok':
1188 1191 error = self._unwrap_exception(msg['content'])
1189 1192
1190 1193 if error:
1191 1194 raise error
1192 1195
1193 1196 #--------------------------------------------------------------------------
1194 1197 # Execution related methods
1195 1198 #--------------------------------------------------------------------------
1196 1199
1197 1200 def _maybe_raise(self, result):
1198 1201 """wrapper for maybe raising an exception if apply failed."""
1199 1202 if isinstance(result, error.RemoteError):
1200 1203 raise result
1201 1204
1202 1205 return result
1203 1206
1204 1207 def send_apply_request(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
1205 1208 ident=None):
1206 1209 """construct and send an apply message via a socket.
1207 1210
1208 1211 This is the principal method with which all engine execution is performed by views.
1209 1212 """
1210 1213
1211 1214 if self._closed:
1212 1215 raise RuntimeError("Client cannot be used after its sockets have been closed")
1213 1216
1214 1217 # defaults:
1215 1218 args = args if args is not None else []
1216 1219 kwargs = kwargs if kwargs is not None else {}
1217 1220 subheader = subheader if subheader is not None else {}
1218 1221
1219 1222 # validate arguments
1220 1223 if not callable(f) and not isinstance(f, Reference):
1221 1224 raise TypeError("f must be callable, not %s"%type(f))
1222 1225 if not isinstance(args, (tuple, list)):
1223 1226 raise TypeError("args must be tuple or list, not %s"%type(args))
1224 1227 if not isinstance(kwargs, dict):
1225 1228 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1226 1229 if not isinstance(subheader, dict):
1227 1230 raise TypeError("subheader must be dict, not %s"%type(subheader))
1228 1231
1229 1232 bufs = util.pack_apply_message(f,args,kwargs)
1230 1233
1231 1234 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1232 1235 subheader=subheader, track=track)
1233 1236
1234 1237 msg_id = msg['header']['msg_id']
1235 1238 self.outstanding.add(msg_id)
1236 1239 if ident:
1237 1240 # possibly routed to a specific engine
1238 1241 if isinstance(ident, list):
1239 1242 ident = ident[-1]
1240 1243 if ident in self._engines.values():
1241 1244 # save for later, in case of engine death
1242 1245 self._outstanding_dict[ident].add(msg_id)
1243 1246 self.history.append(msg_id)
1244 1247 self.metadata[msg_id]['submitted'] = datetime.now()
1245 1248
1246 1249 return msg
1247 1250
1248 1251 def send_execute_request(self, socket, code, silent=True, subheader=None, ident=None):
1249 1252 """construct and send an execute request via a socket.
1250 1253
1251 1254 """
1252 1255
1253 1256 if self._closed:
1254 1257 raise RuntimeError("Client cannot be used after its sockets have been closed")
1255 1258
1256 1259 # defaults:
1257 1260 subheader = subheader if subheader is not None else {}
1258 1261
1259 1262 # validate arguments
1260 1263 if not isinstance(code, basestring):
1261 1264 raise TypeError("code must be text, not %s" % type(code))
1262 1265 if not isinstance(subheader, dict):
1263 1266 raise TypeError("subheader must be dict, not %s" % type(subheader))
1264 1267
1265 1268 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1266 1269
1267 1270
1268 1271 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1269 1272 subheader=subheader)
1270 1273
1271 1274 msg_id = msg['header']['msg_id']
1272 1275 self.outstanding.add(msg_id)
1273 1276 if ident:
1274 1277 # possibly routed to a specific engine
1275 1278 if isinstance(ident, list):
1276 1279 ident = ident[-1]
1277 1280 if ident in self._engines.values():
1278 1281 # save for later, in case of engine death
1279 1282 self._outstanding_dict[ident].add(msg_id)
1280 1283 self.history.append(msg_id)
1281 1284 self.metadata[msg_id]['submitted'] = datetime.now()
1282 1285
1283 1286 return msg
1284 1287
1285 1288 #--------------------------------------------------------------------------
1286 1289 # construct a View object
1287 1290 #--------------------------------------------------------------------------
1288 1291
1289 1292 def load_balanced_view(self, targets=None):
1290 1293 """construct a DirectView object.
1291 1294
1292 1295 If no arguments are specified, create a LoadBalancedView
1293 1296 using all engines.
1294 1297
1295 1298 Parameters
1296 1299 ----------
1297 1300
1298 1301 targets: list,slice,int,etc. [default: use all engines]
1299 1302 The subset of engines across which to load-balance
1300 1303 """
1301 1304 if targets == 'all':
1302 1305 targets = None
1303 1306 if targets is not None:
1304 1307 targets = self._build_targets(targets)[1]
1305 1308 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1306 1309
1307 1310 def direct_view(self, targets='all'):
1308 1311 """construct a DirectView object.
1309 1312
1310 1313 If no targets are specified, create a DirectView using all engines.
1311 1314
1312 1315 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1313 1316 evaluate the target engines at each execution, whereas rc[:] will connect to
1314 1317 all *current* engines, and that list will not change.
1315 1318
1316 1319 That is, 'all' will always use all engines, whereas rc[:] will not use
1317 1320 engines added after the DirectView is constructed.
1318 1321
1319 1322 Parameters
1320 1323 ----------
1321 1324
1322 1325 targets: list,slice,int,etc. [default: use all engines]
1323 1326 The engines to use for the View
1324 1327 """
1325 1328 single = isinstance(targets, int)
1326 1329 # allow 'all' to be lazily evaluated at each execution
1327 1330 if targets != 'all':
1328 1331 targets = self._build_targets(targets)[1]
1329 1332 if single:
1330 1333 targets = targets[0]
1331 1334 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1332 1335
1333 1336 #--------------------------------------------------------------------------
1334 1337 # Query methods
1335 1338 #--------------------------------------------------------------------------
1336 1339
1337 1340 @spin_first
1338 1341 def get_result(self, indices_or_msg_ids=None, block=None):
1339 1342 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1340 1343
1341 1344 If the client already has the results, no request to the Hub will be made.
1342 1345
1343 1346 This is a convenient way to construct AsyncResult objects, which are wrappers
1344 1347 that include metadata about execution, and allow for awaiting results that
1345 1348 were not submitted by this Client.
1346 1349
1347 1350 It can also be a convenient way to retrieve the metadata associated with
1348 1351 blocking execution, since it always retrieves
1349 1352
1350 1353 Examples
1351 1354 --------
1352 1355 ::
1353 1356
1354 1357 In [10]: r = client.apply()
1355 1358
1356 1359 Parameters
1357 1360 ----------
1358 1361
1359 1362 indices_or_msg_ids : integer history index, str msg_id, or list of either
1360 1363 The indices or msg_ids of indices to be retrieved
1361 1364
1362 1365 block : bool
1363 1366 Whether to wait for the result to be done
1364 1367
1365 1368 Returns
1366 1369 -------
1367 1370
1368 1371 AsyncResult
1369 1372 A single AsyncResult object will always be returned.
1370 1373
1371 1374 AsyncHubResult
1372 1375 A subclass of AsyncResult that retrieves results from the Hub
1373 1376
1374 1377 """
1375 1378 block = self.block if block is None else block
1376 1379 if indices_or_msg_ids is None:
1377 1380 indices_or_msg_ids = -1
1378 1381
1379 1382 if not isinstance(indices_or_msg_ids, (list,tuple)):
1380 1383 indices_or_msg_ids = [indices_or_msg_ids]
1381 1384
1382 1385 theids = []
1383 1386 for id in indices_or_msg_ids:
1384 1387 if isinstance(id, int):
1385 1388 id = self.history[id]
1386 1389 if not isinstance(id, basestring):
1387 1390 raise TypeError("indices must be str or int, not %r"%id)
1388 1391 theids.append(id)
1389 1392
1390 1393 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1391 1394 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1392 1395
1393 1396 if remote_ids:
1394 1397 ar = AsyncHubResult(self, msg_ids=theids)
1395 1398 else:
1396 1399 ar = AsyncResult(self, msg_ids=theids)
1397 1400
1398 1401 if block:
1399 1402 ar.wait()
1400 1403
1401 1404 return ar
1402 1405
1403 1406 @spin_first
1404 1407 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1405 1408 """Resubmit one or more tasks.
1406 1409
1407 1410 in-flight tasks may not be resubmitted.
1408 1411
1409 1412 Parameters
1410 1413 ----------
1411 1414
1412 1415 indices_or_msg_ids : integer history index, str msg_id, or list of either
1413 1416 The indices or msg_ids of indices to be retrieved
1414 1417
1415 1418 block : bool
1416 1419 Whether to wait for the result to be done
1417 1420
1418 1421 Returns
1419 1422 -------
1420 1423
1421 1424 AsyncHubResult
1422 1425 A subclass of AsyncResult that retrieves results from the Hub
1423 1426
1424 1427 """
1425 1428 block = self.block if block is None else block
1426 1429 if indices_or_msg_ids is None:
1427 1430 indices_or_msg_ids = -1
1428 1431
1429 1432 if not isinstance(indices_or_msg_ids, (list,tuple)):
1430 1433 indices_or_msg_ids = [indices_or_msg_ids]
1431 1434
1432 1435 theids = []
1433 1436 for id in indices_or_msg_ids:
1434 1437 if isinstance(id, int):
1435 1438 id = self.history[id]
1436 1439 if not isinstance(id, basestring):
1437 1440 raise TypeError("indices must be str or int, not %r"%id)
1438 1441 theids.append(id)
1439 1442
1440 1443 content = dict(msg_ids = theids)
1441 1444
1442 1445 self.session.send(self._query_socket, 'resubmit_request', content)
1443 1446
1444 1447 zmq.select([self._query_socket], [], [])
1445 1448 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1446 1449 if self.debug:
1447 1450 pprint(msg)
1448 1451 content = msg['content']
1449 1452 if content['status'] != 'ok':
1450 1453 raise self._unwrap_exception(content)
1451 1454 mapping = content['resubmitted']
1452 1455 new_ids = [ mapping[msg_id] for msg_id in theids ]
1453 1456
1454 1457 ar = AsyncHubResult(self, msg_ids=new_ids)
1455 1458
1456 1459 if block:
1457 1460 ar.wait()
1458 1461
1459 1462 return ar
1460 1463
1461 1464 @spin_first
1462 1465 def result_status(self, msg_ids, status_only=True):
1463 1466 """Check on the status of the result(s) of the apply request with `msg_ids`.
1464 1467
1465 1468 If status_only is False, then the actual results will be retrieved, else
1466 1469 only the status of the results will be checked.
1467 1470
1468 1471 Parameters
1469 1472 ----------
1470 1473
1471 1474 msg_ids : list of msg_ids
1472 1475 if int:
1473 1476 Passed as index to self.history for convenience.
1474 1477 status_only : bool (default: True)
1475 1478 if False:
1476 1479 Retrieve the actual results of completed tasks.
1477 1480
1478 1481 Returns
1479 1482 -------
1480 1483
1481 1484 results : dict
1482 1485 There will always be the keys 'pending' and 'completed', which will
1483 1486 be lists of msg_ids that are incomplete or complete. If `status_only`
1484 1487 is False, then completed results will be keyed by their `msg_id`.
1485 1488 """
1486 1489 if not isinstance(msg_ids, (list,tuple)):
1487 1490 msg_ids = [msg_ids]
1488 1491
1489 1492 theids = []
1490 1493 for msg_id in msg_ids:
1491 1494 if isinstance(msg_id, int):
1492 1495 msg_id = self.history[msg_id]
1493 1496 if not isinstance(msg_id, basestring):
1494 1497 raise TypeError("msg_ids must be str, not %r"%msg_id)
1495 1498 theids.append(msg_id)
1496 1499
1497 1500 completed = []
1498 1501 local_results = {}
1499 1502
1500 1503 # comment this block out to temporarily disable local shortcut:
1501 1504 for msg_id in theids:
1502 1505 if msg_id in self.results:
1503 1506 completed.append(msg_id)
1504 1507 local_results[msg_id] = self.results[msg_id]
1505 1508 theids.remove(msg_id)
1506 1509
1507 1510 if theids: # some not locally cached
1508 1511 content = dict(msg_ids=theids, status_only=status_only)
1509 1512 msg = self.session.send(self._query_socket, "result_request", content=content)
1510 1513 zmq.select([self._query_socket], [], [])
1511 1514 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1512 1515 if self.debug:
1513 1516 pprint(msg)
1514 1517 content = msg['content']
1515 1518 if content['status'] != 'ok':
1516 1519 raise self._unwrap_exception(content)
1517 1520 buffers = msg['buffers']
1518 1521 else:
1519 1522 content = dict(completed=[],pending=[])
1520 1523
1521 1524 content['completed'].extend(completed)
1522 1525
1523 1526 if status_only:
1524 1527 return content
1525 1528
1526 1529 failures = []
1527 1530 # load cached results into result:
1528 1531 content.update(local_results)
1529 1532
1530 1533 # update cache with results:
1531 1534 for msg_id in sorted(theids):
1532 1535 if msg_id in content['completed']:
1533 1536 rec = content[msg_id]
1534 1537 parent = rec['header']
1535 1538 header = rec['result_header']
1536 1539 rcontent = rec['result_content']
1537 1540 iodict = rec['io']
1538 1541 if isinstance(rcontent, str):
1539 1542 rcontent = self.session.unpack(rcontent)
1540 1543
1541 1544 md = self.metadata[msg_id]
1542 1545 md.update(self._extract_metadata(header, parent, rcontent))
1543 1546 if rec.get('received'):
1544 1547 md['received'] = rec['received']
1545 1548 md.update(iodict)
1546 1549
1547 1550 if rcontent['status'] == 'ok':
1548 1551 if header['msg_type'] == 'apply_reply':
1549 1552 res,buffers = util.unserialize_object(buffers)
1550 1553 elif header['msg_type'] == 'execute_reply':
1551 1554 res = ExecuteReply(msg_id, rcontent, md)
1552 1555 else:
1553 raise KeyError("unhandled msg type: %r" % header[msg_type])
1556 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1554 1557 else:
1555 1558 res = self._unwrap_exception(rcontent)
1556 1559 failures.append(res)
1557 1560
1558 1561 self.results[msg_id] = res
1559 1562 content[msg_id] = res
1560 1563
1561 1564 if len(theids) == 1 and failures:
1562 1565 raise failures[0]
1563 1566
1564 1567 error.collect_exceptions(failures, "result_status")
1565 1568 return content
1566 1569
1567 1570 @spin_first
1568 1571 def queue_status(self, targets='all', verbose=False):
1569 1572 """Fetch the status of engine queues.
1570 1573
1571 1574 Parameters
1572 1575 ----------
1573 1576
1574 1577 targets : int/str/list of ints/strs
1575 1578 the engines whose states are to be queried.
1576 1579 default : all
1577 1580 verbose : bool
1578 1581 Whether to return lengths only, or lists of ids for each element
1579 1582 """
1580 1583 if targets == 'all':
1581 1584 # allow 'all' to be evaluated on the engine
1582 1585 engine_ids = None
1583 1586 else:
1584 1587 engine_ids = self._build_targets(targets)[1]
1585 1588 content = dict(targets=engine_ids, verbose=verbose)
1586 1589 self.session.send(self._query_socket, "queue_request", content=content)
1587 1590 idents,msg = self.session.recv(self._query_socket, 0)
1588 1591 if self.debug:
1589 1592 pprint(msg)
1590 1593 content = msg['content']
1591 1594 status = content.pop('status')
1592 1595 if status != 'ok':
1593 1596 raise self._unwrap_exception(content)
1594 1597 content = rekey(content)
1595 1598 if isinstance(targets, int):
1596 1599 return content[targets]
1597 1600 else:
1598 1601 return content
1599 1602
1600 1603 @spin_first
1601 1604 def purge_results(self, jobs=[], targets=[]):
1602 1605 """Tell the Hub to forget results.
1603 1606
1604 1607 Individual results can be purged by msg_id, or the entire
1605 1608 history of specific targets can be purged.
1606 1609
1607 1610 Use `purge_results('all')` to scrub everything from the Hub's db.
1608 1611
1609 1612 Parameters
1610 1613 ----------
1611 1614
1612 1615 jobs : str or list of str or AsyncResult objects
1613 1616 the msg_ids whose results should be forgotten.
1614 1617 targets : int/str/list of ints/strs
1615 1618 The targets, by int_id, whose entire history is to be purged.
1616 1619
1617 1620 default : None
1618 1621 """
1619 1622 if not targets and not jobs:
1620 1623 raise ValueError("Must specify at least one of `targets` and `jobs`")
1621 1624 if targets:
1622 1625 targets = self._build_targets(targets)[1]
1623 1626
1624 1627 # construct msg_ids from jobs
1625 1628 if jobs == 'all':
1626 1629 msg_ids = jobs
1627 1630 else:
1628 1631 msg_ids = []
1629 1632 if isinstance(jobs, (basestring,AsyncResult)):
1630 1633 jobs = [jobs]
1631 1634 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1632 1635 if bad_ids:
1633 1636 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1634 1637 for j in jobs:
1635 1638 if isinstance(j, AsyncResult):
1636 1639 msg_ids.extend(j.msg_ids)
1637 1640 else:
1638 1641 msg_ids.append(j)
1639 1642
1640 1643 content = dict(engine_ids=targets, msg_ids=msg_ids)
1641 1644 self.session.send(self._query_socket, "purge_request", content=content)
1642 1645 idents, msg = self.session.recv(self._query_socket, 0)
1643 1646 if self.debug:
1644 1647 pprint(msg)
1645 1648 content = msg['content']
1646 1649 if content['status'] != 'ok':
1647 1650 raise self._unwrap_exception(content)
1648 1651
1649 1652 @spin_first
1650 1653 def hub_history(self):
1651 1654 """Get the Hub's history
1652 1655
1653 1656 Just like the Client, the Hub has a history, which is a list of msg_ids.
1654 1657 This will contain the history of all clients, and, depending on configuration,
1655 1658 may contain history across multiple cluster sessions.
1656 1659
1657 1660 Any msg_id returned here is a valid argument to `get_result`.
1658 1661
1659 1662 Returns
1660 1663 -------
1661 1664
1662 1665 msg_ids : list of strs
1663 1666 list of all msg_ids, ordered by task submission time.
1664 1667 """
1665 1668
1666 1669 self.session.send(self._query_socket, "history_request", content={})
1667 1670 idents, msg = self.session.recv(self._query_socket, 0)
1668 1671
1669 1672 if self.debug:
1670 1673 pprint(msg)
1671 1674 content = msg['content']
1672 1675 if content['status'] != 'ok':
1673 1676 raise self._unwrap_exception(content)
1674 1677 else:
1675 1678 return content['history']
1676 1679
1677 1680 @spin_first
1678 1681 def db_query(self, query, keys=None):
1679 1682 """Query the Hub's TaskRecord database
1680 1683
1681 1684 This will return a list of task record dicts that match `query`
1682 1685
1683 1686 Parameters
1684 1687 ----------
1685 1688
1686 1689 query : mongodb query dict
1687 1690 The search dict. See mongodb query docs for details.
1688 1691 keys : list of strs [optional]
1689 1692 The subset of keys to be returned. The default is to fetch everything but buffers.
1690 1693 'msg_id' will *always* be included.
1691 1694 """
1692 1695 if isinstance(keys, basestring):
1693 1696 keys = [keys]
1694 1697 content = dict(query=query, keys=keys)
1695 1698 self.session.send(self._query_socket, "db_request", content=content)
1696 1699 idents, msg = self.session.recv(self._query_socket, 0)
1697 1700 if self.debug:
1698 1701 pprint(msg)
1699 1702 content = msg['content']
1700 1703 if content['status'] != 'ok':
1701 1704 raise self._unwrap_exception(content)
1702 1705
1703 1706 records = content['records']
1704 1707
1705 1708 buffer_lens = content['buffer_lens']
1706 1709 result_buffer_lens = content['result_buffer_lens']
1707 1710 buffers = msg['buffers']
1708 1711 has_bufs = buffer_lens is not None
1709 1712 has_rbufs = result_buffer_lens is not None
1710 1713 for i,rec in enumerate(records):
1711 1714 # relink buffers
1712 1715 if has_bufs:
1713 1716 blen = buffer_lens[i]
1714 1717 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1715 1718 if has_rbufs:
1716 1719 blen = result_buffer_lens[i]
1717 1720 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1718 1721
1719 1722 return records
1720 1723
1721 1724 __all__ = [ 'Client' ]
General Comments 0
You need to be logged in to leave comments. Login now