##// END OF EJS Templates
Backport PR #2773: Fixed minor typo causing AttributeError to be thrown....
MinRK -
Show More
@@ -1,1724 +1,1724 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 from threading import Thread, Event
22 22 import time
23 23 import warnings
24 24 from datetime import datetime
25 25 from getpass import getpass
26 26 from pprint import pprint
27 27
28 28 pjoin = os.path.join
29 29
30 30 import zmq
31 31 # from zmq.eventloop import ioloop, zmqstream
32 32
33 33 from IPython.config.configurable import MultipleInstanceError
34 34 from IPython.core.application import BaseIPythonApplication
35 35 from IPython.core.profiledir import ProfileDir, ProfileDirError
36 36
37 37 from IPython.utils.coloransi import TermColors
38 38 from IPython.utils.jsonutil import rekey
39 39 from IPython.utils.localinterfaces import LOCAL_IPS
40 40 from IPython.utils.path import get_ipython_dir
41 41 from IPython.utils.py3compat import cast_bytes
42 42 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
43 43 Dict, List, Bool, Set, Any)
44 44 from IPython.external.decorator import decorator
45 45 from IPython.external.ssh import tunnel
46 46
47 47 from IPython.parallel import Reference
48 48 from IPython.parallel import error
49 49 from IPython.parallel import util
50 50
51 51 from IPython.zmq.session import Session, Message
52 52
53 53 from .asyncresult import AsyncResult, AsyncHubResult
54 54 from .view import DirectView, LoadBalancedView
55 55
56 56 if sys.version_info[0] >= 3:
57 57 # xrange is used in a couple 'isinstance' tests in py2
58 58 # should be just 'range' in 3k
59 59 xrange = range
60 60
61 61 #--------------------------------------------------------------------------
62 62 # Decorators for Client methods
63 63 #--------------------------------------------------------------------------
64 64
65 65 @decorator
66 66 def spin_first(f, self, *args, **kwargs):
67 67 """Call spin() to sync state prior to calling the method."""
68 68 self.spin()
69 69 return f(self, *args, **kwargs)
70 70
71 71
72 72 #--------------------------------------------------------------------------
73 73 # Classes
74 74 #--------------------------------------------------------------------------
75 75
76 76
77 77 class ExecuteReply(object):
78 78 """wrapper for finished Execute results"""
79 79 def __init__(self, msg_id, content, metadata):
80 80 self.msg_id = msg_id
81 81 self._content = content
82 82 self.execution_count = content['execution_count']
83 83 self.metadata = metadata
84 84
85 85 def __getitem__(self, key):
86 86 return self.metadata[key]
87 87
88 88 def __getattr__(self, key):
89 89 if key not in self.metadata:
90 90 raise AttributeError(key)
91 91 return self.metadata[key]
92 92
93 93 def __repr__(self):
94 94 pyout = self.metadata['pyout'] or {'data':{}}
95 95 text_out = pyout['data'].get('text/plain', '')
96 96 if len(text_out) > 32:
97 97 text_out = text_out[:29] + '...'
98 98
99 99 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
100 100
101 101 def _repr_pretty_(self, p, cycle):
102 102 pyout = self.metadata['pyout'] or {'data':{}}
103 103 text_out = pyout['data'].get('text/plain', '')
104 104
105 105 if not text_out:
106 106 return
107 107
108 108 try:
109 109 ip = get_ipython()
110 110 except NameError:
111 111 colors = "NoColor"
112 112 else:
113 113 colors = ip.colors
114 114
115 115 if colors == "NoColor":
116 116 out = normal = ""
117 117 else:
118 118 out = TermColors.Red
119 119 normal = TermColors.Normal
120 120
121 121 if '\n' in text_out and not text_out.startswith('\n'):
122 122 # add newline for multiline reprs
123 123 text_out = '\n' + text_out
124 124
125 125 p.text(
126 126 out + u'Out[%i:%i]: ' % (
127 127 self.metadata['engine_id'], self.execution_count
128 128 ) + normal + text_out
129 129 )
130 130
131 131 def _repr_html_(self):
132 132 pyout = self.metadata['pyout'] or {'data':{}}
133 133 return pyout['data'].get("text/html")
134 134
135 135 def _repr_latex_(self):
136 136 pyout = self.metadata['pyout'] or {'data':{}}
137 137 return pyout['data'].get("text/latex")
138 138
139 139 def _repr_json_(self):
140 140 pyout = self.metadata['pyout'] or {'data':{}}
141 141 return pyout['data'].get("application/json")
142 142
143 143 def _repr_javascript_(self):
144 144 pyout = self.metadata['pyout'] or {'data':{}}
145 145 return pyout['data'].get("application/javascript")
146 146
147 147 def _repr_png_(self):
148 148 pyout = self.metadata['pyout'] or {'data':{}}
149 149 return pyout['data'].get("image/png")
150 150
151 151 def _repr_jpeg_(self):
152 152 pyout = self.metadata['pyout'] or {'data':{}}
153 153 return pyout['data'].get("image/jpeg")
154 154
155 155 def _repr_svg_(self):
156 156 pyout = self.metadata['pyout'] or {'data':{}}
157 157 return pyout['data'].get("image/svg+xml")
158 158
159 159
160 160 class Metadata(dict):
161 161 """Subclass of dict for initializing metadata values.
162 162
163 163 Attribute access works on keys.
164 164
165 165 These objects have a strict set of keys - errors will raise if you try
166 166 to add new keys.
167 167 """
168 168 def __init__(self, *args, **kwargs):
169 169 dict.__init__(self)
170 170 md = {'msg_id' : None,
171 171 'submitted' : None,
172 172 'started' : None,
173 173 'completed' : None,
174 174 'received' : None,
175 175 'engine_uuid' : None,
176 176 'engine_id' : None,
177 177 'follow' : None,
178 178 'after' : None,
179 179 'status' : None,
180 180
181 181 'pyin' : None,
182 182 'pyout' : None,
183 183 'pyerr' : None,
184 184 'stdout' : '',
185 185 'stderr' : '',
186 186 'outputs' : [],
187 187 'outputs_ready' : False,
188 188 }
189 189 self.update(md)
190 190 self.update(dict(*args, **kwargs))
191 191
192 192 def __getattr__(self, key):
193 193 """getattr aliased to getitem"""
194 194 if key in self.iterkeys():
195 195 return self[key]
196 196 else:
197 197 raise AttributeError(key)
198 198
199 199 def __setattr__(self, key, value):
200 200 """setattr aliased to setitem, with strict"""
201 201 if key in self.iterkeys():
202 202 self[key] = value
203 203 else:
204 204 raise AttributeError(key)
205 205
206 206 def __setitem__(self, key, value):
207 207 """strict static key enforcement"""
208 208 if key in self.iterkeys():
209 209 dict.__setitem__(self, key, value)
210 210 else:
211 211 raise KeyError(key)
212 212
213 213
214 214 class Client(HasTraits):
215 215 """A semi-synchronous client to the IPython ZMQ cluster
216 216
217 217 Parameters
218 218 ----------
219 219
220 220 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
221 221 Connection information for the Hub's registration. If a json connector
222 222 file is given, then likely no further configuration is necessary.
223 223 [Default: use profile]
224 224 profile : bytes
225 225 The name of the Cluster profile to be used to find connector information.
226 226 If run from an IPython application, the default profile will be the same
227 227 as the running application, otherwise it will be 'default'.
228 228 context : zmq.Context
229 229 Pass an existing zmq.Context instance, otherwise the client will create its own.
230 230 debug : bool
231 231 flag for lots of message printing for debug purposes
232 232 timeout : int/float
233 233 time (in seconds) to wait for connection replies from the Hub
234 234 [Default: 10]
235 235
236 236 #-------------- session related args ----------------
237 237
238 238 config : Config object
239 239 If specified, this will be relayed to the Session for configuration
240 240 username : str
241 241 set username for the session object
242 242 packer : str (import_string) or callable
243 243 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
244 244 function to serialize messages. Must support same input as
245 245 JSON, and output must be bytes.
246 246 You can pass a callable directly as `pack`
247 247 unpacker : str (import_string) or callable
248 248 The inverse of packer. Only necessary if packer is specified as *not* one
249 249 of 'json' or 'pickle'.
250 250
251 251 #-------------- ssh related args ----------------
252 252 # These are args for configuring the ssh tunnel to be used
253 253 # credentials are used to forward connections over ssh to the Controller
254 254 # Note that the ip given in `addr` needs to be relative to sshserver
255 255 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
256 256 # and set sshserver as the same machine the Controller is on. However,
257 257 # the only requirement is that sshserver is able to see the Controller
258 258 # (i.e. is within the same trusted network).
259 259
260 260 sshserver : str
261 261 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
262 262 If keyfile or password is specified, and this is not, it will default to
263 263 the ip given in addr.
264 264 sshkey : str; path to ssh private key file
265 265 This specifies a key to be used in ssh login, default None.
266 266 Regular default ssh keys will be used without specifying this argument.
267 267 password : str
268 268 Your ssh password to sshserver. Note that if this is left None,
269 269 you will be prompted for it if passwordless key based login is unavailable.
270 270 paramiko : bool
271 271 flag for whether to use paramiko instead of shell ssh for tunneling.
272 272 [default: True on win32, False else]
273 273
274 274 ------- exec authentication args -------
275 275 If even localhost is untrusted, you can have some protection against
276 276 unauthorized execution by signing messages with HMAC digests.
277 277 Messages are still sent as cleartext, so if someone can snoop your
278 278 loopback traffic this will not protect your privacy, but will prevent
279 279 unauthorized execution.
280 280
281 281 exec_key : str
282 282 an authentication key or file containing a key
283 283 default: None
284 284
285 285
286 286 Attributes
287 287 ----------
288 288
289 289 ids : list of int engine IDs
290 290 requesting the ids attribute always synchronizes
291 291 the registration state. To request ids without synchronization,
292 292 use semi-private _ids attributes.
293 293
294 294 history : list of msg_ids
295 295 a list of msg_ids, keeping track of all the execution
296 296 messages you have submitted in order.
297 297
298 298 outstanding : set of msg_ids
299 299 a set of msg_ids that have been submitted, but whose
300 300 results have not yet been received.
301 301
302 302 results : dict
303 303 a dict of all our results, keyed by msg_id
304 304
305 305 block : bool
306 306 determines default behavior when block not specified
307 307 in execution methods
308 308
309 309 Methods
310 310 -------
311 311
312 312 spin
313 313 flushes incoming results and registration state changes
314 314 control methods spin, and requesting `ids` also ensures up to date
315 315
316 316 wait
317 317 wait on one or more msg_ids
318 318
319 319 execution methods
320 320 apply
321 321 legacy: execute, run
322 322
323 323 data movement
324 324 push, pull, scatter, gather
325 325
326 326 query methods
327 327 queue_status, get_result, purge, result_status
328 328
329 329 control methods
330 330 abort, shutdown
331 331
332 332 """
333 333
334 334
335 335 block = Bool(False)
336 336 outstanding = Set()
337 337 results = Instance('collections.defaultdict', (dict,))
338 338 metadata = Instance('collections.defaultdict', (Metadata,))
339 339 history = List()
340 340 debug = Bool(False)
341 341 _spin_thread = Any()
342 342 _stop_spinning = Any()
343 343
344 344 profile=Unicode()
345 345 def _profile_default(self):
346 346 if BaseIPythonApplication.initialized():
347 347 # an IPython app *might* be running, try to get its profile
348 348 try:
349 349 return BaseIPythonApplication.instance().profile
350 350 except (AttributeError, MultipleInstanceError):
351 351 # could be a *different* subclass of config.Application,
352 352 # which would raise one of these two errors.
353 353 return u'default'
354 354 else:
355 355 return u'default'
356 356
357 357
358 358 _outstanding_dict = Instance('collections.defaultdict', (set,))
359 359 _ids = List()
360 360 _connected=Bool(False)
361 361 _ssh=Bool(False)
362 362 _context = Instance('zmq.Context')
363 363 _config = Dict()
364 364 _engines=Instance(util.ReverseDict, (), {})
365 365 # _hub_socket=Instance('zmq.Socket')
366 366 _query_socket=Instance('zmq.Socket')
367 367 _control_socket=Instance('zmq.Socket')
368 368 _iopub_socket=Instance('zmq.Socket')
369 369 _notification_socket=Instance('zmq.Socket')
370 370 _mux_socket=Instance('zmq.Socket')
371 371 _task_socket=Instance('zmq.Socket')
372 372 _task_scheme=Unicode()
373 373 _closed = False
374 374 _ignored_control_replies=Integer(0)
375 375 _ignored_hub_replies=Integer(0)
376 376
377 377 def __new__(self, *args, **kw):
378 378 # don't raise on positional args
379 379 return HasTraits.__new__(self, **kw)
380 380
381 381 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
382 382 context=None, debug=False, exec_key=None,
383 383 sshserver=None, sshkey=None, password=None, paramiko=None,
384 384 timeout=10, **extra_args
385 385 ):
386 386 if profile:
387 387 super(Client, self).__init__(debug=debug, profile=profile)
388 388 else:
389 389 super(Client, self).__init__(debug=debug)
390 390 if context is None:
391 391 context = zmq.Context.instance()
392 392 self._context = context
393 393 self._stop_spinning = Event()
394 394
395 395 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
396 396 if self._cd is not None:
397 397 if url_or_file is None:
398 398 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
399 399 if url_or_file is None:
400 400 raise ValueError(
401 401 "I can't find enough information to connect to a hub!"
402 402 " Please specify at least one of url_or_file or profile."
403 403 )
404 404
405 405 if not util.is_url(url_or_file):
406 406 # it's not a url, try for a file
407 407 if not os.path.exists(url_or_file):
408 408 if self._cd:
409 409 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
410 410 if not os.path.exists(url_or_file):
411 411 raise IOError("Connection file not found: %r" % url_or_file)
412 412 with open(url_or_file) as f:
413 413 cfg = json.loads(f.read())
414 414 else:
415 415 cfg = {'url':url_or_file}
416 416
417 417 # sync defaults from args, json:
418 418 if sshserver:
419 419 cfg['ssh'] = sshserver
420 420 if exec_key:
421 421 cfg['exec_key'] = exec_key
422 422 exec_key = cfg['exec_key']
423 423 location = cfg.setdefault('location', None)
424 424 cfg['url'] = util.disambiguate_url(cfg['url'], location)
425 425 url = cfg['url']
426 426 proto,addr,port = util.split_url(url)
427 427 if location is not None and addr == '127.0.0.1':
428 428 # location specified, and connection is expected to be local
429 429 if location not in LOCAL_IPS and not sshserver:
430 430 # load ssh from JSON *only* if the controller is not on
431 431 # this machine
432 432 sshserver=cfg['ssh']
433 433 if location not in LOCAL_IPS and not sshserver:
434 434 # warn if no ssh specified, but SSH is probably needed
435 435 # This is only a warning, because the most likely cause
436 436 # is a local Controller on a laptop whose IP is dynamic
437 437 warnings.warn("""
438 438 Controller appears to be listening on localhost, but not on this machine.
439 439 If this is true, you should specify Client(...,sshserver='you@%s')
440 440 or instruct your controller to listen on an external IP."""%location,
441 441 RuntimeWarning)
442 442 elif not sshserver:
443 443 # otherwise sync with cfg
444 444 sshserver = cfg['ssh']
445 445
446 446 self._config = cfg
447 447
448 448 self._ssh = bool(sshserver or sshkey or password)
449 449 if self._ssh and sshserver is None:
450 450 # default to ssh via localhost
451 451 sshserver = url.split('://')[1].split(':')[0]
452 452 if self._ssh and password is None:
453 453 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
454 454 password=False
455 455 else:
456 456 password = getpass("SSH Password for %s: "%sshserver)
457 457 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
458 458
459 459 # configure and construct the session
460 460 if exec_key is not None:
461 461 if os.path.isfile(exec_key):
462 462 extra_args['keyfile'] = exec_key
463 463 else:
464 464 exec_key = cast_bytes(exec_key)
465 465 extra_args['key'] = exec_key
466 466 self.session = Session(**extra_args)
467 467
468 468 self._query_socket = self._context.socket(zmq.DEALER)
469 469 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
470 470 if self._ssh:
471 471 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
472 472 else:
473 473 self._query_socket.connect(url)
474 474
475 475 self.session.debug = self.debug
476 476
477 477 self._notification_handlers = {'registration_notification' : self._register_engine,
478 478 'unregistration_notification' : self._unregister_engine,
479 479 'shutdown_notification' : lambda msg: self.close(),
480 480 }
481 481 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
482 482 'apply_reply' : self._handle_apply_reply}
483 483 self._connect(sshserver, ssh_kwargs, timeout)
484 484
485 485 # last step: setup magics, if we are in IPython:
486 486
487 487 try:
488 488 ip = get_ipython()
489 489 except NameError:
490 490 return
491 491 else:
492 492 if 'px' not in ip.magics_manager.magics:
493 493 # in IPython but we are the first Client.
494 494 # activate a default view for parallel magics.
495 495 self.activate()
496 496
497 497 def __del__(self):
498 498 """cleanup sockets, but _not_ context."""
499 499 self.close()
500 500
501 501 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
502 502 if ipython_dir is None:
503 503 ipython_dir = get_ipython_dir()
504 504 if profile_dir is not None:
505 505 try:
506 506 self._cd = ProfileDir.find_profile_dir(profile_dir)
507 507 return
508 508 except ProfileDirError:
509 509 pass
510 510 elif profile is not None:
511 511 try:
512 512 self._cd = ProfileDir.find_profile_dir_by_name(
513 513 ipython_dir, profile)
514 514 return
515 515 except ProfileDirError:
516 516 pass
517 517 self._cd = None
518 518
519 519 def _update_engines(self, engines):
520 520 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
521 521 for k,v in engines.iteritems():
522 522 eid = int(k)
523 523 self._engines[eid] = v
524 524 self._ids.append(eid)
525 525 self._ids = sorted(self._ids)
526 526 if sorted(self._engines.keys()) != range(len(self._engines)) and \
527 527 self._task_scheme == 'pure' and self._task_socket:
528 528 self._stop_scheduling_tasks()
529 529
530 530 def _stop_scheduling_tasks(self):
531 531 """Stop scheduling tasks because an engine has been unregistered
532 532 from a pure ZMQ scheduler.
533 533 """
534 534 self._task_socket.close()
535 535 self._task_socket = None
536 536 msg = "An engine has been unregistered, and we are using pure " +\
537 537 "ZMQ task scheduling. Task farming will be disabled."
538 538 if self.outstanding:
539 539 msg += " If you were running tasks when this happened, " +\
540 540 "some `outstanding` msg_ids may never resolve."
541 541 warnings.warn(msg, RuntimeWarning)
542 542
543 543 def _build_targets(self, targets):
544 544 """Turn valid target IDs or 'all' into two lists:
545 545 (int_ids, uuids).
546 546 """
547 547 if not self._ids:
548 548 # flush notification socket if no engines yet, just in case
549 549 if not self.ids:
550 550 raise error.NoEnginesRegistered("Can't build targets without any engines")
551 551
552 552 if targets is None:
553 553 targets = self._ids
554 554 elif isinstance(targets, basestring):
555 555 if targets.lower() == 'all':
556 556 targets = self._ids
557 557 else:
558 558 raise TypeError("%r not valid str target, must be 'all'"%(targets))
559 559 elif isinstance(targets, int):
560 560 if targets < 0:
561 561 targets = self.ids[targets]
562 562 if targets not in self._ids:
563 563 raise IndexError("No such engine: %i"%targets)
564 564 targets = [targets]
565 565
566 566 if isinstance(targets, slice):
567 567 indices = range(len(self._ids))[targets]
568 568 ids = self.ids
569 569 targets = [ ids[i] for i in indices ]
570 570
571 571 if not isinstance(targets, (tuple, list, xrange)):
572 572 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
573 573
574 574 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
575 575
576 576 def _connect(self, sshserver, ssh_kwargs, timeout):
577 577 """setup all our socket connections to the cluster. This is called from
578 578 __init__."""
579 579
580 580 # Maybe allow reconnecting?
581 581 if self._connected:
582 582 return
583 583 self._connected=True
584 584
585 585 def connect_socket(s, url):
586 586 url = util.disambiguate_url(url, self._config['location'])
587 587 if self._ssh:
588 588 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
589 589 else:
590 590 return s.connect(url)
591 591
592 592 self.session.send(self._query_socket, 'connection_request')
593 593 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
594 594 poller = zmq.Poller()
595 595 poller.register(self._query_socket, zmq.POLLIN)
596 596 # poll expects milliseconds, timeout is seconds
597 597 evts = poller.poll(timeout*1000)
598 598 if not evts:
599 599 raise error.TimeoutError("Hub connection request timed out")
600 600 idents,msg = self.session.recv(self._query_socket,mode=0)
601 601 if self.debug:
602 602 pprint(msg)
603 603 msg = Message(msg)
604 604 content = msg.content
605 605 self._config['registration'] = dict(content)
606 606 if content.status == 'ok':
607 607 ident = self.session.bsession
608 608 if content.mux:
609 609 self._mux_socket = self._context.socket(zmq.DEALER)
610 610 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
611 611 connect_socket(self._mux_socket, content.mux)
612 612 if content.task:
613 613 self._task_scheme, task_addr = content.task
614 614 self._task_socket = self._context.socket(zmq.DEALER)
615 615 self._task_socket.setsockopt(zmq.IDENTITY, ident)
616 616 connect_socket(self._task_socket, task_addr)
617 617 if content.notification:
618 618 self._notification_socket = self._context.socket(zmq.SUB)
619 619 connect_socket(self._notification_socket, content.notification)
620 620 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
621 621 # if content.query:
622 622 # self._query_socket = self._context.socket(zmq.DEALER)
623 623 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
624 624 # connect_socket(self._query_socket, content.query)
625 625 if content.control:
626 626 self._control_socket = self._context.socket(zmq.DEALER)
627 627 self._control_socket.setsockopt(zmq.IDENTITY, ident)
628 628 connect_socket(self._control_socket, content.control)
629 629 if content.iopub:
630 630 self._iopub_socket = self._context.socket(zmq.SUB)
631 631 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
632 632 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
633 633 connect_socket(self._iopub_socket, content.iopub)
634 634 self._update_engines(dict(content.engines))
635 635 else:
636 636 self._connected = False
637 637 raise Exception("Failed to connect!")
638 638
639 639 #--------------------------------------------------------------------------
640 640 # handlers and callbacks for incoming messages
641 641 #--------------------------------------------------------------------------
642 642
643 643 def _unwrap_exception(self, content):
644 644 """unwrap exception, and remap engine_id to int."""
645 645 e = error.unwrap_exception(content)
646 646 # print e.traceback
647 647 if e.engine_info:
648 648 e_uuid = e.engine_info['engine_uuid']
649 649 eid = self._engines[e_uuid]
650 650 e.engine_info['engine_id'] = eid
651 651 return e
652 652
653 653 def _extract_metadata(self, header, parent, content):
654 654 md = {'msg_id' : parent['msg_id'],
655 655 'received' : datetime.now(),
656 656 'engine_uuid' : header.get('engine', None),
657 657 'follow' : parent.get('follow', []),
658 658 'after' : parent.get('after', []),
659 659 'status' : content['status'],
660 660 }
661 661
662 662 if md['engine_uuid'] is not None:
663 663 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
664 664
665 665 if 'date' in parent:
666 666 md['submitted'] = parent['date']
667 667 if 'started' in header:
668 668 md['started'] = header['started']
669 669 if 'date' in header:
670 670 md['completed'] = header['date']
671 671 return md
672 672
673 673 def _register_engine(self, msg):
674 674 """Register a new engine, and update our connection info."""
675 675 content = msg['content']
676 676 eid = content['id']
677 677 d = {eid : content['queue']}
678 678 self._update_engines(d)
679 679
680 680 def _unregister_engine(self, msg):
681 681 """Unregister an engine that has died."""
682 682 content = msg['content']
683 683 eid = int(content['id'])
684 684 if eid in self._ids:
685 685 self._ids.remove(eid)
686 686 uuid = self._engines.pop(eid)
687 687
688 688 self._handle_stranded_msgs(eid, uuid)
689 689
690 690 if self._task_socket and self._task_scheme == 'pure':
691 691 self._stop_scheduling_tasks()
692 692
693 693 def _handle_stranded_msgs(self, eid, uuid):
694 694 """Handle messages known to be on an engine when the engine unregisters.
695 695
696 696 It is possible that this will fire prematurely - that is, an engine will
697 697 go down after completing a result, and the client will be notified
698 698 of the unregistration and later receive the successful result.
699 699 """
700 700
701 701 outstanding = self._outstanding_dict[uuid]
702 702
703 703 for msg_id in list(outstanding):
704 704 if msg_id in self.results:
705 705 # we already
706 706 continue
707 707 try:
708 708 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
709 709 except:
710 710 content = error.wrap_exception()
711 711 # build a fake message:
712 712 parent = {}
713 713 header = {}
714 714 parent['msg_id'] = msg_id
715 715 header['engine'] = uuid
716 716 header['date'] = datetime.now()
717 717 msg = dict(parent_header=parent, header=header, content=content)
718 718 self._handle_apply_reply(msg)
719 719
720 720 def _handle_execute_reply(self, msg):
721 721 """Save the reply to an execute_request into our results.
722 722
723 723 execute messages are never actually used. apply is used instead.
724 724 """
725 725
726 726 parent = msg['parent_header']
727 727 msg_id = parent['msg_id']
728 728 if msg_id not in self.outstanding:
729 729 if msg_id in self.history:
730 730 print ("got stale result: %s"%msg_id)
731 731 else:
732 732 print ("got unknown result: %s"%msg_id)
733 733 else:
734 734 self.outstanding.remove(msg_id)
735 735
736 736 content = msg['content']
737 737 header = msg['header']
738 738
739 739 # construct metadata:
740 740 md = self.metadata[msg_id]
741 741 md.update(self._extract_metadata(header, parent, content))
742 742 # is this redundant?
743 743 self.metadata[msg_id] = md
744 744
745 745 e_outstanding = self._outstanding_dict[md['engine_uuid']]
746 746 if msg_id in e_outstanding:
747 747 e_outstanding.remove(msg_id)
748 748
749 749 # construct result:
750 750 if content['status'] == 'ok':
751 751 self.results[msg_id] = ExecuteReply(msg_id, content, md)
752 752 elif content['status'] == 'aborted':
753 753 self.results[msg_id] = error.TaskAborted(msg_id)
754 754 elif content['status'] == 'resubmitted':
755 755 # TODO: handle resubmission
756 756 pass
757 757 else:
758 758 self.results[msg_id] = self._unwrap_exception(content)
759 759
760 760 def _handle_apply_reply(self, msg):
761 761 """Save the reply to an apply_request into our results."""
762 762 parent = msg['parent_header']
763 763 msg_id = parent['msg_id']
764 764 if msg_id not in self.outstanding:
765 765 if msg_id in self.history:
766 766 print ("got stale result: %s"%msg_id)
767 767 print self.results[msg_id]
768 768 print msg
769 769 else:
770 770 print ("got unknown result: %s"%msg_id)
771 771 else:
772 772 self.outstanding.remove(msg_id)
773 773 content = msg['content']
774 774 header = msg['header']
775 775
776 776 # construct metadata:
777 777 md = self.metadata[msg_id]
778 778 md.update(self._extract_metadata(header, parent, content))
779 779 # is this redundant?
780 780 self.metadata[msg_id] = md
781 781
782 782 e_outstanding = self._outstanding_dict[md['engine_uuid']]
783 783 if msg_id in e_outstanding:
784 784 e_outstanding.remove(msg_id)
785 785
786 786 # construct result:
787 787 if content['status'] == 'ok':
788 788 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
789 789 elif content['status'] == 'aborted':
790 790 self.results[msg_id] = error.TaskAborted(msg_id)
791 791 elif content['status'] == 'resubmitted':
792 792 # TODO: handle resubmission
793 793 pass
794 794 else:
795 795 self.results[msg_id] = self._unwrap_exception(content)
796 796
797 797 def _flush_notifications(self):
798 798 """Flush notifications of engine registrations waiting
799 799 in ZMQ queue."""
800 800 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
801 801 while msg is not None:
802 802 if self.debug:
803 803 pprint(msg)
804 804 msg_type = msg['header']['msg_type']
805 805 handler = self._notification_handlers.get(msg_type, None)
806 806 if handler is None:
807 raise Exception("Unhandled message type: %s"%msg.msg_type)
807 raise Exception("Unhandled message type: %s" % msg_type)
808 808 else:
809 809 handler(msg)
810 810 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
811 811
812 812 def _flush_results(self, sock):
813 813 """Flush task or queue results waiting in ZMQ queue."""
814 814 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
815 815 while msg is not None:
816 816 if self.debug:
817 817 pprint(msg)
818 818 msg_type = msg['header']['msg_type']
819 819 handler = self._queue_handlers.get(msg_type, None)
820 820 if handler is None:
821 raise Exception("Unhandled message type: %s"%msg.msg_type)
821 raise Exception("Unhandled message type: %s" % msg_type)
822 822 else:
823 823 handler(msg)
824 824 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
825 825
826 826 def _flush_control(self, sock):
827 827 """Flush replies from the control channel waiting
828 828 in the ZMQ queue.
829 829
830 830 Currently: ignore them."""
831 831 if self._ignored_control_replies <= 0:
832 832 return
833 833 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
834 834 while msg is not None:
835 835 self._ignored_control_replies -= 1
836 836 if self.debug:
837 837 pprint(msg)
838 838 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
839 839
840 840 def _flush_ignored_control(self):
841 841 """flush ignored control replies"""
842 842 while self._ignored_control_replies > 0:
843 843 self.session.recv(self._control_socket)
844 844 self._ignored_control_replies -= 1
845 845
846 846 def _flush_ignored_hub_replies(self):
847 847 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
848 848 while msg is not None:
849 849 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
850 850
851 851 def _flush_iopub(self, sock):
852 852 """Flush replies from the iopub channel waiting
853 853 in the ZMQ queue.
854 854 """
855 855 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
856 856 while msg is not None:
857 857 if self.debug:
858 858 pprint(msg)
859 859 parent = msg['parent_header']
860 860 # ignore IOPub messages with no parent.
861 861 # Caused by print statements or warnings from before the first execution.
862 862 if not parent:
863 863 continue
864 864 msg_id = parent['msg_id']
865 865 content = msg['content']
866 866 header = msg['header']
867 867 msg_type = msg['header']['msg_type']
868 868
869 869 # init metadata:
870 870 md = self.metadata[msg_id]
871 871
872 872 if msg_type == 'stream':
873 873 name = content['name']
874 874 s = md[name] or ''
875 875 md[name] = s + content['data']
876 876 elif msg_type == 'pyerr':
877 877 md.update({'pyerr' : self._unwrap_exception(content)})
878 878 elif msg_type == 'pyin':
879 879 md.update({'pyin' : content['code']})
880 880 elif msg_type == 'display_data':
881 881 md['outputs'].append(content)
882 882 elif msg_type == 'pyout':
883 883 md['pyout'] = content
884 884 elif msg_type == 'status':
885 885 # idle message comes after all outputs
886 886 if content['execution_state'] == 'idle':
887 887 md['outputs_ready'] = True
888 888 else:
889 889 # unhandled msg_type (status, etc.)
890 890 pass
891 891
892 892 # reduntant?
893 893 self.metadata[msg_id] = md
894 894
895 895 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
896 896
897 897 #--------------------------------------------------------------------------
898 898 # len, getitem
899 899 #--------------------------------------------------------------------------
900 900
901 901 def __len__(self):
902 902 """len(client) returns # of engines."""
903 903 return len(self.ids)
904 904
905 905 def __getitem__(self, key):
906 906 """index access returns DirectView multiplexer objects
907 907
908 908 Must be int, slice, or list/tuple/xrange of ints"""
909 909 if not isinstance(key, (int, slice, tuple, list, xrange)):
910 910 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
911 911 else:
912 912 return self.direct_view(key)
913 913
914 914 #--------------------------------------------------------------------------
915 915 # Begin public methods
916 916 #--------------------------------------------------------------------------
917 917
918 918 @property
919 919 def ids(self):
920 920 """Always up-to-date ids property."""
921 921 self._flush_notifications()
922 922 # always copy:
923 923 return list(self._ids)
924 924
925 925 def activate(self, targets='all', suffix=''):
926 926 """Create a DirectView and register it with IPython magics
927 927
928 928 Defines the magics `%px, %autopx, %pxresult, %%px`
929 929
930 930 Parameters
931 931 ----------
932 932
933 933 targets: int, list of ints, or 'all'
934 934 The engines on which the view's magics will run
935 935 suffix: str [default: '']
936 936 The suffix, if any, for the magics. This allows you to have
937 937 multiple views associated with parallel magics at the same time.
938 938
939 939 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
940 940 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
941 941 on engine 0.
942 942 """
943 943 view = self.direct_view(targets)
944 944 view.block = True
945 945 view.activate(suffix)
946 946 return view
947 947
948 948 def close(self):
949 949 if self._closed:
950 950 return
951 951 self.stop_spin_thread()
952 952 snames = filter(lambda n: n.endswith('socket'), dir(self))
953 953 for socket in map(lambda name: getattr(self, name), snames):
954 954 if isinstance(socket, zmq.Socket) and not socket.closed:
955 955 socket.close()
956 956 self._closed = True
957 957
958 958 def _spin_every(self, interval=1):
959 959 """target func for use in spin_thread"""
960 960 while True:
961 961 if self._stop_spinning.is_set():
962 962 return
963 963 time.sleep(interval)
964 964 self.spin()
965 965
966 966 def spin_thread(self, interval=1):
967 967 """call Client.spin() in a background thread on some regular interval
968 968
969 969 This helps ensure that messages don't pile up too much in the zmq queue
970 970 while you are working on other things, or just leaving an idle terminal.
971 971
972 972 It also helps limit potential padding of the `received` timestamp
973 973 on AsyncResult objects, used for timings.
974 974
975 975 Parameters
976 976 ----------
977 977
978 978 interval : float, optional
979 979 The interval on which to spin the client in the background thread
980 980 (simply passed to time.sleep).
981 981
982 982 Notes
983 983 -----
984 984
985 985 For precision timing, you may want to use this method to put a bound
986 986 on the jitter (in seconds) in `received` timestamps used
987 987 in AsyncResult.wall_time.
988 988
989 989 """
990 990 if self._spin_thread is not None:
991 991 self.stop_spin_thread()
992 992 self._stop_spinning.clear()
993 993 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
994 994 self._spin_thread.daemon = True
995 995 self._spin_thread.start()
996 996
997 997 def stop_spin_thread(self):
998 998 """stop background spin_thread, if any"""
999 999 if self._spin_thread is not None:
1000 1000 self._stop_spinning.set()
1001 1001 self._spin_thread.join()
1002 1002 self._spin_thread = None
1003 1003
1004 1004 def spin(self):
1005 1005 """Flush any registration notifications and execution results
1006 1006 waiting in the ZMQ queue.
1007 1007 """
1008 1008 if self._notification_socket:
1009 1009 self._flush_notifications()
1010 1010 if self._iopub_socket:
1011 1011 self._flush_iopub(self._iopub_socket)
1012 1012 if self._mux_socket:
1013 1013 self._flush_results(self._mux_socket)
1014 1014 if self._task_socket:
1015 1015 self._flush_results(self._task_socket)
1016 1016 if self._control_socket:
1017 1017 self._flush_control(self._control_socket)
1018 1018 if self._query_socket:
1019 1019 self._flush_ignored_hub_replies()
1020 1020
1021 1021 def wait(self, jobs=None, timeout=-1):
1022 1022 """waits on one or more `jobs`, for up to `timeout` seconds.
1023 1023
1024 1024 Parameters
1025 1025 ----------
1026 1026
1027 1027 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1028 1028 ints are indices to self.history
1029 1029 strs are msg_ids
1030 1030 default: wait on all outstanding messages
1031 1031 timeout : float
1032 1032 a time in seconds, after which to give up.
1033 1033 default is -1, which means no timeout
1034 1034
1035 1035 Returns
1036 1036 -------
1037 1037
1038 1038 True : when all msg_ids are done
1039 1039 False : timeout reached, some msg_ids still outstanding
1040 1040 """
1041 1041 tic = time.time()
1042 1042 if jobs is None:
1043 1043 theids = self.outstanding
1044 1044 else:
1045 1045 if isinstance(jobs, (int, basestring, AsyncResult)):
1046 1046 jobs = [jobs]
1047 1047 theids = set()
1048 1048 for job in jobs:
1049 1049 if isinstance(job, int):
1050 1050 # index access
1051 1051 job = self.history[job]
1052 1052 elif isinstance(job, AsyncResult):
1053 1053 map(theids.add, job.msg_ids)
1054 1054 continue
1055 1055 theids.add(job)
1056 1056 if not theids.intersection(self.outstanding):
1057 1057 return True
1058 1058 self.spin()
1059 1059 while theids.intersection(self.outstanding):
1060 1060 if timeout >= 0 and ( time.time()-tic ) > timeout:
1061 1061 break
1062 1062 time.sleep(1e-3)
1063 1063 self.spin()
1064 1064 return len(theids.intersection(self.outstanding)) == 0
1065 1065
1066 1066 #--------------------------------------------------------------------------
1067 1067 # Control methods
1068 1068 #--------------------------------------------------------------------------
1069 1069
1070 1070 @spin_first
1071 1071 def clear(self, targets=None, block=None):
1072 1072 """Clear the namespace in target(s)."""
1073 1073 block = self.block if block is None else block
1074 1074 targets = self._build_targets(targets)[0]
1075 1075 for t in targets:
1076 1076 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1077 1077 error = False
1078 1078 if block:
1079 1079 self._flush_ignored_control()
1080 1080 for i in range(len(targets)):
1081 1081 idents,msg = self.session.recv(self._control_socket,0)
1082 1082 if self.debug:
1083 1083 pprint(msg)
1084 1084 if msg['content']['status'] != 'ok':
1085 1085 error = self._unwrap_exception(msg['content'])
1086 1086 else:
1087 1087 self._ignored_control_replies += len(targets)
1088 1088 if error:
1089 1089 raise error
1090 1090
1091 1091
1092 1092 @spin_first
1093 1093 def abort(self, jobs=None, targets=None, block=None):
1094 1094 """Abort specific jobs from the execution queues of target(s).
1095 1095
1096 1096 This is a mechanism to prevent jobs that have already been submitted
1097 1097 from executing.
1098 1098
1099 1099 Parameters
1100 1100 ----------
1101 1101
1102 1102 jobs : msg_id, list of msg_ids, or AsyncResult
1103 1103 The jobs to be aborted
1104 1104
1105 1105 If unspecified/None: abort all outstanding jobs.
1106 1106
1107 1107 """
1108 1108 block = self.block if block is None else block
1109 1109 jobs = jobs if jobs is not None else list(self.outstanding)
1110 1110 targets = self._build_targets(targets)[0]
1111 1111
1112 1112 msg_ids = []
1113 1113 if isinstance(jobs, (basestring,AsyncResult)):
1114 1114 jobs = [jobs]
1115 1115 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1116 1116 if bad_ids:
1117 1117 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1118 1118 for j in jobs:
1119 1119 if isinstance(j, AsyncResult):
1120 1120 msg_ids.extend(j.msg_ids)
1121 1121 else:
1122 1122 msg_ids.append(j)
1123 1123 content = dict(msg_ids=msg_ids)
1124 1124 for t in targets:
1125 1125 self.session.send(self._control_socket, 'abort_request',
1126 1126 content=content, ident=t)
1127 1127 error = False
1128 1128 if block:
1129 1129 self._flush_ignored_control()
1130 1130 for i in range(len(targets)):
1131 1131 idents,msg = self.session.recv(self._control_socket,0)
1132 1132 if self.debug:
1133 1133 pprint(msg)
1134 1134 if msg['content']['status'] != 'ok':
1135 1135 error = self._unwrap_exception(msg['content'])
1136 1136 else:
1137 1137 self._ignored_control_replies += len(targets)
1138 1138 if error:
1139 1139 raise error
1140 1140
1141 1141 @spin_first
1142 1142 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1143 1143 """Terminates one or more engine processes, optionally including the hub.
1144 1144
1145 1145 Parameters
1146 1146 ----------
1147 1147
1148 1148 targets: list of ints or 'all' [default: all]
1149 1149 Which engines to shutdown.
1150 1150 hub: bool [default: False]
1151 1151 Whether to include the Hub. hub=True implies targets='all'.
1152 1152 block: bool [default: self.block]
1153 1153 Whether to wait for clean shutdown replies or not.
1154 1154 restart: bool [default: False]
1155 1155 NOT IMPLEMENTED
1156 1156 whether to restart engines after shutting them down.
1157 1157 """
1158 1158 from IPython.parallel.error import NoEnginesRegistered
1159 1159 if restart:
1160 1160 raise NotImplementedError("Engine restart is not yet implemented")
1161 1161
1162 1162 block = self.block if block is None else block
1163 1163 if hub:
1164 1164 targets = 'all'
1165 1165 try:
1166 1166 targets = self._build_targets(targets)[0]
1167 1167 except NoEnginesRegistered:
1168 1168 targets = []
1169 1169 for t in targets:
1170 1170 self.session.send(self._control_socket, 'shutdown_request',
1171 1171 content={'restart':restart},ident=t)
1172 1172 error = False
1173 1173 if block or hub:
1174 1174 self._flush_ignored_control()
1175 1175 for i in range(len(targets)):
1176 1176 idents,msg = self.session.recv(self._control_socket, 0)
1177 1177 if self.debug:
1178 1178 pprint(msg)
1179 1179 if msg['content']['status'] != 'ok':
1180 1180 error = self._unwrap_exception(msg['content'])
1181 1181 else:
1182 1182 self._ignored_control_replies += len(targets)
1183 1183
1184 1184 if hub:
1185 1185 time.sleep(0.25)
1186 1186 self.session.send(self._query_socket, 'shutdown_request')
1187 1187 idents,msg = self.session.recv(self._query_socket, 0)
1188 1188 if self.debug:
1189 1189 pprint(msg)
1190 1190 if msg['content']['status'] != 'ok':
1191 1191 error = self._unwrap_exception(msg['content'])
1192 1192
1193 1193 if error:
1194 1194 raise error
1195 1195
1196 1196 #--------------------------------------------------------------------------
1197 1197 # Execution related methods
1198 1198 #--------------------------------------------------------------------------
1199 1199
1200 1200 def _maybe_raise(self, result):
1201 1201 """wrapper for maybe raising an exception if apply failed."""
1202 1202 if isinstance(result, error.RemoteError):
1203 1203 raise result
1204 1204
1205 1205 return result
1206 1206
1207 1207 def send_apply_request(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
1208 1208 ident=None):
1209 1209 """construct and send an apply message via a socket.
1210 1210
1211 1211 This is the principal method with which all engine execution is performed by views.
1212 1212 """
1213 1213
1214 1214 if self._closed:
1215 1215 raise RuntimeError("Client cannot be used after its sockets have been closed")
1216 1216
1217 1217 # defaults:
1218 1218 args = args if args is not None else []
1219 1219 kwargs = kwargs if kwargs is not None else {}
1220 1220 subheader = subheader if subheader is not None else {}
1221 1221
1222 1222 # validate arguments
1223 1223 if not callable(f) and not isinstance(f, Reference):
1224 1224 raise TypeError("f must be callable, not %s"%type(f))
1225 1225 if not isinstance(args, (tuple, list)):
1226 1226 raise TypeError("args must be tuple or list, not %s"%type(args))
1227 1227 if not isinstance(kwargs, dict):
1228 1228 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1229 1229 if not isinstance(subheader, dict):
1230 1230 raise TypeError("subheader must be dict, not %s"%type(subheader))
1231 1231
1232 1232 bufs = util.pack_apply_message(f,args,kwargs)
1233 1233
1234 1234 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1235 1235 subheader=subheader, track=track)
1236 1236
1237 1237 msg_id = msg['header']['msg_id']
1238 1238 self.outstanding.add(msg_id)
1239 1239 if ident:
1240 1240 # possibly routed to a specific engine
1241 1241 if isinstance(ident, list):
1242 1242 ident = ident[-1]
1243 1243 if ident in self._engines.values():
1244 1244 # save for later, in case of engine death
1245 1245 self._outstanding_dict[ident].add(msg_id)
1246 1246 self.history.append(msg_id)
1247 1247 self.metadata[msg_id]['submitted'] = datetime.now()
1248 1248
1249 1249 return msg
1250 1250
1251 1251 def send_execute_request(self, socket, code, silent=True, subheader=None, ident=None):
1252 1252 """construct and send an execute request via a socket.
1253 1253
1254 1254 """
1255 1255
1256 1256 if self._closed:
1257 1257 raise RuntimeError("Client cannot be used after its sockets have been closed")
1258 1258
1259 1259 # defaults:
1260 1260 subheader = subheader if subheader is not None else {}
1261 1261
1262 1262 # validate arguments
1263 1263 if not isinstance(code, basestring):
1264 1264 raise TypeError("code must be text, not %s" % type(code))
1265 1265 if not isinstance(subheader, dict):
1266 1266 raise TypeError("subheader must be dict, not %s" % type(subheader))
1267 1267
1268 1268 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1269 1269
1270 1270
1271 1271 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1272 1272 subheader=subheader)
1273 1273
1274 1274 msg_id = msg['header']['msg_id']
1275 1275 self.outstanding.add(msg_id)
1276 1276 if ident:
1277 1277 # possibly routed to a specific engine
1278 1278 if isinstance(ident, list):
1279 1279 ident = ident[-1]
1280 1280 if ident in self._engines.values():
1281 1281 # save for later, in case of engine death
1282 1282 self._outstanding_dict[ident].add(msg_id)
1283 1283 self.history.append(msg_id)
1284 1284 self.metadata[msg_id]['submitted'] = datetime.now()
1285 1285
1286 1286 return msg
1287 1287
1288 1288 #--------------------------------------------------------------------------
1289 1289 # construct a View object
1290 1290 #--------------------------------------------------------------------------
1291 1291
1292 1292 def load_balanced_view(self, targets=None):
1293 1293 """construct a DirectView object.
1294 1294
1295 1295 If no arguments are specified, create a LoadBalancedView
1296 1296 using all engines.
1297 1297
1298 1298 Parameters
1299 1299 ----------
1300 1300
1301 1301 targets: list,slice,int,etc. [default: use all engines]
1302 1302 The subset of engines across which to load-balance
1303 1303 """
1304 1304 if targets == 'all':
1305 1305 targets = None
1306 1306 if targets is not None:
1307 1307 targets = self._build_targets(targets)[1]
1308 1308 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1309 1309
1310 1310 def direct_view(self, targets='all'):
1311 1311 """construct a DirectView object.
1312 1312
1313 1313 If no targets are specified, create a DirectView using all engines.
1314 1314
1315 1315 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1316 1316 evaluate the target engines at each execution, whereas rc[:] will connect to
1317 1317 all *current* engines, and that list will not change.
1318 1318
1319 1319 That is, 'all' will always use all engines, whereas rc[:] will not use
1320 1320 engines added after the DirectView is constructed.
1321 1321
1322 1322 Parameters
1323 1323 ----------
1324 1324
1325 1325 targets: list,slice,int,etc. [default: use all engines]
1326 1326 The engines to use for the View
1327 1327 """
1328 1328 single = isinstance(targets, int)
1329 1329 # allow 'all' to be lazily evaluated at each execution
1330 1330 if targets != 'all':
1331 1331 targets = self._build_targets(targets)[1]
1332 1332 if single:
1333 1333 targets = targets[0]
1334 1334 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1335 1335
1336 1336 #--------------------------------------------------------------------------
1337 1337 # Query methods
1338 1338 #--------------------------------------------------------------------------
1339 1339
1340 1340 @spin_first
1341 1341 def get_result(self, indices_or_msg_ids=None, block=None):
1342 1342 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1343 1343
1344 1344 If the client already has the results, no request to the Hub will be made.
1345 1345
1346 1346 This is a convenient way to construct AsyncResult objects, which are wrappers
1347 1347 that include metadata about execution, and allow for awaiting results that
1348 1348 were not submitted by this Client.
1349 1349
1350 1350 It can also be a convenient way to retrieve the metadata associated with
1351 1351 blocking execution, since it always retrieves
1352 1352
1353 1353 Examples
1354 1354 --------
1355 1355 ::
1356 1356
1357 1357 In [10]: r = client.apply()
1358 1358
1359 1359 Parameters
1360 1360 ----------
1361 1361
1362 1362 indices_or_msg_ids : integer history index, str msg_id, or list of either
1363 1363 The indices or msg_ids of indices to be retrieved
1364 1364
1365 1365 block : bool
1366 1366 Whether to wait for the result to be done
1367 1367
1368 1368 Returns
1369 1369 -------
1370 1370
1371 1371 AsyncResult
1372 1372 A single AsyncResult object will always be returned.
1373 1373
1374 1374 AsyncHubResult
1375 1375 A subclass of AsyncResult that retrieves results from the Hub
1376 1376
1377 1377 """
1378 1378 block = self.block if block is None else block
1379 1379 if indices_or_msg_ids is None:
1380 1380 indices_or_msg_ids = -1
1381 1381
1382 1382 if not isinstance(indices_or_msg_ids, (list,tuple)):
1383 1383 indices_or_msg_ids = [indices_or_msg_ids]
1384 1384
1385 1385 theids = []
1386 1386 for id in indices_or_msg_ids:
1387 1387 if isinstance(id, int):
1388 1388 id = self.history[id]
1389 1389 if not isinstance(id, basestring):
1390 1390 raise TypeError("indices must be str or int, not %r"%id)
1391 1391 theids.append(id)
1392 1392
1393 1393 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1394 1394 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1395 1395
1396 1396 if remote_ids:
1397 1397 ar = AsyncHubResult(self, msg_ids=theids)
1398 1398 else:
1399 1399 ar = AsyncResult(self, msg_ids=theids)
1400 1400
1401 1401 if block:
1402 1402 ar.wait()
1403 1403
1404 1404 return ar
1405 1405
1406 1406 @spin_first
1407 1407 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1408 1408 """Resubmit one or more tasks.
1409 1409
1410 1410 in-flight tasks may not be resubmitted.
1411 1411
1412 1412 Parameters
1413 1413 ----------
1414 1414
1415 1415 indices_or_msg_ids : integer history index, str msg_id, or list of either
1416 1416 The indices or msg_ids of indices to be retrieved
1417 1417
1418 1418 block : bool
1419 1419 Whether to wait for the result to be done
1420 1420
1421 1421 Returns
1422 1422 -------
1423 1423
1424 1424 AsyncHubResult
1425 1425 A subclass of AsyncResult that retrieves results from the Hub
1426 1426
1427 1427 """
1428 1428 block = self.block if block is None else block
1429 1429 if indices_or_msg_ids is None:
1430 1430 indices_or_msg_ids = -1
1431 1431
1432 1432 if not isinstance(indices_or_msg_ids, (list,tuple)):
1433 1433 indices_or_msg_ids = [indices_or_msg_ids]
1434 1434
1435 1435 theids = []
1436 1436 for id in indices_or_msg_ids:
1437 1437 if isinstance(id, int):
1438 1438 id = self.history[id]
1439 1439 if not isinstance(id, basestring):
1440 1440 raise TypeError("indices must be str or int, not %r"%id)
1441 1441 theids.append(id)
1442 1442
1443 1443 content = dict(msg_ids = theids)
1444 1444
1445 1445 self.session.send(self._query_socket, 'resubmit_request', content)
1446 1446
1447 1447 zmq.select([self._query_socket], [], [])
1448 1448 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1449 1449 if self.debug:
1450 1450 pprint(msg)
1451 1451 content = msg['content']
1452 1452 if content['status'] != 'ok':
1453 1453 raise self._unwrap_exception(content)
1454 1454 mapping = content['resubmitted']
1455 1455 new_ids = [ mapping[msg_id] for msg_id in theids ]
1456 1456
1457 1457 ar = AsyncHubResult(self, msg_ids=new_ids)
1458 1458
1459 1459 if block:
1460 1460 ar.wait()
1461 1461
1462 1462 return ar
1463 1463
1464 1464 @spin_first
1465 1465 def result_status(self, msg_ids, status_only=True):
1466 1466 """Check on the status of the result(s) of the apply request with `msg_ids`.
1467 1467
1468 1468 If status_only is False, then the actual results will be retrieved, else
1469 1469 only the status of the results will be checked.
1470 1470
1471 1471 Parameters
1472 1472 ----------
1473 1473
1474 1474 msg_ids : list of msg_ids
1475 1475 if int:
1476 1476 Passed as index to self.history for convenience.
1477 1477 status_only : bool (default: True)
1478 1478 if False:
1479 1479 Retrieve the actual results of completed tasks.
1480 1480
1481 1481 Returns
1482 1482 -------
1483 1483
1484 1484 results : dict
1485 1485 There will always be the keys 'pending' and 'completed', which will
1486 1486 be lists of msg_ids that are incomplete or complete. If `status_only`
1487 1487 is False, then completed results will be keyed by their `msg_id`.
1488 1488 """
1489 1489 if not isinstance(msg_ids, (list,tuple)):
1490 1490 msg_ids = [msg_ids]
1491 1491
1492 1492 theids = []
1493 1493 for msg_id in msg_ids:
1494 1494 if isinstance(msg_id, int):
1495 1495 msg_id = self.history[msg_id]
1496 1496 if not isinstance(msg_id, basestring):
1497 1497 raise TypeError("msg_ids must be str, not %r"%msg_id)
1498 1498 theids.append(msg_id)
1499 1499
1500 1500 completed = []
1501 1501 local_results = {}
1502 1502
1503 1503 # comment this block out to temporarily disable local shortcut:
1504 1504 for msg_id in theids:
1505 1505 if msg_id in self.results:
1506 1506 completed.append(msg_id)
1507 1507 local_results[msg_id] = self.results[msg_id]
1508 1508 theids.remove(msg_id)
1509 1509
1510 1510 if theids: # some not locally cached
1511 1511 content = dict(msg_ids=theids, status_only=status_only)
1512 1512 msg = self.session.send(self._query_socket, "result_request", content=content)
1513 1513 zmq.select([self._query_socket], [], [])
1514 1514 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1515 1515 if self.debug:
1516 1516 pprint(msg)
1517 1517 content = msg['content']
1518 1518 if content['status'] != 'ok':
1519 1519 raise self._unwrap_exception(content)
1520 1520 buffers = msg['buffers']
1521 1521 else:
1522 1522 content = dict(completed=[],pending=[])
1523 1523
1524 1524 content['completed'].extend(completed)
1525 1525
1526 1526 if status_only:
1527 1527 return content
1528 1528
1529 1529 failures = []
1530 1530 # load cached results into result:
1531 1531 content.update(local_results)
1532 1532
1533 1533 # update cache with results:
1534 1534 for msg_id in sorted(theids):
1535 1535 if msg_id in content['completed']:
1536 1536 rec = content[msg_id]
1537 1537 parent = rec['header']
1538 1538 header = rec['result_header']
1539 1539 rcontent = rec['result_content']
1540 1540 iodict = rec['io']
1541 1541 if isinstance(rcontent, str):
1542 1542 rcontent = self.session.unpack(rcontent)
1543 1543
1544 1544 md = self.metadata[msg_id]
1545 1545 md.update(self._extract_metadata(header, parent, rcontent))
1546 1546 if rec.get('received'):
1547 1547 md['received'] = rec['received']
1548 1548 md.update(iodict)
1549 1549
1550 1550 if rcontent['status'] == 'ok':
1551 1551 if header['msg_type'] == 'apply_reply':
1552 1552 res,buffers = util.unserialize_object(buffers)
1553 1553 elif header['msg_type'] == 'execute_reply':
1554 1554 res = ExecuteReply(msg_id, rcontent, md)
1555 1555 else:
1556 1556 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1557 1557 else:
1558 1558 res = self._unwrap_exception(rcontent)
1559 1559 failures.append(res)
1560 1560
1561 1561 self.results[msg_id] = res
1562 1562 content[msg_id] = res
1563 1563
1564 1564 if len(theids) == 1 and failures:
1565 1565 raise failures[0]
1566 1566
1567 1567 error.collect_exceptions(failures, "result_status")
1568 1568 return content
1569 1569
1570 1570 @spin_first
1571 1571 def queue_status(self, targets='all', verbose=False):
1572 1572 """Fetch the status of engine queues.
1573 1573
1574 1574 Parameters
1575 1575 ----------
1576 1576
1577 1577 targets : int/str/list of ints/strs
1578 1578 the engines whose states are to be queried.
1579 1579 default : all
1580 1580 verbose : bool
1581 1581 Whether to return lengths only, or lists of ids for each element
1582 1582 """
1583 1583 if targets == 'all':
1584 1584 # allow 'all' to be evaluated on the engine
1585 1585 engine_ids = None
1586 1586 else:
1587 1587 engine_ids = self._build_targets(targets)[1]
1588 1588 content = dict(targets=engine_ids, verbose=verbose)
1589 1589 self.session.send(self._query_socket, "queue_request", content=content)
1590 1590 idents,msg = self.session.recv(self._query_socket, 0)
1591 1591 if self.debug:
1592 1592 pprint(msg)
1593 1593 content = msg['content']
1594 1594 status = content.pop('status')
1595 1595 if status != 'ok':
1596 1596 raise self._unwrap_exception(content)
1597 1597 content = rekey(content)
1598 1598 if isinstance(targets, int):
1599 1599 return content[targets]
1600 1600 else:
1601 1601 return content
1602 1602
1603 1603 @spin_first
1604 1604 def purge_results(self, jobs=[], targets=[]):
1605 1605 """Tell the Hub to forget results.
1606 1606
1607 1607 Individual results can be purged by msg_id, or the entire
1608 1608 history of specific targets can be purged.
1609 1609
1610 1610 Use `purge_results('all')` to scrub everything from the Hub's db.
1611 1611
1612 1612 Parameters
1613 1613 ----------
1614 1614
1615 1615 jobs : str or list of str or AsyncResult objects
1616 1616 the msg_ids whose results should be forgotten.
1617 1617 targets : int/str/list of ints/strs
1618 1618 The targets, by int_id, whose entire history is to be purged.
1619 1619
1620 1620 default : None
1621 1621 """
1622 1622 if not targets and not jobs:
1623 1623 raise ValueError("Must specify at least one of `targets` and `jobs`")
1624 1624 if targets:
1625 1625 targets = self._build_targets(targets)[1]
1626 1626
1627 1627 # construct msg_ids from jobs
1628 1628 if jobs == 'all':
1629 1629 msg_ids = jobs
1630 1630 else:
1631 1631 msg_ids = []
1632 1632 if isinstance(jobs, (basestring,AsyncResult)):
1633 1633 jobs = [jobs]
1634 1634 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1635 1635 if bad_ids:
1636 1636 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1637 1637 for j in jobs:
1638 1638 if isinstance(j, AsyncResult):
1639 1639 msg_ids.extend(j.msg_ids)
1640 1640 else:
1641 1641 msg_ids.append(j)
1642 1642
1643 1643 content = dict(engine_ids=targets, msg_ids=msg_ids)
1644 1644 self.session.send(self._query_socket, "purge_request", content=content)
1645 1645 idents, msg = self.session.recv(self._query_socket, 0)
1646 1646 if self.debug:
1647 1647 pprint(msg)
1648 1648 content = msg['content']
1649 1649 if content['status'] != 'ok':
1650 1650 raise self._unwrap_exception(content)
1651 1651
1652 1652 @spin_first
1653 1653 def hub_history(self):
1654 1654 """Get the Hub's history
1655 1655
1656 1656 Just like the Client, the Hub has a history, which is a list of msg_ids.
1657 1657 This will contain the history of all clients, and, depending on configuration,
1658 1658 may contain history across multiple cluster sessions.
1659 1659
1660 1660 Any msg_id returned here is a valid argument to `get_result`.
1661 1661
1662 1662 Returns
1663 1663 -------
1664 1664
1665 1665 msg_ids : list of strs
1666 1666 list of all msg_ids, ordered by task submission time.
1667 1667 """
1668 1668
1669 1669 self.session.send(self._query_socket, "history_request", content={})
1670 1670 idents, msg = self.session.recv(self._query_socket, 0)
1671 1671
1672 1672 if self.debug:
1673 1673 pprint(msg)
1674 1674 content = msg['content']
1675 1675 if content['status'] != 'ok':
1676 1676 raise self._unwrap_exception(content)
1677 1677 else:
1678 1678 return content['history']
1679 1679
1680 1680 @spin_first
1681 1681 def db_query(self, query, keys=None):
1682 1682 """Query the Hub's TaskRecord database
1683 1683
1684 1684 This will return a list of task record dicts that match `query`
1685 1685
1686 1686 Parameters
1687 1687 ----------
1688 1688
1689 1689 query : mongodb query dict
1690 1690 The search dict. See mongodb query docs for details.
1691 1691 keys : list of strs [optional]
1692 1692 The subset of keys to be returned. The default is to fetch everything but buffers.
1693 1693 'msg_id' will *always* be included.
1694 1694 """
1695 1695 if isinstance(keys, basestring):
1696 1696 keys = [keys]
1697 1697 content = dict(query=query, keys=keys)
1698 1698 self.session.send(self._query_socket, "db_request", content=content)
1699 1699 idents, msg = self.session.recv(self._query_socket, 0)
1700 1700 if self.debug:
1701 1701 pprint(msg)
1702 1702 content = msg['content']
1703 1703 if content['status'] != 'ok':
1704 1704 raise self._unwrap_exception(content)
1705 1705
1706 1706 records = content['records']
1707 1707
1708 1708 buffer_lens = content['buffer_lens']
1709 1709 result_buffer_lens = content['result_buffer_lens']
1710 1710 buffers = msg['buffers']
1711 1711 has_bufs = buffer_lens is not None
1712 1712 has_rbufs = result_buffer_lens is not None
1713 1713 for i,rec in enumerate(records):
1714 1714 # relink buffers
1715 1715 if has_bufs:
1716 1716 blen = buffer_lens[i]
1717 1717 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1718 1718 if has_rbufs:
1719 1719 blen = result_buffer_lens[i]
1720 1720 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1721 1721
1722 1722 return records
1723 1723
1724 1724 __all__ = [ 'Client' ]
General Comments 0
You need to be logged in to leave comments. Login now