##// END OF EJS Templates
Backport PR #4074: close Client sockets if connection fails...
MinRK -
Show More
@@ -1,1839 +1,1852 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 from threading import Thread, Event
22 22 import time
23 23 import warnings
24 24 from datetime import datetime
25 25 from getpass import getpass
26 26 from pprint import pprint
27 27
28 28 pjoin = os.path.join
29 29
30 30 import zmq
31 31 # from zmq.eventloop import ioloop, zmqstream
32 32
33 33 from IPython.config.configurable import MultipleInstanceError
34 34 from IPython.core.application import BaseIPythonApplication
35 35 from IPython.core.profiledir import ProfileDir, ProfileDirError
36 36
37 37 from IPython.utils.coloransi import TermColors
38 38 from IPython.utils.jsonutil import rekey
39 39 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
40 40 from IPython.utils.path import get_ipython_dir
41 41 from IPython.utils.py3compat import cast_bytes
42 42 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
43 43 Dict, List, Bool, Set, Any)
44 44 from IPython.external.decorator import decorator
45 45 from IPython.external.ssh import tunnel
46 46
47 47 from IPython.parallel import Reference
48 48 from IPython.parallel import error
49 49 from IPython.parallel import util
50 50
51 51 from IPython.kernel.zmq.session import Session, Message
52 52 from IPython.kernel.zmq import serialize
53 53
54 54 from .asyncresult import AsyncResult, AsyncHubResult
55 55 from .view import DirectView, LoadBalancedView
56 56
57 57 if sys.version_info[0] >= 3:
58 58 # xrange is used in a couple 'isinstance' tests in py2
59 59 # should be just 'range' in 3k
60 60 xrange = range
61 61
62 62 #--------------------------------------------------------------------------
63 63 # Decorators for Client methods
64 64 #--------------------------------------------------------------------------
65 65
66 66 @decorator
67 67 def spin_first(f, self, *args, **kwargs):
68 68 """Call spin() to sync state prior to calling the method."""
69 69 self.spin()
70 70 return f(self, *args, **kwargs)
71 71
72 72
73 73 #--------------------------------------------------------------------------
74 74 # Classes
75 75 #--------------------------------------------------------------------------
76 76
77 77
78 78 class ExecuteReply(object):
79 79 """wrapper for finished Execute results"""
80 80 def __init__(self, msg_id, content, metadata):
81 81 self.msg_id = msg_id
82 82 self._content = content
83 83 self.execution_count = content['execution_count']
84 84 self.metadata = metadata
85 85
86 86 def __getitem__(self, key):
87 87 return self.metadata[key]
88 88
89 89 def __getattr__(self, key):
90 90 if key not in self.metadata:
91 91 raise AttributeError(key)
92 92 return self.metadata[key]
93 93
94 94 def __repr__(self):
95 95 pyout = self.metadata['pyout'] or {'data':{}}
96 96 text_out = pyout['data'].get('text/plain', '')
97 97 if len(text_out) > 32:
98 98 text_out = text_out[:29] + '...'
99 99
100 100 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
101 101
102 102 def _repr_pretty_(self, p, cycle):
103 103 pyout = self.metadata['pyout'] or {'data':{}}
104 104 text_out = pyout['data'].get('text/plain', '')
105 105
106 106 if not text_out:
107 107 return
108 108
109 109 try:
110 110 ip = get_ipython()
111 111 except NameError:
112 112 colors = "NoColor"
113 113 else:
114 114 colors = ip.colors
115 115
116 116 if colors == "NoColor":
117 117 out = normal = ""
118 118 else:
119 119 out = TermColors.Red
120 120 normal = TermColors.Normal
121 121
122 122 if '\n' in text_out and not text_out.startswith('\n'):
123 123 # add newline for multiline reprs
124 124 text_out = '\n' + text_out
125 125
126 126 p.text(
127 127 out + u'Out[%i:%i]: ' % (
128 128 self.metadata['engine_id'], self.execution_count
129 129 ) + normal + text_out
130 130 )
131 131
132 132 def _repr_html_(self):
133 133 pyout = self.metadata['pyout'] or {'data':{}}
134 134 return pyout['data'].get("text/html")
135 135
136 136 def _repr_latex_(self):
137 137 pyout = self.metadata['pyout'] or {'data':{}}
138 138 return pyout['data'].get("text/latex")
139 139
140 140 def _repr_json_(self):
141 141 pyout = self.metadata['pyout'] or {'data':{}}
142 142 return pyout['data'].get("application/json")
143 143
144 144 def _repr_javascript_(self):
145 145 pyout = self.metadata['pyout'] or {'data':{}}
146 146 return pyout['data'].get("application/javascript")
147 147
148 148 def _repr_png_(self):
149 149 pyout = self.metadata['pyout'] or {'data':{}}
150 150 return pyout['data'].get("image/png")
151 151
152 152 def _repr_jpeg_(self):
153 153 pyout = self.metadata['pyout'] or {'data':{}}
154 154 return pyout['data'].get("image/jpeg")
155 155
156 156 def _repr_svg_(self):
157 157 pyout = self.metadata['pyout'] or {'data':{}}
158 158 return pyout['data'].get("image/svg+xml")
159 159
160 160
161 161 class Metadata(dict):
162 162 """Subclass of dict for initializing metadata values.
163 163
164 164 Attribute access works on keys.
165 165
166 166 These objects have a strict set of keys - errors will raise if you try
167 167 to add new keys.
168 168 """
169 169 def __init__(self, *args, **kwargs):
170 170 dict.__init__(self)
171 171 md = {'msg_id' : None,
172 172 'submitted' : None,
173 173 'started' : None,
174 174 'completed' : None,
175 175 'received' : None,
176 176 'engine_uuid' : None,
177 177 'engine_id' : None,
178 178 'follow' : None,
179 179 'after' : None,
180 180 'status' : None,
181 181
182 182 'pyin' : None,
183 183 'pyout' : None,
184 184 'pyerr' : None,
185 185 'stdout' : '',
186 186 'stderr' : '',
187 187 'outputs' : [],
188 188 'data': {},
189 189 'outputs_ready' : False,
190 190 }
191 191 self.update(md)
192 192 self.update(dict(*args, **kwargs))
193 193
194 194 def __getattr__(self, key):
195 195 """getattr aliased to getitem"""
196 196 if key in self.iterkeys():
197 197 return self[key]
198 198 else:
199 199 raise AttributeError(key)
200 200
201 201 def __setattr__(self, key, value):
202 202 """setattr aliased to setitem, with strict"""
203 203 if key in self.iterkeys():
204 204 self[key] = value
205 205 else:
206 206 raise AttributeError(key)
207 207
208 208 def __setitem__(self, key, value):
209 209 """strict static key enforcement"""
210 210 if key in self.iterkeys():
211 211 dict.__setitem__(self, key, value)
212 212 else:
213 213 raise KeyError(key)
214 214
215 215
216 216 class Client(HasTraits):
217 217 """A semi-synchronous client to the IPython ZMQ cluster
218 218
219 219 Parameters
220 220 ----------
221 221
222 222 url_file : str/unicode; path to ipcontroller-client.json
223 223 This JSON file should contain all the information needed to connect to a cluster,
224 224 and is likely the only argument needed.
225 225 Connection information for the Hub's registration. If a json connector
226 226 file is given, then likely no further configuration is necessary.
227 227 [Default: use profile]
228 228 profile : bytes
229 229 The name of the Cluster profile to be used to find connector information.
230 230 If run from an IPython application, the default profile will be the same
231 231 as the running application, otherwise it will be 'default'.
232 232 cluster_id : str
233 233 String id to added to runtime files, to prevent name collisions when using
234 234 multiple clusters with a single profile simultaneously.
235 235 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
236 236 Since this is text inserted into filenames, typical recommendations apply:
237 237 Simple character strings are ideal, and spaces are not recommended (but
238 238 should generally work)
239 239 context : zmq.Context
240 240 Pass an existing zmq.Context instance, otherwise the client will create its own.
241 241 debug : bool
242 242 flag for lots of message printing for debug purposes
243 243 timeout : int/float
244 244 time (in seconds) to wait for connection replies from the Hub
245 245 [Default: 10]
246 246
247 247 #-------------- session related args ----------------
248 248
249 249 config : Config object
250 250 If specified, this will be relayed to the Session for configuration
251 251 username : str
252 252 set username for the session object
253 253
254 254 #-------------- ssh related args ----------------
255 255 # These are args for configuring the ssh tunnel to be used
256 256 # credentials are used to forward connections over ssh to the Controller
257 257 # Note that the ip given in `addr` needs to be relative to sshserver
258 258 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
259 259 # and set sshserver as the same machine the Controller is on. However,
260 260 # the only requirement is that sshserver is able to see the Controller
261 261 # (i.e. is within the same trusted network).
262 262
263 263 sshserver : str
264 264 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
265 265 If keyfile or password is specified, and this is not, it will default to
266 266 the ip given in addr.
267 267 sshkey : str; path to ssh private key file
268 268 This specifies a key to be used in ssh login, default None.
269 269 Regular default ssh keys will be used without specifying this argument.
270 270 password : str
271 271 Your ssh password to sshserver. Note that if this is left None,
272 272 you will be prompted for it if passwordless key based login is unavailable.
273 273 paramiko : bool
274 274 flag for whether to use paramiko instead of shell ssh for tunneling.
275 275 [default: True on win32, False else]
276 276
277 277
278 278 Attributes
279 279 ----------
280 280
281 281 ids : list of int engine IDs
282 282 requesting the ids attribute always synchronizes
283 283 the registration state. To request ids without synchronization,
284 284 use semi-private _ids attributes.
285 285
286 286 history : list of msg_ids
287 287 a list of msg_ids, keeping track of all the execution
288 288 messages you have submitted in order.
289 289
290 290 outstanding : set of msg_ids
291 291 a set of msg_ids that have been submitted, but whose
292 292 results have not yet been received.
293 293
294 294 results : dict
295 295 a dict of all our results, keyed by msg_id
296 296
297 297 block : bool
298 298 determines default behavior when block not specified
299 299 in execution methods
300 300
301 301 Methods
302 302 -------
303 303
304 304 spin
305 305 flushes incoming results and registration state changes
306 306 control methods spin, and requesting `ids` also ensures up to date
307 307
308 308 wait
309 309 wait on one or more msg_ids
310 310
311 311 execution methods
312 312 apply
313 313 legacy: execute, run
314 314
315 315 data movement
316 316 push, pull, scatter, gather
317 317
318 318 query methods
319 319 queue_status, get_result, purge, result_status
320 320
321 321 control methods
322 322 abort, shutdown
323 323
324 324 """
325 325
326 326
327 327 block = Bool(False)
328 328 outstanding = Set()
329 329 results = Instance('collections.defaultdict', (dict,))
330 330 metadata = Instance('collections.defaultdict', (Metadata,))
331 331 history = List()
332 332 debug = Bool(False)
333 333 _spin_thread = Any()
334 334 _stop_spinning = Any()
335 335
336 336 profile=Unicode()
337 337 def _profile_default(self):
338 338 if BaseIPythonApplication.initialized():
339 339 # an IPython app *might* be running, try to get its profile
340 340 try:
341 341 return BaseIPythonApplication.instance().profile
342 342 except (AttributeError, MultipleInstanceError):
343 343 # could be a *different* subclass of config.Application,
344 344 # which would raise one of these two errors.
345 345 return u'default'
346 346 else:
347 347 return u'default'
348 348
349 349
350 350 _outstanding_dict = Instance('collections.defaultdict', (set,))
351 351 _ids = List()
352 352 _connected=Bool(False)
353 353 _ssh=Bool(False)
354 354 _context = Instance('zmq.Context')
355 355 _config = Dict()
356 356 _engines=Instance(util.ReverseDict, (), {})
357 357 # _hub_socket=Instance('zmq.Socket')
358 358 _query_socket=Instance('zmq.Socket')
359 359 _control_socket=Instance('zmq.Socket')
360 360 _iopub_socket=Instance('zmq.Socket')
361 361 _notification_socket=Instance('zmq.Socket')
362 362 _mux_socket=Instance('zmq.Socket')
363 363 _task_socket=Instance('zmq.Socket')
364 364 _task_scheme=Unicode()
365 365 _closed = False
366 366 _ignored_control_replies=Integer(0)
367 367 _ignored_hub_replies=Integer(0)
368 368
369 369 def __new__(self, *args, **kw):
370 370 # don't raise on positional args
371 371 return HasTraits.__new__(self, **kw)
372 372
373 373 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
374 374 context=None, debug=False,
375 375 sshserver=None, sshkey=None, password=None, paramiko=None,
376 376 timeout=10, cluster_id=None, **extra_args
377 377 ):
378 378 if profile:
379 379 super(Client, self).__init__(debug=debug, profile=profile)
380 380 else:
381 381 super(Client, self).__init__(debug=debug)
382 382 if context is None:
383 383 context = zmq.Context.instance()
384 384 self._context = context
385 385 self._stop_spinning = Event()
386 386
387 387 if 'url_or_file' in extra_args:
388 388 url_file = extra_args['url_or_file']
389 389 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
390 390
391 391 if url_file and util.is_url(url_file):
392 392 raise ValueError("single urls cannot be specified, url-files must be used.")
393 393
394 394 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
395 395
396 396 if self._cd is not None:
397 397 if url_file is None:
398 398 if not cluster_id:
399 399 client_json = 'ipcontroller-client.json'
400 400 else:
401 401 client_json = 'ipcontroller-%s-client.json' % cluster_id
402 402 url_file = pjoin(self._cd.security_dir, client_json)
403 403 if url_file is None:
404 404 raise ValueError(
405 405 "I can't find enough information to connect to a hub!"
406 406 " Please specify at least one of url_file or profile."
407 407 )
408 408
409 409 with open(url_file) as f:
410 410 cfg = json.load(f)
411 411
412 412 self._task_scheme = cfg['task_scheme']
413 413
414 414 # sync defaults from args, json:
415 415 if sshserver:
416 416 cfg['ssh'] = sshserver
417 417
418 418 location = cfg.setdefault('location', None)
419 419
420 420 proto,addr = cfg['interface'].split('://')
421 421 addr = util.disambiguate_ip_address(addr, location)
422 422 cfg['interface'] = "%s://%s" % (proto, addr)
423 423
424 424 # turn interface,port into full urls:
425 425 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
426 426 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
427 427
428 428 url = cfg['registration']
429 429
430 430 if location is not None and addr == LOCALHOST:
431 431 # location specified, and connection is expected to be local
432 432 if location not in LOCAL_IPS and not sshserver:
433 433 # load ssh from JSON *only* if the controller is not on
434 434 # this machine
435 435 sshserver=cfg['ssh']
436 436 if location not in LOCAL_IPS and not sshserver:
437 437 # warn if no ssh specified, but SSH is probably needed
438 438 # This is only a warning, because the most likely cause
439 439 # is a local Controller on a laptop whose IP is dynamic
440 440 warnings.warn("""
441 441 Controller appears to be listening on localhost, but not on this machine.
442 442 If this is true, you should specify Client(...,sshserver='you@%s')
443 443 or instruct your controller to listen on an external IP."""%location,
444 444 RuntimeWarning)
445 445 elif not sshserver:
446 446 # otherwise sync with cfg
447 447 sshserver = cfg['ssh']
448 448
449 449 self._config = cfg
450 450
451 451 self._ssh = bool(sshserver or sshkey or password)
452 452 if self._ssh and sshserver is None:
453 453 # default to ssh via localhost
454 454 sshserver = addr
455 455 if self._ssh and password is None:
456 456 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
457 457 password=False
458 458 else:
459 459 password = getpass("SSH Password for %s: "%sshserver)
460 460 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
461 461
462 462 # configure and construct the session
463 463 try:
464 464 extra_args['packer'] = cfg['pack']
465 465 extra_args['unpacker'] = cfg['unpack']
466 466 extra_args['key'] = cast_bytes(cfg['key'])
467 467 extra_args['signature_scheme'] = cfg['signature_scheme']
468 468 except KeyError as exc:
469 469 msg = '\n'.join([
470 470 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
471 471 "If you are reusing connection files, remove them and start ipcontroller again."
472 472 ])
473 473 raise ValueError(msg.format(exc.message))
474 474
475 475 self.session = Session(**extra_args)
476 476
477 477 self._query_socket = self._context.socket(zmq.DEALER)
478 478
479 479 if self._ssh:
480 480 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
481 481 else:
482 482 self._query_socket.connect(cfg['registration'])
483 483
484 484 self.session.debug = self.debug
485 485
486 486 self._notification_handlers = {'registration_notification' : self._register_engine,
487 487 'unregistration_notification' : self._unregister_engine,
488 488 'shutdown_notification' : lambda msg: self.close(),
489 489 }
490 490 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
491 491 'apply_reply' : self._handle_apply_reply}
492 self._connect(sshserver, ssh_kwargs, timeout)
492
493 try:
494 self._connect(sshserver, ssh_kwargs, timeout)
495 except:
496 self.close(linger=0)
497 raise
493 498
494 499 # last step: setup magics, if we are in IPython:
495 500
496 501 try:
497 502 ip = get_ipython()
498 503 except NameError:
499 504 return
500 505 else:
501 506 if 'px' not in ip.magics_manager.magics:
502 507 # in IPython but we are the first Client.
503 508 # activate a default view for parallel magics.
504 509 self.activate()
505 510
506 511 def __del__(self):
507 512 """cleanup sockets, but _not_ context."""
508 513 self.close()
509 514
510 515 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
511 516 if ipython_dir is None:
512 517 ipython_dir = get_ipython_dir()
513 518 if profile_dir is not None:
514 519 try:
515 520 self._cd = ProfileDir.find_profile_dir(profile_dir)
516 521 return
517 522 except ProfileDirError:
518 523 pass
519 524 elif profile is not None:
520 525 try:
521 526 self._cd = ProfileDir.find_profile_dir_by_name(
522 527 ipython_dir, profile)
523 528 return
524 529 except ProfileDirError:
525 530 pass
526 531 self._cd = None
527 532
528 533 def _update_engines(self, engines):
529 534 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
530 535 for k,v in engines.iteritems():
531 536 eid = int(k)
532 537 if eid not in self._engines:
533 538 self._ids.append(eid)
534 539 self._engines[eid] = v
535 540 self._ids = sorted(self._ids)
536 541 if sorted(self._engines.keys()) != range(len(self._engines)) and \
537 542 self._task_scheme == 'pure' and self._task_socket:
538 543 self._stop_scheduling_tasks()
539 544
540 545 def _stop_scheduling_tasks(self):
541 546 """Stop scheduling tasks because an engine has been unregistered
542 547 from a pure ZMQ scheduler.
543 548 """
544 549 self._task_socket.close()
545 550 self._task_socket = None
546 551 msg = "An engine has been unregistered, and we are using pure " +\
547 552 "ZMQ task scheduling. Task farming will be disabled."
548 553 if self.outstanding:
549 554 msg += " If you were running tasks when this happened, " +\
550 555 "some `outstanding` msg_ids may never resolve."
551 556 warnings.warn(msg, RuntimeWarning)
552 557
553 558 def _build_targets(self, targets):
554 559 """Turn valid target IDs or 'all' into two lists:
555 560 (int_ids, uuids).
556 561 """
557 562 if not self._ids:
558 563 # flush notification socket if no engines yet, just in case
559 564 if not self.ids:
560 565 raise error.NoEnginesRegistered("Can't build targets without any engines")
561 566
562 567 if targets is None:
563 568 targets = self._ids
564 569 elif isinstance(targets, basestring):
565 570 if targets.lower() == 'all':
566 571 targets = self._ids
567 572 else:
568 573 raise TypeError("%r not valid str target, must be 'all'"%(targets))
569 574 elif isinstance(targets, int):
570 575 if targets < 0:
571 576 targets = self.ids[targets]
572 577 if targets not in self._ids:
573 578 raise IndexError("No such engine: %i"%targets)
574 579 targets = [targets]
575 580
576 581 if isinstance(targets, slice):
577 582 indices = range(len(self._ids))[targets]
578 583 ids = self.ids
579 584 targets = [ ids[i] for i in indices ]
580 585
581 586 if not isinstance(targets, (tuple, list, xrange)):
582 587 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
583 588
584 589 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
585 590
586 591 def _connect(self, sshserver, ssh_kwargs, timeout):
587 592 """setup all our socket connections to the cluster. This is called from
588 593 __init__."""
589 594
590 595 # Maybe allow reconnecting?
591 596 if self._connected:
592 597 return
593 598 self._connected=True
594 599
595 600 def connect_socket(s, url):
596 # url = util.disambiguate_url(url, self._config['location'])
597 601 if self._ssh:
598 602 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
599 603 else:
600 604 return s.connect(url)
601 605
602 606 self.session.send(self._query_socket, 'connection_request')
603 607 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
604 608 poller = zmq.Poller()
605 609 poller.register(self._query_socket, zmq.POLLIN)
606 610 # poll expects milliseconds, timeout is seconds
607 611 evts = poller.poll(timeout*1000)
608 612 if not evts:
609 613 raise error.TimeoutError("Hub connection request timed out")
610 614 idents,msg = self.session.recv(self._query_socket,mode=0)
611 615 if self.debug:
612 616 pprint(msg)
613 617 content = msg['content']
614 618 # self._config['registration'] = dict(content)
615 619 cfg = self._config
616 620 if content['status'] == 'ok':
617 621 self._mux_socket = self._context.socket(zmq.DEALER)
618 622 connect_socket(self._mux_socket, cfg['mux'])
619 623
620 624 self._task_socket = self._context.socket(zmq.DEALER)
621 625 connect_socket(self._task_socket, cfg['task'])
622 626
623 627 self._notification_socket = self._context.socket(zmq.SUB)
624 628 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
625 629 connect_socket(self._notification_socket, cfg['notification'])
626 630
627 631 self._control_socket = self._context.socket(zmq.DEALER)
628 632 connect_socket(self._control_socket, cfg['control'])
629 633
630 634 self._iopub_socket = self._context.socket(zmq.SUB)
631 635 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
632 636 connect_socket(self._iopub_socket, cfg['iopub'])
633 637
634 638 self._update_engines(dict(content['engines']))
635 639 else:
636 640 self._connected = False
637 641 raise Exception("Failed to connect!")
638 642
639 643 #--------------------------------------------------------------------------
640 644 # handlers and callbacks for incoming messages
641 645 #--------------------------------------------------------------------------
642 646
643 647 def _unwrap_exception(self, content):
644 648 """unwrap exception, and remap engine_id to int."""
645 649 e = error.unwrap_exception(content)
646 650 # print e.traceback
647 651 if e.engine_info:
648 652 e_uuid = e.engine_info['engine_uuid']
649 653 eid = self._engines[e_uuid]
650 654 e.engine_info['engine_id'] = eid
651 655 return e
652 656
653 657 def _extract_metadata(self, msg):
654 658 header = msg['header']
655 659 parent = msg['parent_header']
656 660 msg_meta = msg['metadata']
657 661 content = msg['content']
658 662 md = {'msg_id' : parent['msg_id'],
659 663 'received' : datetime.now(),
660 664 'engine_uuid' : msg_meta.get('engine', None),
661 665 'follow' : msg_meta.get('follow', []),
662 666 'after' : msg_meta.get('after', []),
663 667 'status' : content['status'],
664 668 }
665 669
666 670 if md['engine_uuid'] is not None:
667 671 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
668 672
669 673 if 'date' in parent:
670 674 md['submitted'] = parent['date']
671 675 if 'started' in msg_meta:
672 676 md['started'] = msg_meta['started']
673 677 if 'date' in header:
674 678 md['completed'] = header['date']
675 679 return md
676 680
677 681 def _register_engine(self, msg):
678 682 """Register a new engine, and update our connection info."""
679 683 content = msg['content']
680 684 eid = content['id']
681 685 d = {eid : content['uuid']}
682 686 self._update_engines(d)
683 687
684 688 def _unregister_engine(self, msg):
685 689 """Unregister an engine that has died."""
686 690 content = msg['content']
687 691 eid = int(content['id'])
688 692 if eid in self._ids:
689 693 self._ids.remove(eid)
690 694 uuid = self._engines.pop(eid)
691 695
692 696 self._handle_stranded_msgs(eid, uuid)
693 697
694 698 if self._task_socket and self._task_scheme == 'pure':
695 699 self._stop_scheduling_tasks()
696 700
697 701 def _handle_stranded_msgs(self, eid, uuid):
698 702 """Handle messages known to be on an engine when the engine unregisters.
699 703
700 704 It is possible that this will fire prematurely - that is, an engine will
701 705 go down after completing a result, and the client will be notified
702 706 of the unregistration and later receive the successful result.
703 707 """
704 708
705 709 outstanding = self._outstanding_dict[uuid]
706 710
707 711 for msg_id in list(outstanding):
708 712 if msg_id in self.results:
709 713 # we already
710 714 continue
711 715 try:
712 716 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
713 717 except:
714 718 content = error.wrap_exception()
715 719 # build a fake message:
716 720 msg = self.session.msg('apply_reply', content=content)
717 721 msg['parent_header']['msg_id'] = msg_id
718 722 msg['metadata']['engine'] = uuid
719 723 self._handle_apply_reply(msg)
720 724
721 725 def _handle_execute_reply(self, msg):
722 726 """Save the reply to an execute_request into our results.
723 727
724 728 execute messages are never actually used. apply is used instead.
725 729 """
726 730
727 731 parent = msg['parent_header']
728 732 msg_id = parent['msg_id']
729 733 if msg_id not in self.outstanding:
730 734 if msg_id in self.history:
731 735 print ("got stale result: %s"%msg_id)
732 736 else:
733 737 print ("got unknown result: %s"%msg_id)
734 738 else:
735 739 self.outstanding.remove(msg_id)
736 740
737 741 content = msg['content']
738 742 header = msg['header']
739 743
740 744 # construct metadata:
741 745 md = self.metadata[msg_id]
742 746 md.update(self._extract_metadata(msg))
743 747 # is this redundant?
744 748 self.metadata[msg_id] = md
745 749
746 750 e_outstanding = self._outstanding_dict[md['engine_uuid']]
747 751 if msg_id in e_outstanding:
748 752 e_outstanding.remove(msg_id)
749 753
750 754 # construct result:
751 755 if content['status'] == 'ok':
752 756 self.results[msg_id] = ExecuteReply(msg_id, content, md)
753 757 elif content['status'] == 'aborted':
754 758 self.results[msg_id] = error.TaskAborted(msg_id)
755 759 elif content['status'] == 'resubmitted':
756 760 # TODO: handle resubmission
757 761 pass
758 762 else:
759 763 self.results[msg_id] = self._unwrap_exception(content)
760 764
761 765 def _handle_apply_reply(self, msg):
762 766 """Save the reply to an apply_request into our results."""
763 767 parent = msg['parent_header']
764 768 msg_id = parent['msg_id']
765 769 if msg_id not in self.outstanding:
766 770 if msg_id in self.history:
767 771 print ("got stale result: %s"%msg_id)
768 772 print self.results[msg_id]
769 773 print msg
770 774 else:
771 775 print ("got unknown result: %s"%msg_id)
772 776 else:
773 777 self.outstanding.remove(msg_id)
774 778 content = msg['content']
775 779 header = msg['header']
776 780
777 781 # construct metadata:
778 782 md = self.metadata[msg_id]
779 783 md.update(self._extract_metadata(msg))
780 784 # is this redundant?
781 785 self.metadata[msg_id] = md
782 786
783 787 e_outstanding = self._outstanding_dict[md['engine_uuid']]
784 788 if msg_id in e_outstanding:
785 789 e_outstanding.remove(msg_id)
786 790
787 791 # construct result:
788 792 if content['status'] == 'ok':
789 793 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
790 794 elif content['status'] == 'aborted':
791 795 self.results[msg_id] = error.TaskAborted(msg_id)
792 796 elif content['status'] == 'resubmitted':
793 797 # TODO: handle resubmission
794 798 pass
795 799 else:
796 800 self.results[msg_id] = self._unwrap_exception(content)
797 801
798 802 def _flush_notifications(self):
799 803 """Flush notifications of engine registrations waiting
800 804 in ZMQ queue."""
801 805 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
802 806 while msg is not None:
803 807 if self.debug:
804 808 pprint(msg)
805 809 msg_type = msg['header']['msg_type']
806 810 handler = self._notification_handlers.get(msg_type, None)
807 811 if handler is None:
808 812 raise Exception("Unhandled message type: %s" % msg_type)
809 813 else:
810 814 handler(msg)
811 815 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
812 816
813 817 def _flush_results(self, sock):
814 818 """Flush task or queue results waiting in ZMQ queue."""
815 819 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
816 820 while msg is not None:
817 821 if self.debug:
818 822 pprint(msg)
819 823 msg_type = msg['header']['msg_type']
820 824 handler = self._queue_handlers.get(msg_type, None)
821 825 if handler is None:
822 826 raise Exception("Unhandled message type: %s" % msg_type)
823 827 else:
824 828 handler(msg)
825 829 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
826 830
827 831 def _flush_control(self, sock):
828 832 """Flush replies from the control channel waiting
829 833 in the ZMQ queue.
830 834
831 835 Currently: ignore them."""
832 836 if self._ignored_control_replies <= 0:
833 837 return
834 838 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
835 839 while msg is not None:
836 840 self._ignored_control_replies -= 1
837 841 if self.debug:
838 842 pprint(msg)
839 843 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
840 844
841 845 def _flush_ignored_control(self):
842 846 """flush ignored control replies"""
843 847 while self._ignored_control_replies > 0:
844 848 self.session.recv(self._control_socket)
845 849 self._ignored_control_replies -= 1
846 850
847 851 def _flush_ignored_hub_replies(self):
848 852 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
849 853 while msg is not None:
850 854 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
851 855
852 856 def _flush_iopub(self, sock):
853 857 """Flush replies from the iopub channel waiting
854 858 in the ZMQ queue.
855 859 """
856 860 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
857 861 while msg is not None:
858 862 if self.debug:
859 863 pprint(msg)
860 864 parent = msg['parent_header']
861 865 # ignore IOPub messages with no parent.
862 866 # Caused by print statements or warnings from before the first execution.
863 867 if not parent:
864 868 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
865 869 continue
866 870 msg_id = parent['msg_id']
867 871 content = msg['content']
868 872 header = msg['header']
869 873 msg_type = msg['header']['msg_type']
870 874
871 875 # init metadata:
872 876 md = self.metadata[msg_id]
873 877
874 878 if msg_type == 'stream':
875 879 name = content['name']
876 880 s = md[name] or ''
877 881 md[name] = s + content['data']
878 882 elif msg_type == 'pyerr':
879 883 md.update({'pyerr' : self._unwrap_exception(content)})
880 884 elif msg_type == 'pyin':
881 885 md.update({'pyin' : content['code']})
882 886 elif msg_type == 'display_data':
883 887 md['outputs'].append(content)
884 888 elif msg_type == 'pyout':
885 889 md['pyout'] = content
886 890 elif msg_type == 'data_message':
887 891 data, remainder = serialize.unserialize_object(msg['buffers'])
888 892 md['data'].update(data)
889 893 elif msg_type == 'status':
890 894 # idle message comes after all outputs
891 895 if content['execution_state'] == 'idle':
892 896 md['outputs_ready'] = True
893 897 else:
894 898 # unhandled msg_type (status, etc.)
895 899 pass
896 900
897 901 # reduntant?
898 902 self.metadata[msg_id] = md
899 903
900 904 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
901 905
902 906 #--------------------------------------------------------------------------
903 907 # len, getitem
904 908 #--------------------------------------------------------------------------
905 909
906 910 def __len__(self):
907 911 """len(client) returns # of engines."""
908 912 return len(self.ids)
909 913
910 914 def __getitem__(self, key):
911 915 """index access returns DirectView multiplexer objects
912 916
913 917 Must be int, slice, or list/tuple/xrange of ints"""
914 918 if not isinstance(key, (int, slice, tuple, list, xrange)):
915 919 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
916 920 else:
917 921 return self.direct_view(key)
918 922
919 923 #--------------------------------------------------------------------------
920 924 # Begin public methods
921 925 #--------------------------------------------------------------------------
922 926
923 927 @property
924 928 def ids(self):
925 929 """Always up-to-date ids property."""
926 930 self._flush_notifications()
927 931 # always copy:
928 932 return list(self._ids)
929 933
930 934 def activate(self, targets='all', suffix=''):
931 935 """Create a DirectView and register it with IPython magics
932 936
933 937 Defines the magics `%px, %autopx, %pxresult, %%px`
934 938
935 939 Parameters
936 940 ----------
937 941
938 942 targets: int, list of ints, or 'all'
939 943 The engines on which the view's magics will run
940 944 suffix: str [default: '']
941 945 The suffix, if any, for the magics. This allows you to have
942 946 multiple views associated with parallel magics at the same time.
943 947
944 948 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
945 949 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
946 950 on engine 0.
947 951 """
948 952 view = self.direct_view(targets)
949 953 view.block = True
950 954 view.activate(suffix)
951 955 return view
952 956
953 def close(self):
957 def close(self, linger=None):
958 """Close my zmq Sockets
959
960 If `linger`, set the zmq LINGER socket option,
961 which allows discarding of messages.
962 """
954 963 if self._closed:
955 964 return
956 965 self.stop_spin_thread()
957 snames = filter(lambda n: n.endswith('socket'), dir(self))
958 for socket in map(lambda name: getattr(self, name), snames):
959 if isinstance(socket, zmq.Socket) and not socket.closed:
960 socket.close()
966 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
967 for name in snames:
968 socket = getattr(self, name)
969 if socket is not None and not socket.closed:
970 if linger is not None:
971 socket.close(linger=linger)
972 else:
973 socket.close()
961 974 self._closed = True
962 975
963 976 def _spin_every(self, interval=1):
964 977 """target func for use in spin_thread"""
965 978 while True:
966 979 if self._stop_spinning.is_set():
967 980 return
968 981 time.sleep(interval)
969 982 self.spin()
970 983
971 984 def spin_thread(self, interval=1):
972 985 """call Client.spin() in a background thread on some regular interval
973 986
974 987 This helps ensure that messages don't pile up too much in the zmq queue
975 988 while you are working on other things, or just leaving an idle terminal.
976 989
977 990 It also helps limit potential padding of the `received` timestamp
978 991 on AsyncResult objects, used for timings.
979 992
980 993 Parameters
981 994 ----------
982 995
983 996 interval : float, optional
984 997 The interval on which to spin the client in the background thread
985 998 (simply passed to time.sleep).
986 999
987 1000 Notes
988 1001 -----
989 1002
990 1003 For precision timing, you may want to use this method to put a bound
991 1004 on the jitter (in seconds) in `received` timestamps used
992 1005 in AsyncResult.wall_time.
993 1006
994 1007 """
995 1008 if self._spin_thread is not None:
996 1009 self.stop_spin_thread()
997 1010 self._stop_spinning.clear()
998 1011 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
999 1012 self._spin_thread.daemon = True
1000 1013 self._spin_thread.start()
1001 1014
1002 1015 def stop_spin_thread(self):
1003 1016 """stop background spin_thread, if any"""
1004 1017 if self._spin_thread is not None:
1005 1018 self._stop_spinning.set()
1006 1019 self._spin_thread.join()
1007 1020 self._spin_thread = None
1008 1021
1009 1022 def spin(self):
1010 1023 """Flush any registration notifications and execution results
1011 1024 waiting in the ZMQ queue.
1012 1025 """
1013 1026 if self._notification_socket:
1014 1027 self._flush_notifications()
1015 1028 if self._iopub_socket:
1016 1029 self._flush_iopub(self._iopub_socket)
1017 1030 if self._mux_socket:
1018 1031 self._flush_results(self._mux_socket)
1019 1032 if self._task_socket:
1020 1033 self._flush_results(self._task_socket)
1021 1034 if self._control_socket:
1022 1035 self._flush_control(self._control_socket)
1023 1036 if self._query_socket:
1024 1037 self._flush_ignored_hub_replies()
1025 1038
1026 1039 def wait(self, jobs=None, timeout=-1):
1027 1040 """waits on one or more `jobs`, for up to `timeout` seconds.
1028 1041
1029 1042 Parameters
1030 1043 ----------
1031 1044
1032 1045 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1033 1046 ints are indices to self.history
1034 1047 strs are msg_ids
1035 1048 default: wait on all outstanding messages
1036 1049 timeout : float
1037 1050 a time in seconds, after which to give up.
1038 1051 default is -1, which means no timeout
1039 1052
1040 1053 Returns
1041 1054 -------
1042 1055
1043 1056 True : when all msg_ids are done
1044 1057 False : timeout reached, some msg_ids still outstanding
1045 1058 """
1046 1059 tic = time.time()
1047 1060 if jobs is None:
1048 1061 theids = self.outstanding
1049 1062 else:
1050 1063 if isinstance(jobs, (int, basestring, AsyncResult)):
1051 1064 jobs = [jobs]
1052 1065 theids = set()
1053 1066 for job in jobs:
1054 1067 if isinstance(job, int):
1055 1068 # index access
1056 1069 job = self.history[job]
1057 1070 elif isinstance(job, AsyncResult):
1058 1071 map(theids.add, job.msg_ids)
1059 1072 continue
1060 1073 theids.add(job)
1061 1074 if not theids.intersection(self.outstanding):
1062 1075 return True
1063 1076 self.spin()
1064 1077 while theids.intersection(self.outstanding):
1065 1078 if timeout >= 0 and ( time.time()-tic ) > timeout:
1066 1079 break
1067 1080 time.sleep(1e-3)
1068 1081 self.spin()
1069 1082 return len(theids.intersection(self.outstanding)) == 0
1070 1083
1071 1084 #--------------------------------------------------------------------------
1072 1085 # Control methods
1073 1086 #--------------------------------------------------------------------------
1074 1087
1075 1088 @spin_first
1076 1089 def clear(self, targets=None, block=None):
1077 1090 """Clear the namespace in target(s)."""
1078 1091 block = self.block if block is None else block
1079 1092 targets = self._build_targets(targets)[0]
1080 1093 for t in targets:
1081 1094 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1082 1095 error = False
1083 1096 if block:
1084 1097 self._flush_ignored_control()
1085 1098 for i in range(len(targets)):
1086 1099 idents,msg = self.session.recv(self._control_socket,0)
1087 1100 if self.debug:
1088 1101 pprint(msg)
1089 1102 if msg['content']['status'] != 'ok':
1090 1103 error = self._unwrap_exception(msg['content'])
1091 1104 else:
1092 1105 self._ignored_control_replies += len(targets)
1093 1106 if error:
1094 1107 raise error
1095 1108
1096 1109
1097 1110 @spin_first
1098 1111 def abort(self, jobs=None, targets=None, block=None):
1099 1112 """Abort specific jobs from the execution queues of target(s).
1100 1113
1101 1114 This is a mechanism to prevent jobs that have already been submitted
1102 1115 from executing.
1103 1116
1104 1117 Parameters
1105 1118 ----------
1106 1119
1107 1120 jobs : msg_id, list of msg_ids, or AsyncResult
1108 1121 The jobs to be aborted
1109 1122
1110 1123 If unspecified/None: abort all outstanding jobs.
1111 1124
1112 1125 """
1113 1126 block = self.block if block is None else block
1114 1127 jobs = jobs if jobs is not None else list(self.outstanding)
1115 1128 targets = self._build_targets(targets)[0]
1116 1129
1117 1130 msg_ids = []
1118 1131 if isinstance(jobs, (basestring,AsyncResult)):
1119 1132 jobs = [jobs]
1120 1133 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1121 1134 if bad_ids:
1122 1135 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1123 1136 for j in jobs:
1124 1137 if isinstance(j, AsyncResult):
1125 1138 msg_ids.extend(j.msg_ids)
1126 1139 else:
1127 1140 msg_ids.append(j)
1128 1141 content = dict(msg_ids=msg_ids)
1129 1142 for t in targets:
1130 1143 self.session.send(self._control_socket, 'abort_request',
1131 1144 content=content, ident=t)
1132 1145 error = False
1133 1146 if block:
1134 1147 self._flush_ignored_control()
1135 1148 for i in range(len(targets)):
1136 1149 idents,msg = self.session.recv(self._control_socket,0)
1137 1150 if self.debug:
1138 1151 pprint(msg)
1139 1152 if msg['content']['status'] != 'ok':
1140 1153 error = self._unwrap_exception(msg['content'])
1141 1154 else:
1142 1155 self._ignored_control_replies += len(targets)
1143 1156 if error:
1144 1157 raise error
1145 1158
1146 1159 @spin_first
1147 1160 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1148 1161 """Terminates one or more engine processes, optionally including the hub.
1149 1162
1150 1163 Parameters
1151 1164 ----------
1152 1165
1153 1166 targets: list of ints or 'all' [default: all]
1154 1167 Which engines to shutdown.
1155 1168 hub: bool [default: False]
1156 1169 Whether to include the Hub. hub=True implies targets='all'.
1157 1170 block: bool [default: self.block]
1158 1171 Whether to wait for clean shutdown replies or not.
1159 1172 restart: bool [default: False]
1160 1173 NOT IMPLEMENTED
1161 1174 whether to restart engines after shutting them down.
1162 1175 """
1163 1176 from IPython.parallel.error import NoEnginesRegistered
1164 1177 if restart:
1165 1178 raise NotImplementedError("Engine restart is not yet implemented")
1166 1179
1167 1180 block = self.block if block is None else block
1168 1181 if hub:
1169 1182 targets = 'all'
1170 1183 try:
1171 1184 targets = self._build_targets(targets)[0]
1172 1185 except NoEnginesRegistered:
1173 1186 targets = []
1174 1187 for t in targets:
1175 1188 self.session.send(self._control_socket, 'shutdown_request',
1176 1189 content={'restart':restart},ident=t)
1177 1190 error = False
1178 1191 if block or hub:
1179 1192 self._flush_ignored_control()
1180 1193 for i in range(len(targets)):
1181 1194 idents,msg = self.session.recv(self._control_socket, 0)
1182 1195 if self.debug:
1183 1196 pprint(msg)
1184 1197 if msg['content']['status'] != 'ok':
1185 1198 error = self._unwrap_exception(msg['content'])
1186 1199 else:
1187 1200 self._ignored_control_replies += len(targets)
1188 1201
1189 1202 if hub:
1190 1203 time.sleep(0.25)
1191 1204 self.session.send(self._query_socket, 'shutdown_request')
1192 1205 idents,msg = self.session.recv(self._query_socket, 0)
1193 1206 if self.debug:
1194 1207 pprint(msg)
1195 1208 if msg['content']['status'] != 'ok':
1196 1209 error = self._unwrap_exception(msg['content'])
1197 1210
1198 1211 if error:
1199 1212 raise error
1200 1213
1201 1214 #--------------------------------------------------------------------------
1202 1215 # Execution related methods
1203 1216 #--------------------------------------------------------------------------
1204 1217
1205 1218 def _maybe_raise(self, result):
1206 1219 """wrapper for maybe raising an exception if apply failed."""
1207 1220 if isinstance(result, error.RemoteError):
1208 1221 raise result
1209 1222
1210 1223 return result
1211 1224
1212 1225 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1213 1226 ident=None):
1214 1227 """construct and send an apply message via a socket.
1215 1228
1216 1229 This is the principal method with which all engine execution is performed by views.
1217 1230 """
1218 1231
1219 1232 if self._closed:
1220 1233 raise RuntimeError("Client cannot be used after its sockets have been closed")
1221 1234
1222 1235 # defaults:
1223 1236 args = args if args is not None else []
1224 1237 kwargs = kwargs if kwargs is not None else {}
1225 1238 metadata = metadata if metadata is not None else {}
1226 1239
1227 1240 # validate arguments
1228 1241 if not callable(f) and not isinstance(f, Reference):
1229 1242 raise TypeError("f must be callable, not %s"%type(f))
1230 1243 if not isinstance(args, (tuple, list)):
1231 1244 raise TypeError("args must be tuple or list, not %s"%type(args))
1232 1245 if not isinstance(kwargs, dict):
1233 1246 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1234 1247 if not isinstance(metadata, dict):
1235 1248 raise TypeError("metadata must be dict, not %s"%type(metadata))
1236 1249
1237 1250 bufs = serialize.pack_apply_message(f, args, kwargs,
1238 1251 buffer_threshold=self.session.buffer_threshold,
1239 1252 item_threshold=self.session.item_threshold,
1240 1253 )
1241 1254
1242 1255 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1243 1256 metadata=metadata, track=track)
1244 1257
1245 1258 msg_id = msg['header']['msg_id']
1246 1259 self.outstanding.add(msg_id)
1247 1260 if ident:
1248 1261 # possibly routed to a specific engine
1249 1262 if isinstance(ident, list):
1250 1263 ident = ident[-1]
1251 1264 if ident in self._engines.values():
1252 1265 # save for later, in case of engine death
1253 1266 self._outstanding_dict[ident].add(msg_id)
1254 1267 self.history.append(msg_id)
1255 1268 self.metadata[msg_id]['submitted'] = datetime.now()
1256 1269
1257 1270 return msg
1258 1271
1259 1272 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1260 1273 """construct and send an execute request via a socket.
1261 1274
1262 1275 """
1263 1276
1264 1277 if self._closed:
1265 1278 raise RuntimeError("Client cannot be used after its sockets have been closed")
1266 1279
1267 1280 # defaults:
1268 1281 metadata = metadata if metadata is not None else {}
1269 1282
1270 1283 # validate arguments
1271 1284 if not isinstance(code, basestring):
1272 1285 raise TypeError("code must be text, not %s" % type(code))
1273 1286 if not isinstance(metadata, dict):
1274 1287 raise TypeError("metadata must be dict, not %s" % type(metadata))
1275 1288
1276 1289 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1277 1290
1278 1291
1279 1292 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1280 1293 metadata=metadata)
1281 1294
1282 1295 msg_id = msg['header']['msg_id']
1283 1296 self.outstanding.add(msg_id)
1284 1297 if ident:
1285 1298 # possibly routed to a specific engine
1286 1299 if isinstance(ident, list):
1287 1300 ident = ident[-1]
1288 1301 if ident in self._engines.values():
1289 1302 # save for later, in case of engine death
1290 1303 self._outstanding_dict[ident].add(msg_id)
1291 1304 self.history.append(msg_id)
1292 1305 self.metadata[msg_id]['submitted'] = datetime.now()
1293 1306
1294 1307 return msg
1295 1308
1296 1309 #--------------------------------------------------------------------------
1297 1310 # construct a View object
1298 1311 #--------------------------------------------------------------------------
1299 1312
1300 1313 def load_balanced_view(self, targets=None):
1301 1314 """construct a DirectView object.
1302 1315
1303 1316 If no arguments are specified, create a LoadBalancedView
1304 1317 using all engines.
1305 1318
1306 1319 Parameters
1307 1320 ----------
1308 1321
1309 1322 targets: list,slice,int,etc. [default: use all engines]
1310 1323 The subset of engines across which to load-balance
1311 1324 """
1312 1325 if targets == 'all':
1313 1326 targets = None
1314 1327 if targets is not None:
1315 1328 targets = self._build_targets(targets)[1]
1316 1329 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1317 1330
1318 1331 def direct_view(self, targets='all'):
1319 1332 """construct a DirectView object.
1320 1333
1321 1334 If no targets are specified, create a DirectView using all engines.
1322 1335
1323 1336 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1324 1337 evaluate the target engines at each execution, whereas rc[:] will connect to
1325 1338 all *current* engines, and that list will not change.
1326 1339
1327 1340 That is, 'all' will always use all engines, whereas rc[:] will not use
1328 1341 engines added after the DirectView is constructed.
1329 1342
1330 1343 Parameters
1331 1344 ----------
1332 1345
1333 1346 targets: list,slice,int,etc. [default: use all engines]
1334 1347 The engines to use for the View
1335 1348 """
1336 1349 single = isinstance(targets, int)
1337 1350 # allow 'all' to be lazily evaluated at each execution
1338 1351 if targets != 'all':
1339 1352 targets = self._build_targets(targets)[1]
1340 1353 if single:
1341 1354 targets = targets[0]
1342 1355 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1343 1356
1344 1357 #--------------------------------------------------------------------------
1345 1358 # Query methods
1346 1359 #--------------------------------------------------------------------------
1347 1360
1348 1361 @spin_first
1349 1362 def get_result(self, indices_or_msg_ids=None, block=None):
1350 1363 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1351 1364
1352 1365 If the client already has the results, no request to the Hub will be made.
1353 1366
1354 1367 This is a convenient way to construct AsyncResult objects, which are wrappers
1355 1368 that include metadata about execution, and allow for awaiting results that
1356 1369 were not submitted by this Client.
1357 1370
1358 1371 It can also be a convenient way to retrieve the metadata associated with
1359 1372 blocking execution, since it always retrieves
1360 1373
1361 1374 Examples
1362 1375 --------
1363 1376 ::
1364 1377
1365 1378 In [10]: r = client.apply()
1366 1379
1367 1380 Parameters
1368 1381 ----------
1369 1382
1370 1383 indices_or_msg_ids : integer history index, str msg_id, or list of either
1371 1384 The indices or msg_ids of indices to be retrieved
1372 1385
1373 1386 block : bool
1374 1387 Whether to wait for the result to be done
1375 1388
1376 1389 Returns
1377 1390 -------
1378 1391
1379 1392 AsyncResult
1380 1393 A single AsyncResult object will always be returned.
1381 1394
1382 1395 AsyncHubResult
1383 1396 A subclass of AsyncResult that retrieves results from the Hub
1384 1397
1385 1398 """
1386 1399 block = self.block if block is None else block
1387 1400 if indices_or_msg_ids is None:
1388 1401 indices_or_msg_ids = -1
1389 1402
1390 1403 single_result = False
1391 1404 if not isinstance(indices_or_msg_ids, (list,tuple)):
1392 1405 indices_or_msg_ids = [indices_or_msg_ids]
1393 1406 single_result = True
1394 1407
1395 1408 theids = []
1396 1409 for id in indices_or_msg_ids:
1397 1410 if isinstance(id, int):
1398 1411 id = self.history[id]
1399 1412 if not isinstance(id, basestring):
1400 1413 raise TypeError("indices must be str or int, not %r"%id)
1401 1414 theids.append(id)
1402 1415
1403 1416 local_ids = filter(lambda msg_id: msg_id in self.outstanding or msg_id in self.results, theids)
1404 1417 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1405 1418
1406 1419 # given single msg_id initially, get_result shot get the result itself,
1407 1420 # not a length-one list
1408 1421 if single_result:
1409 1422 theids = theids[0]
1410 1423
1411 1424 if remote_ids:
1412 1425 ar = AsyncHubResult(self, msg_ids=theids)
1413 1426 else:
1414 1427 ar = AsyncResult(self, msg_ids=theids)
1415 1428
1416 1429 if block:
1417 1430 ar.wait()
1418 1431
1419 1432 return ar
1420 1433
1421 1434 @spin_first
1422 1435 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1423 1436 """Resubmit one or more tasks.
1424 1437
1425 1438 in-flight tasks may not be resubmitted.
1426 1439
1427 1440 Parameters
1428 1441 ----------
1429 1442
1430 1443 indices_or_msg_ids : integer history index, str msg_id, or list of either
1431 1444 The indices or msg_ids of indices to be retrieved
1432 1445
1433 1446 block : bool
1434 1447 Whether to wait for the result to be done
1435 1448
1436 1449 Returns
1437 1450 -------
1438 1451
1439 1452 AsyncHubResult
1440 1453 A subclass of AsyncResult that retrieves results from the Hub
1441 1454
1442 1455 """
1443 1456 block = self.block if block is None else block
1444 1457 if indices_or_msg_ids is None:
1445 1458 indices_or_msg_ids = -1
1446 1459
1447 1460 if not isinstance(indices_or_msg_ids, (list,tuple)):
1448 1461 indices_or_msg_ids = [indices_or_msg_ids]
1449 1462
1450 1463 theids = []
1451 1464 for id in indices_or_msg_ids:
1452 1465 if isinstance(id, int):
1453 1466 id = self.history[id]
1454 1467 if not isinstance(id, basestring):
1455 1468 raise TypeError("indices must be str or int, not %r"%id)
1456 1469 theids.append(id)
1457 1470
1458 1471 content = dict(msg_ids = theids)
1459 1472
1460 1473 self.session.send(self._query_socket, 'resubmit_request', content)
1461 1474
1462 1475 zmq.select([self._query_socket], [], [])
1463 1476 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1464 1477 if self.debug:
1465 1478 pprint(msg)
1466 1479 content = msg['content']
1467 1480 if content['status'] != 'ok':
1468 1481 raise self._unwrap_exception(content)
1469 1482 mapping = content['resubmitted']
1470 1483 new_ids = [ mapping[msg_id] for msg_id in theids ]
1471 1484
1472 1485 ar = AsyncHubResult(self, msg_ids=new_ids)
1473 1486
1474 1487 if block:
1475 1488 ar.wait()
1476 1489
1477 1490 return ar
1478 1491
1479 1492 @spin_first
1480 1493 def result_status(self, msg_ids, status_only=True):
1481 1494 """Check on the status of the result(s) of the apply request with `msg_ids`.
1482 1495
1483 1496 If status_only is False, then the actual results will be retrieved, else
1484 1497 only the status of the results will be checked.
1485 1498
1486 1499 Parameters
1487 1500 ----------
1488 1501
1489 1502 msg_ids : list of msg_ids
1490 1503 if int:
1491 1504 Passed as index to self.history for convenience.
1492 1505 status_only : bool (default: True)
1493 1506 if False:
1494 1507 Retrieve the actual results of completed tasks.
1495 1508
1496 1509 Returns
1497 1510 -------
1498 1511
1499 1512 results : dict
1500 1513 There will always be the keys 'pending' and 'completed', which will
1501 1514 be lists of msg_ids that are incomplete or complete. If `status_only`
1502 1515 is False, then completed results will be keyed by their `msg_id`.
1503 1516 """
1504 1517 if not isinstance(msg_ids, (list,tuple)):
1505 1518 msg_ids = [msg_ids]
1506 1519
1507 1520 theids = []
1508 1521 for msg_id in msg_ids:
1509 1522 if isinstance(msg_id, int):
1510 1523 msg_id = self.history[msg_id]
1511 1524 if not isinstance(msg_id, basestring):
1512 1525 raise TypeError("msg_ids must be str, not %r"%msg_id)
1513 1526 theids.append(msg_id)
1514 1527
1515 1528 completed = []
1516 1529 local_results = {}
1517 1530
1518 1531 # comment this block out to temporarily disable local shortcut:
1519 1532 for msg_id in theids:
1520 1533 if msg_id in self.results:
1521 1534 completed.append(msg_id)
1522 1535 local_results[msg_id] = self.results[msg_id]
1523 1536 theids.remove(msg_id)
1524 1537
1525 1538 if theids: # some not locally cached
1526 1539 content = dict(msg_ids=theids, status_only=status_only)
1527 1540 msg = self.session.send(self._query_socket, "result_request", content=content)
1528 1541 zmq.select([self._query_socket], [], [])
1529 1542 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1530 1543 if self.debug:
1531 1544 pprint(msg)
1532 1545 content = msg['content']
1533 1546 if content['status'] != 'ok':
1534 1547 raise self._unwrap_exception(content)
1535 1548 buffers = msg['buffers']
1536 1549 else:
1537 1550 content = dict(completed=[],pending=[])
1538 1551
1539 1552 content['completed'].extend(completed)
1540 1553
1541 1554 if status_only:
1542 1555 return content
1543 1556
1544 1557 failures = []
1545 1558 # load cached results into result:
1546 1559 content.update(local_results)
1547 1560
1548 1561 # update cache with results:
1549 1562 for msg_id in sorted(theids):
1550 1563 if msg_id in content['completed']:
1551 1564 rec = content[msg_id]
1552 1565 parent = rec['header']
1553 1566 header = rec['result_header']
1554 1567 rcontent = rec['result_content']
1555 1568 iodict = rec['io']
1556 1569 if isinstance(rcontent, str):
1557 1570 rcontent = self.session.unpack(rcontent)
1558 1571
1559 1572 md = self.metadata[msg_id]
1560 1573 md_msg = dict(
1561 1574 content=rcontent,
1562 1575 parent_header=parent,
1563 1576 header=header,
1564 1577 metadata=rec['result_metadata'],
1565 1578 )
1566 1579 md.update(self._extract_metadata(md_msg))
1567 1580 if rec.get('received'):
1568 1581 md['received'] = rec['received']
1569 1582 md.update(iodict)
1570 1583
1571 1584 if rcontent['status'] == 'ok':
1572 1585 if header['msg_type'] == 'apply_reply':
1573 1586 res,buffers = serialize.unserialize_object(buffers)
1574 1587 elif header['msg_type'] == 'execute_reply':
1575 1588 res = ExecuteReply(msg_id, rcontent, md)
1576 1589 else:
1577 1590 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1578 1591 else:
1579 1592 res = self._unwrap_exception(rcontent)
1580 1593 failures.append(res)
1581 1594
1582 1595 self.results[msg_id] = res
1583 1596 content[msg_id] = res
1584 1597
1585 1598 if len(theids) == 1 and failures:
1586 1599 raise failures[0]
1587 1600
1588 1601 error.collect_exceptions(failures, "result_status")
1589 1602 return content
1590 1603
1591 1604 @spin_first
1592 1605 def queue_status(self, targets='all', verbose=False):
1593 1606 """Fetch the status of engine queues.
1594 1607
1595 1608 Parameters
1596 1609 ----------
1597 1610
1598 1611 targets : int/str/list of ints/strs
1599 1612 the engines whose states are to be queried.
1600 1613 default : all
1601 1614 verbose : bool
1602 1615 Whether to return lengths only, or lists of ids for each element
1603 1616 """
1604 1617 if targets == 'all':
1605 1618 # allow 'all' to be evaluated on the engine
1606 1619 engine_ids = None
1607 1620 else:
1608 1621 engine_ids = self._build_targets(targets)[1]
1609 1622 content = dict(targets=engine_ids, verbose=verbose)
1610 1623 self.session.send(self._query_socket, "queue_request", content=content)
1611 1624 idents,msg = self.session.recv(self._query_socket, 0)
1612 1625 if self.debug:
1613 1626 pprint(msg)
1614 1627 content = msg['content']
1615 1628 status = content.pop('status')
1616 1629 if status != 'ok':
1617 1630 raise self._unwrap_exception(content)
1618 1631 content = rekey(content)
1619 1632 if isinstance(targets, int):
1620 1633 return content[targets]
1621 1634 else:
1622 1635 return content
1623 1636
1624 1637 def _build_msgids_from_target(self, targets=None):
1625 1638 """Build a list of msg_ids from the list of engine targets"""
1626 1639 if not targets: # needed as _build_targets otherwise uses all engines
1627 1640 return []
1628 1641 target_ids = self._build_targets(targets)[0]
1629 1642 return filter(lambda md_id: self.metadata[md_id]["engine_uuid"] in target_ids, self.metadata)
1630 1643
1631 1644 def _build_msgids_from_jobs(self, jobs=None):
1632 1645 """Build a list of msg_ids from "jobs" """
1633 1646 if not jobs:
1634 1647 return []
1635 1648 msg_ids = []
1636 1649 if isinstance(jobs, (basestring,AsyncResult)):
1637 1650 jobs = [jobs]
1638 1651 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1639 1652 if bad_ids:
1640 1653 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1641 1654 for j in jobs:
1642 1655 if isinstance(j, AsyncResult):
1643 1656 msg_ids.extend(j.msg_ids)
1644 1657 else:
1645 1658 msg_ids.append(j)
1646 1659 return msg_ids
1647 1660
1648 1661 def purge_local_results(self, jobs=[], targets=[]):
1649 1662 """Clears the client caches of results and frees such memory.
1650 1663
1651 1664 Individual results can be purged by msg_id, or the entire
1652 1665 history of specific targets can be purged.
1653 1666
1654 1667 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1655 1668
1656 1669 The client must have no outstanding tasks before purging the caches.
1657 1670 Raises `AssertionError` if there are still outstanding tasks.
1658 1671
1659 1672 After this call all `AsyncResults` are invalid and should be discarded.
1660 1673
1661 1674 If you must "reget" the results, you can still do so by using
1662 1675 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1663 1676 redownload the results from the hub if they are still available
1664 1677 (i.e `client.purge_hub_results(...)` has not been called.
1665 1678
1666 1679 Parameters
1667 1680 ----------
1668 1681
1669 1682 jobs : str or list of str or AsyncResult objects
1670 1683 the msg_ids whose results should be purged.
1671 1684 targets : int/str/list of ints/strs
1672 1685 The targets, by int_id, whose entire results are to be purged.
1673 1686
1674 1687 default : None
1675 1688 """
1676 1689 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1677 1690
1678 1691 if not targets and not jobs:
1679 1692 raise ValueError("Must specify at least one of `targets` and `jobs`")
1680 1693
1681 1694 if jobs == 'all':
1682 1695 self.results.clear()
1683 1696 self.metadata.clear()
1684 1697 return
1685 1698 else:
1686 1699 msg_ids = []
1687 1700 msg_ids.extend(self._build_msgids_from_target(targets))
1688 1701 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1689 1702 map(self.results.pop, msg_ids)
1690 1703 map(self.metadata.pop, msg_ids)
1691 1704
1692 1705
1693 1706 @spin_first
1694 1707 def purge_hub_results(self, jobs=[], targets=[]):
1695 1708 """Tell the Hub to forget results.
1696 1709
1697 1710 Individual results can be purged by msg_id, or the entire
1698 1711 history of specific targets can be purged.
1699 1712
1700 1713 Use `purge_results('all')` to scrub everything from the Hub's db.
1701 1714
1702 1715 Parameters
1703 1716 ----------
1704 1717
1705 1718 jobs : str or list of str or AsyncResult objects
1706 1719 the msg_ids whose results should be forgotten.
1707 1720 targets : int/str/list of ints/strs
1708 1721 The targets, by int_id, whose entire history is to be purged.
1709 1722
1710 1723 default : None
1711 1724 """
1712 1725 if not targets and not jobs:
1713 1726 raise ValueError("Must specify at least one of `targets` and `jobs`")
1714 1727 if targets:
1715 1728 targets = self._build_targets(targets)[1]
1716 1729
1717 1730 # construct msg_ids from jobs
1718 1731 if jobs == 'all':
1719 1732 msg_ids = jobs
1720 1733 else:
1721 1734 msg_ids = self._build_msgids_from_jobs(jobs)
1722 1735
1723 1736 content = dict(engine_ids=targets, msg_ids=msg_ids)
1724 1737 self.session.send(self._query_socket, "purge_request", content=content)
1725 1738 idents, msg = self.session.recv(self._query_socket, 0)
1726 1739 if self.debug:
1727 1740 pprint(msg)
1728 1741 content = msg['content']
1729 1742 if content['status'] != 'ok':
1730 1743 raise self._unwrap_exception(content)
1731 1744
1732 1745 def purge_results(self, jobs=[], targets=[]):
1733 1746 """Clears the cached results from both the hub and the local client
1734 1747
1735 1748 Individual results can be purged by msg_id, or the entire
1736 1749 history of specific targets can be purged.
1737 1750
1738 1751 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1739 1752 the Client's db.
1740 1753
1741 1754 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1742 1755 the same arguments.
1743 1756
1744 1757 Parameters
1745 1758 ----------
1746 1759
1747 1760 jobs : str or list of str or AsyncResult objects
1748 1761 the msg_ids whose results should be forgotten.
1749 1762 targets : int/str/list of ints/strs
1750 1763 The targets, by int_id, whose entire history is to be purged.
1751 1764
1752 1765 default : None
1753 1766 """
1754 1767 self.purge_local_results(jobs=jobs, targets=targets)
1755 1768 self.purge_hub_results(jobs=jobs, targets=targets)
1756 1769
1757 1770 def purge_everything(self):
1758 1771 """Clears all content from previous Tasks from both the hub and the local client
1759 1772
1760 1773 In addition to calling `purge_results("all")` it also deletes the history and
1761 1774 other bookkeeping lists.
1762 1775 """
1763 1776 self.purge_results("all")
1764 1777 self.history = []
1765 1778 self.session.digest_history.clear()
1766 1779
1767 1780 @spin_first
1768 1781 def hub_history(self):
1769 1782 """Get the Hub's history
1770 1783
1771 1784 Just like the Client, the Hub has a history, which is a list of msg_ids.
1772 1785 This will contain the history of all clients, and, depending on configuration,
1773 1786 may contain history across multiple cluster sessions.
1774 1787
1775 1788 Any msg_id returned here is a valid argument to `get_result`.
1776 1789
1777 1790 Returns
1778 1791 -------
1779 1792
1780 1793 msg_ids : list of strs
1781 1794 list of all msg_ids, ordered by task submission time.
1782 1795 """
1783 1796
1784 1797 self.session.send(self._query_socket, "history_request", content={})
1785 1798 idents, msg = self.session.recv(self._query_socket, 0)
1786 1799
1787 1800 if self.debug:
1788 1801 pprint(msg)
1789 1802 content = msg['content']
1790 1803 if content['status'] != 'ok':
1791 1804 raise self._unwrap_exception(content)
1792 1805 else:
1793 1806 return content['history']
1794 1807
1795 1808 @spin_first
1796 1809 def db_query(self, query, keys=None):
1797 1810 """Query the Hub's TaskRecord database
1798 1811
1799 1812 This will return a list of task record dicts that match `query`
1800 1813
1801 1814 Parameters
1802 1815 ----------
1803 1816
1804 1817 query : mongodb query dict
1805 1818 The search dict. See mongodb query docs for details.
1806 1819 keys : list of strs [optional]
1807 1820 The subset of keys to be returned. The default is to fetch everything but buffers.
1808 1821 'msg_id' will *always* be included.
1809 1822 """
1810 1823 if isinstance(keys, basestring):
1811 1824 keys = [keys]
1812 1825 content = dict(query=query, keys=keys)
1813 1826 self.session.send(self._query_socket, "db_request", content=content)
1814 1827 idents, msg = self.session.recv(self._query_socket, 0)
1815 1828 if self.debug:
1816 1829 pprint(msg)
1817 1830 content = msg['content']
1818 1831 if content['status'] != 'ok':
1819 1832 raise self._unwrap_exception(content)
1820 1833
1821 1834 records = content['records']
1822 1835
1823 1836 buffer_lens = content['buffer_lens']
1824 1837 result_buffer_lens = content['result_buffer_lens']
1825 1838 buffers = msg['buffers']
1826 1839 has_bufs = buffer_lens is not None
1827 1840 has_rbufs = result_buffer_lens is not None
1828 1841 for i,rec in enumerate(records):
1829 1842 # relink buffers
1830 1843 if has_bufs:
1831 1844 blen = buffer_lens[i]
1832 1845 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1833 1846 if has_rbufs:
1834 1847 blen = result_buffer_lens[i]
1835 1848 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1836 1849
1837 1850 return records
1838 1851
1839 1852 __all__ = [ 'Client' ]
General Comments 0
You need to be logged in to leave comments. Login now