##// END OF EJS Templates
parallel.Client: ignore 'ssh' field of JSON connector if controller appears to be local...
MinRK -
Show More
@@ -1,1412 +1,1418
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import time
21 21 import warnings
22 22 from datetime import datetime
23 23 from getpass import getpass
24 24 from pprint import pprint
25 25
26 26 pjoin = os.path.join
27 27
28 28 import zmq
29 29 # from zmq.eventloop import ioloop, zmqstream
30 30
31 31 from IPython.config.configurable import MultipleInstanceError
32 32 from IPython.core.application import BaseIPythonApplication
33 33
34 34 from IPython.utils.jsonutil import rekey
35 35 from IPython.utils.localinterfaces import LOCAL_IPS
36 36 from IPython.utils.path import get_ipython_dir
37 37 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
38 38 Dict, List, Bool, Set)
39 39 from IPython.external.decorator import decorator
40 40 from IPython.external.ssh import tunnel
41 41
42 42 from IPython.parallel import error
43 43 from IPython.parallel import util
44 44
45 45 from IPython.zmq.session import Session, Message
46 46
47 47 from .asyncresult import AsyncResult, AsyncHubResult
48 48 from IPython.core.profiledir import ProfileDir, ProfileDirError
49 49 from .view import DirectView, LoadBalancedView
50 50
51 51 #--------------------------------------------------------------------------
52 52 # Decorators for Client methods
53 53 #--------------------------------------------------------------------------
54 54
55 55 @decorator
56 56 def spin_first(f, self, *args, **kwargs):
57 57 """Call spin() to sync state prior to calling the method."""
58 58 self.spin()
59 59 return f(self, *args, **kwargs)
60 60
61 61
62 62 #--------------------------------------------------------------------------
63 63 # Classes
64 64 #--------------------------------------------------------------------------
65 65
66 66 class Metadata(dict):
67 67 """Subclass of dict for initializing metadata values.
68 68
69 69 Attribute access works on keys.
70 70
71 71 These objects have a strict set of keys - errors will raise if you try
72 72 to add new keys.
73 73 """
74 74 def __init__(self, *args, **kwargs):
75 75 dict.__init__(self)
76 76 md = {'msg_id' : None,
77 77 'submitted' : None,
78 78 'started' : None,
79 79 'completed' : None,
80 80 'received' : None,
81 81 'engine_uuid' : None,
82 82 'engine_id' : None,
83 83 'follow' : None,
84 84 'after' : None,
85 85 'status' : None,
86 86
87 87 'pyin' : None,
88 88 'pyout' : None,
89 89 'pyerr' : None,
90 90 'stdout' : '',
91 91 'stderr' : '',
92 92 }
93 93 self.update(md)
94 94 self.update(dict(*args, **kwargs))
95 95
96 96 def __getattr__(self, key):
97 97 """getattr aliased to getitem"""
98 98 if key in self.iterkeys():
99 99 return self[key]
100 100 else:
101 101 raise AttributeError(key)
102 102
103 103 def __setattr__(self, key, value):
104 104 """setattr aliased to setitem, with strict"""
105 105 if key in self.iterkeys():
106 106 self[key] = value
107 107 else:
108 108 raise AttributeError(key)
109 109
110 110 def __setitem__(self, key, value):
111 111 """strict static key enforcement"""
112 112 if key in self.iterkeys():
113 113 dict.__setitem__(self, key, value)
114 114 else:
115 115 raise KeyError(key)
116 116
117 117
118 118 class Client(HasTraits):
119 119 """A semi-synchronous client to the IPython ZMQ cluster
120 120
121 121 Parameters
122 122 ----------
123 123
124 124 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
125 125 Connection information for the Hub's registration. If a json connector
126 126 file is given, then likely no further configuration is necessary.
127 127 [Default: use profile]
128 128 profile : bytes
129 129 The name of the Cluster profile to be used to find connector information.
130 130 If run from an IPython application, the default profile will be the same
131 131 as the running application, otherwise it will be 'default'.
132 132 context : zmq.Context
133 133 Pass an existing zmq.Context instance, otherwise the client will create its own.
134 134 debug : bool
135 135 flag for lots of message printing for debug purposes
136 136 timeout : int/float
137 137 time (in seconds) to wait for connection replies from the Hub
138 138 [Default: 10]
139 139
140 140 #-------------- session related args ----------------
141 141
142 142 config : Config object
143 143 If specified, this will be relayed to the Session for configuration
144 144 username : str
145 145 set username for the session object
146 146 packer : str (import_string) or callable
147 147 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
148 148 function to serialize messages. Must support same input as
149 149 JSON, and output must be bytes.
150 150 You can pass a callable directly as `pack`
151 151 unpacker : str (import_string) or callable
152 152 The inverse of packer. Only necessary if packer is specified as *not* one
153 153 of 'json' or 'pickle'.
154 154
155 155 #-------------- ssh related args ----------------
156 156 # These are args for configuring the ssh tunnel to be used
157 157 # credentials are used to forward connections over ssh to the Controller
158 158 # Note that the ip given in `addr` needs to be relative to sshserver
159 159 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
160 160 # and set sshserver as the same machine the Controller is on. However,
161 161 # the only requirement is that sshserver is able to see the Controller
162 162 # (i.e. is within the same trusted network).
163 163
164 164 sshserver : str
165 165 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
166 166 If keyfile or password is specified, and this is not, it will default to
167 167 the ip given in addr.
168 168 sshkey : str; path to public ssh key file
169 169 This specifies a key to be used in ssh login, default None.
170 170 Regular default ssh keys will be used without specifying this argument.
171 171 password : str
172 172 Your ssh password to sshserver. Note that if this is left None,
173 173 you will be prompted for it if passwordless key based login is unavailable.
174 174 paramiko : bool
175 175 flag for whether to use paramiko instead of shell ssh for tunneling.
176 176 [default: True on win32, False else]
177 177
178 178 ------- exec authentication args -------
179 179 If even localhost is untrusted, you can have some protection against
180 180 unauthorized execution by signing messages with HMAC digests.
181 181 Messages are still sent as cleartext, so if someone can snoop your
182 182 loopback traffic this will not protect your privacy, but will prevent
183 183 unauthorized execution.
184 184
185 185 exec_key : str
186 186 an authentication key or file containing a key
187 187 default: None
188 188
189 189
190 190 Attributes
191 191 ----------
192 192
193 193 ids : list of int engine IDs
194 194 requesting the ids attribute always synchronizes
195 195 the registration state. To request ids without synchronization,
196 196 use semi-private _ids attributes.
197 197
198 198 history : list of msg_ids
199 199 a list of msg_ids, keeping track of all the execution
200 200 messages you have submitted in order.
201 201
202 202 outstanding : set of msg_ids
203 203 a set of msg_ids that have been submitted, but whose
204 204 results have not yet been received.
205 205
206 206 results : dict
207 207 a dict of all our results, keyed by msg_id
208 208
209 209 block : bool
210 210 determines default behavior when block not specified
211 211 in execution methods
212 212
213 213 Methods
214 214 -------
215 215
216 216 spin
217 217 flushes incoming results and registration state changes
218 218 control methods spin, and requesting `ids` also ensures up to date
219 219
220 220 wait
221 221 wait on one or more msg_ids
222 222
223 223 execution methods
224 224 apply
225 225 legacy: execute, run
226 226
227 227 data movement
228 228 push, pull, scatter, gather
229 229
230 230 query methods
231 231 queue_status, get_result, purge, result_status
232 232
233 233 control methods
234 234 abort, shutdown
235 235
236 236 """
237 237
238 238
239 239 block = Bool(False)
240 240 outstanding = Set()
241 241 results = Instance('collections.defaultdict', (dict,))
242 242 metadata = Instance('collections.defaultdict', (Metadata,))
243 243 history = List()
244 244 debug = Bool(False)
245 245
246 246 profile=Unicode()
247 247 def _profile_default(self):
248 248 if BaseIPythonApplication.initialized():
249 249 # an IPython app *might* be running, try to get its profile
250 250 try:
251 251 return BaseIPythonApplication.instance().profile
252 252 except (AttributeError, MultipleInstanceError):
253 253 # could be a *different* subclass of config.Application,
254 254 # which would raise one of these two errors.
255 255 return u'default'
256 256 else:
257 257 return u'default'
258 258
259 259
260 260 _outstanding_dict = Instance('collections.defaultdict', (set,))
261 261 _ids = List()
262 262 _connected=Bool(False)
263 263 _ssh=Bool(False)
264 264 _context = Instance('zmq.Context')
265 265 _config = Dict()
266 266 _engines=Instance(util.ReverseDict, (), {})
267 267 # _hub_socket=Instance('zmq.Socket')
268 268 _query_socket=Instance('zmq.Socket')
269 269 _control_socket=Instance('zmq.Socket')
270 270 _iopub_socket=Instance('zmq.Socket')
271 271 _notification_socket=Instance('zmq.Socket')
272 272 _mux_socket=Instance('zmq.Socket')
273 273 _task_socket=Instance('zmq.Socket')
274 274 _task_scheme=Unicode()
275 275 _closed = False
276 276 _ignored_control_replies=Int(0)
277 277 _ignored_hub_replies=Int(0)
278 278
279 279 def __new__(self, *args, **kw):
280 280 # don't raise on positional args
281 281 return HasTraits.__new__(self, **kw)
282 282
283 283 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
284 284 context=None, debug=False, exec_key=None,
285 285 sshserver=None, sshkey=None, password=None, paramiko=None,
286 286 timeout=10, **extra_args
287 287 ):
288 288 if profile:
289 289 super(Client, self).__init__(debug=debug, profile=profile)
290 290 else:
291 291 super(Client, self).__init__(debug=debug)
292 292 if context is None:
293 293 context = zmq.Context.instance()
294 294 self._context = context
295 295
296 296 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
297 297 if self._cd is not None:
298 298 if url_or_file is None:
299 299 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
300 300 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
301 301 " Please specify at least one of url_or_file or profile."
302 302
303 303 try:
304 304 util.validate_url(url_or_file)
305 305 except AssertionError:
306 306 if not os.path.exists(url_or_file):
307 307 if self._cd:
308 308 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
309 309 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
310 310 with open(url_or_file) as f:
311 311 cfg = json.loads(f.read())
312 312 else:
313 313 cfg = {'url':url_or_file}
314 314
315 315 # sync defaults from args, json:
316 316 if sshserver:
317 317 cfg['ssh'] = sshserver
318 318 if exec_key:
319 319 cfg['exec_key'] = exec_key
320 320 exec_key = cfg['exec_key']
321 sshserver=cfg['ssh']
322 url = cfg['url']
323 321 location = cfg.setdefault('location', None)
324 322 cfg['url'] = util.disambiguate_url(cfg['url'], location)
325 323 url = cfg['url']
326 if location is not None:
327 proto,addr,port = util.split_url(url)
328 if addr == '127.0.0.1' and location not in LOCAL_IPS and not sshserver:
324 proto,addr,port = util.split_url(url)
325 if location is not None and addr == '127.0.0.1':
326 # location specified, and connection is expected to be local
327 if location not in LOCAL_IPS and not sshserver:
328 # load ssh from JSON *only* if the controller is not on
329 # this machine
330 sshserver=cfg['ssh']
331 if location not in LOCAL_IPS and not sshserver:
332 # warn if no ssh specified, but SSH is probably needed
333 # This is only a warning, because the most likely cause
334 # is a local Controller on a laptop whose IP is dynamic
329 335 warnings.warn("""
330 336 Controller appears to be listening on localhost, but not on this machine.
331 337 If this is true, you should specify Client(...,sshserver='you@%s')
332 338 or instruct your controller to listen on an external IP."""%location,
333 339 RuntimeWarning)
334 340
335 341 self._config = cfg
336 342
337 343 self._ssh = bool(sshserver or sshkey or password)
338 344 if self._ssh and sshserver is None:
339 345 # default to ssh via localhost
340 346 sshserver = url.split('://')[1].split(':')[0]
341 347 if self._ssh and password is None:
342 348 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
343 349 password=False
344 350 else:
345 351 password = getpass("SSH Password for %s: "%sshserver)
346 352 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
347 353
348 354 # configure and construct the session
349 355 if exec_key is not None:
350 356 if os.path.isfile(exec_key):
351 357 extra_args['keyfile'] = exec_key
352 358 else:
353 359 if isinstance(exec_key, unicode):
354 360 exec_key = exec_key.encode('ascii')
355 361 extra_args['key'] = exec_key
356 362 self.session = Session(**extra_args)
357 363
358 364 self._query_socket = self._context.socket(zmq.XREQ)
359 365 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
360 366 if self._ssh:
361 367 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
362 368 else:
363 369 self._query_socket.connect(url)
364 370
365 371 self.session.debug = self.debug
366 372
367 373 self._notification_handlers = {'registration_notification' : self._register_engine,
368 374 'unregistration_notification' : self._unregister_engine,
369 375 'shutdown_notification' : lambda msg: self.close(),
370 376 }
371 377 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
372 378 'apply_reply' : self._handle_apply_reply}
373 379 self._connect(sshserver, ssh_kwargs, timeout)
374 380
375 381 def __del__(self):
376 382 """cleanup sockets, but _not_ context."""
377 383 self.close()
378 384
379 385 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
380 386 if ipython_dir is None:
381 387 ipython_dir = get_ipython_dir()
382 388 if profile_dir is not None:
383 389 try:
384 390 self._cd = ProfileDir.find_profile_dir(profile_dir)
385 391 return
386 392 except ProfileDirError:
387 393 pass
388 394 elif profile is not None:
389 395 try:
390 396 self._cd = ProfileDir.find_profile_dir_by_name(
391 397 ipython_dir, profile)
392 398 return
393 399 except ProfileDirError:
394 400 pass
395 401 self._cd = None
396 402
397 403 def _update_engines(self, engines):
398 404 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
399 405 for k,v in engines.iteritems():
400 406 eid = int(k)
401 407 self._engines[eid] = bytes(v) # force not unicode
402 408 self._ids.append(eid)
403 409 self._ids = sorted(self._ids)
404 410 if sorted(self._engines.keys()) != range(len(self._engines)) and \
405 411 self._task_scheme == 'pure' and self._task_socket:
406 412 self._stop_scheduling_tasks()
407 413
408 414 def _stop_scheduling_tasks(self):
409 415 """Stop scheduling tasks because an engine has been unregistered
410 416 from a pure ZMQ scheduler.
411 417 """
412 418 self._task_socket.close()
413 419 self._task_socket = None
414 420 msg = "An engine has been unregistered, and we are using pure " +\
415 421 "ZMQ task scheduling. Task farming will be disabled."
416 422 if self.outstanding:
417 423 msg += " If you were running tasks when this happened, " +\
418 424 "some `outstanding` msg_ids may never resolve."
419 425 warnings.warn(msg, RuntimeWarning)
420 426
421 427 def _build_targets(self, targets):
422 428 """Turn valid target IDs or 'all' into two lists:
423 429 (int_ids, uuids).
424 430 """
425 431 if not self._ids:
426 432 # flush notification socket if no engines yet, just in case
427 433 if not self.ids:
428 434 raise error.NoEnginesRegistered("Can't build targets without any engines")
429 435
430 436 if targets is None:
431 437 targets = self._ids
432 438 elif isinstance(targets, str):
433 439 if targets.lower() == 'all':
434 440 targets = self._ids
435 441 else:
436 442 raise TypeError("%r not valid str target, must be 'all'"%(targets))
437 443 elif isinstance(targets, int):
438 444 if targets < 0:
439 445 targets = self.ids[targets]
440 446 if targets not in self._ids:
441 447 raise IndexError("No such engine: %i"%targets)
442 448 targets = [targets]
443 449
444 450 if isinstance(targets, slice):
445 451 indices = range(len(self._ids))[targets]
446 452 ids = self.ids
447 453 targets = [ ids[i] for i in indices ]
448 454
449 455 if not isinstance(targets, (tuple, list, xrange)):
450 456 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
451 457
452 458 return [self._engines[t] for t in targets], list(targets)
453 459
454 460 def _connect(self, sshserver, ssh_kwargs, timeout):
455 461 """setup all our socket connections to the cluster. This is called from
456 462 __init__."""
457 463
458 464 # Maybe allow reconnecting?
459 465 if self._connected:
460 466 return
461 467 self._connected=True
462 468
463 469 def connect_socket(s, url):
464 470 url = util.disambiguate_url(url, self._config['location'])
465 471 if self._ssh:
466 472 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
467 473 else:
468 474 return s.connect(url)
469 475
470 476 self.session.send(self._query_socket, 'connection_request')
471 477 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
472 478 poller = zmq.Poller()
473 479 poller.register(self._query_socket, zmq.POLLIN)
474 480 # poll expects milliseconds, timeout is seconds
475 481 evts = poller.poll(timeout*1000)
476 482 if not evts:
477 483 raise error.TimeoutError("Hub connection request timed out")
478 484 idents,msg = self.session.recv(self._query_socket,mode=0)
479 485 if self.debug:
480 486 pprint(msg)
481 487 msg = Message(msg)
482 488 content = msg.content
483 489 self._config['registration'] = dict(content)
484 490 if content.status == 'ok':
485 491 if content.mux:
486 492 self._mux_socket = self._context.socket(zmq.XREQ)
487 493 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
488 494 connect_socket(self._mux_socket, content.mux)
489 495 if content.task:
490 496 self._task_scheme, task_addr = content.task
491 497 self._task_socket = self._context.socket(zmq.XREQ)
492 498 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
493 499 connect_socket(self._task_socket, task_addr)
494 500 if content.notification:
495 501 self._notification_socket = self._context.socket(zmq.SUB)
496 502 connect_socket(self._notification_socket, content.notification)
497 503 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
498 504 # if content.query:
499 505 # self._query_socket = self._context.socket(zmq.XREQ)
500 506 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
501 507 # connect_socket(self._query_socket, content.query)
502 508 if content.control:
503 509 self._control_socket = self._context.socket(zmq.XREQ)
504 510 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
505 511 connect_socket(self._control_socket, content.control)
506 512 if content.iopub:
507 513 self._iopub_socket = self._context.socket(zmq.SUB)
508 514 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
509 515 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
510 516 connect_socket(self._iopub_socket, content.iopub)
511 517 self._update_engines(dict(content.engines))
512 518 else:
513 519 self._connected = False
514 520 raise Exception("Failed to connect!")
515 521
516 522 #--------------------------------------------------------------------------
517 523 # handlers and callbacks for incoming messages
518 524 #--------------------------------------------------------------------------
519 525
520 526 def _unwrap_exception(self, content):
521 527 """unwrap exception, and remap engine_id to int."""
522 528 e = error.unwrap_exception(content)
523 529 # print e.traceback
524 530 if e.engine_info:
525 531 e_uuid = e.engine_info['engine_uuid']
526 532 eid = self._engines[e_uuid]
527 533 e.engine_info['engine_id'] = eid
528 534 return e
529 535
530 536 def _extract_metadata(self, header, parent, content):
531 537 md = {'msg_id' : parent['msg_id'],
532 538 'received' : datetime.now(),
533 539 'engine_uuid' : header.get('engine', None),
534 540 'follow' : parent.get('follow', []),
535 541 'after' : parent.get('after', []),
536 542 'status' : content['status'],
537 543 }
538 544
539 545 if md['engine_uuid'] is not None:
540 546 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
541 547
542 548 if 'date' in parent:
543 549 md['submitted'] = parent['date']
544 550 if 'started' in header:
545 551 md['started'] = header['started']
546 552 if 'date' in header:
547 553 md['completed'] = header['date']
548 554 return md
549 555
550 556 def _register_engine(self, msg):
551 557 """Register a new engine, and update our connection info."""
552 558 content = msg['content']
553 559 eid = content['id']
554 560 d = {eid : content['queue']}
555 561 self._update_engines(d)
556 562
557 563 def _unregister_engine(self, msg):
558 564 """Unregister an engine that has died."""
559 565 content = msg['content']
560 566 eid = int(content['id'])
561 567 if eid in self._ids:
562 568 self._ids.remove(eid)
563 569 uuid = self._engines.pop(eid)
564 570
565 571 self._handle_stranded_msgs(eid, uuid)
566 572
567 573 if self._task_socket and self._task_scheme == 'pure':
568 574 self._stop_scheduling_tasks()
569 575
570 576 def _handle_stranded_msgs(self, eid, uuid):
571 577 """Handle messages known to be on an engine when the engine unregisters.
572 578
573 579 It is possible that this will fire prematurely - that is, an engine will
574 580 go down after completing a result, and the client will be notified
575 581 of the unregistration and later receive the successful result.
576 582 """
577 583
578 584 outstanding = self._outstanding_dict[uuid]
579 585
580 586 for msg_id in list(outstanding):
581 587 if msg_id in self.results:
582 588 # we already
583 589 continue
584 590 try:
585 591 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
586 592 except:
587 593 content = error.wrap_exception()
588 594 # build a fake message:
589 595 parent = {}
590 596 header = {}
591 597 parent['msg_id'] = msg_id
592 598 header['engine'] = uuid
593 599 header['date'] = datetime.now()
594 600 msg = dict(parent_header=parent, header=header, content=content)
595 601 self._handle_apply_reply(msg)
596 602
597 603 def _handle_execute_reply(self, msg):
598 604 """Save the reply to an execute_request into our results.
599 605
600 606 execute messages are never actually used. apply is used instead.
601 607 """
602 608
603 609 parent = msg['parent_header']
604 610 msg_id = parent['msg_id']
605 611 if msg_id not in self.outstanding:
606 612 if msg_id in self.history:
607 613 print ("got stale result: %s"%msg_id)
608 614 else:
609 615 print ("got unknown result: %s"%msg_id)
610 616 else:
611 617 self.outstanding.remove(msg_id)
612 618 self.results[msg_id] = self._unwrap_exception(msg['content'])
613 619
614 620 def _handle_apply_reply(self, msg):
615 621 """Save the reply to an apply_request into our results."""
616 622 parent = msg['parent_header']
617 623 msg_id = parent['msg_id']
618 624 if msg_id not in self.outstanding:
619 625 if msg_id in self.history:
620 626 print ("got stale result: %s"%msg_id)
621 627 print self.results[msg_id]
622 628 print msg
623 629 else:
624 630 print ("got unknown result: %s"%msg_id)
625 631 else:
626 632 self.outstanding.remove(msg_id)
627 633 content = msg['content']
628 634 header = msg['header']
629 635
630 636 # construct metadata:
631 637 md = self.metadata[msg_id]
632 638 md.update(self._extract_metadata(header, parent, content))
633 639 # is this redundant?
634 640 self.metadata[msg_id] = md
635 641
636 642 e_outstanding = self._outstanding_dict[md['engine_uuid']]
637 643 if msg_id in e_outstanding:
638 644 e_outstanding.remove(msg_id)
639 645
640 646 # construct result:
641 647 if content['status'] == 'ok':
642 648 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
643 649 elif content['status'] == 'aborted':
644 650 self.results[msg_id] = error.TaskAborted(msg_id)
645 651 elif content['status'] == 'resubmitted':
646 652 # TODO: handle resubmission
647 653 pass
648 654 else:
649 655 self.results[msg_id] = self._unwrap_exception(content)
650 656
651 657 def _flush_notifications(self):
652 658 """Flush notifications of engine registrations waiting
653 659 in ZMQ queue."""
654 660 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
655 661 while msg is not None:
656 662 if self.debug:
657 663 pprint(msg)
658 664 msg_type = msg['msg_type']
659 665 handler = self._notification_handlers.get(msg_type, None)
660 666 if handler is None:
661 667 raise Exception("Unhandled message type: %s"%msg.msg_type)
662 668 else:
663 669 handler(msg)
664 670 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
665 671
666 672 def _flush_results(self, sock):
667 673 """Flush task or queue results waiting in ZMQ queue."""
668 674 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
669 675 while msg is not None:
670 676 if self.debug:
671 677 pprint(msg)
672 678 msg_type = msg['msg_type']
673 679 handler = self._queue_handlers.get(msg_type, None)
674 680 if handler is None:
675 681 raise Exception("Unhandled message type: %s"%msg.msg_type)
676 682 else:
677 683 handler(msg)
678 684 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
679 685
680 686 def _flush_control(self, sock):
681 687 """Flush replies from the control channel waiting
682 688 in the ZMQ queue.
683 689
684 690 Currently: ignore them."""
685 691 if self._ignored_control_replies <= 0:
686 692 return
687 693 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
688 694 while msg is not None:
689 695 self._ignored_control_replies -= 1
690 696 if self.debug:
691 697 pprint(msg)
692 698 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
693 699
694 700 def _flush_ignored_control(self):
695 701 """flush ignored control replies"""
696 702 while self._ignored_control_replies > 0:
697 703 self.session.recv(self._control_socket)
698 704 self._ignored_control_replies -= 1
699 705
700 706 def _flush_ignored_hub_replies(self):
701 707 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
702 708 while msg is not None:
703 709 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
704 710
705 711 def _flush_iopub(self, sock):
706 712 """Flush replies from the iopub channel waiting
707 713 in the ZMQ queue.
708 714 """
709 715 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
710 716 while msg is not None:
711 717 if self.debug:
712 718 pprint(msg)
713 719 parent = msg['parent_header']
714 720 msg_id = parent['msg_id']
715 721 content = msg['content']
716 722 header = msg['header']
717 723 msg_type = msg['msg_type']
718 724
719 725 # init metadata:
720 726 md = self.metadata[msg_id]
721 727
722 728 if msg_type == 'stream':
723 729 name = content['name']
724 730 s = md[name] or ''
725 731 md[name] = s + content['data']
726 732 elif msg_type == 'pyerr':
727 733 md.update({'pyerr' : self._unwrap_exception(content)})
728 734 elif msg_type == 'pyin':
729 735 md.update({'pyin' : content['code']})
730 736 else:
731 737 md.update({msg_type : content.get('data', '')})
732 738
733 739 # reduntant?
734 740 self.metadata[msg_id] = md
735 741
736 742 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
737 743
738 744 #--------------------------------------------------------------------------
739 745 # len, getitem
740 746 #--------------------------------------------------------------------------
741 747
742 748 def __len__(self):
743 749 """len(client) returns # of engines."""
744 750 return len(self.ids)
745 751
746 752 def __getitem__(self, key):
747 753 """index access returns DirectView multiplexer objects
748 754
749 755 Must be int, slice, or list/tuple/xrange of ints"""
750 756 if not isinstance(key, (int, slice, tuple, list, xrange)):
751 757 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
752 758 else:
753 759 return self.direct_view(key)
754 760
755 761 #--------------------------------------------------------------------------
756 762 # Begin public methods
757 763 #--------------------------------------------------------------------------
758 764
759 765 @property
760 766 def ids(self):
761 767 """Always up-to-date ids property."""
762 768 self._flush_notifications()
763 769 # always copy:
764 770 return list(self._ids)
765 771
766 772 def close(self):
767 773 if self._closed:
768 774 return
769 775 snames = filter(lambda n: n.endswith('socket'), dir(self))
770 776 for socket in map(lambda name: getattr(self, name), snames):
771 777 if isinstance(socket, zmq.Socket) and not socket.closed:
772 778 socket.close()
773 779 self._closed = True
774 780
775 781 def spin(self):
776 782 """Flush any registration notifications and execution results
777 783 waiting in the ZMQ queue.
778 784 """
779 785 if self._notification_socket:
780 786 self._flush_notifications()
781 787 if self._mux_socket:
782 788 self._flush_results(self._mux_socket)
783 789 if self._task_socket:
784 790 self._flush_results(self._task_socket)
785 791 if self._control_socket:
786 792 self._flush_control(self._control_socket)
787 793 if self._iopub_socket:
788 794 self._flush_iopub(self._iopub_socket)
789 795 if self._query_socket:
790 796 self._flush_ignored_hub_replies()
791 797
792 798 def wait(self, jobs=None, timeout=-1):
793 799 """waits on one or more `jobs`, for up to `timeout` seconds.
794 800
795 801 Parameters
796 802 ----------
797 803
798 804 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
799 805 ints are indices to self.history
800 806 strs are msg_ids
801 807 default: wait on all outstanding messages
802 808 timeout : float
803 809 a time in seconds, after which to give up.
804 810 default is -1, which means no timeout
805 811
806 812 Returns
807 813 -------
808 814
809 815 True : when all msg_ids are done
810 816 False : timeout reached, some msg_ids still outstanding
811 817 """
812 818 tic = time.time()
813 819 if jobs is None:
814 820 theids = self.outstanding
815 821 else:
816 822 if isinstance(jobs, (int, str, AsyncResult)):
817 823 jobs = [jobs]
818 824 theids = set()
819 825 for job in jobs:
820 826 if isinstance(job, int):
821 827 # index access
822 828 job = self.history[job]
823 829 elif isinstance(job, AsyncResult):
824 830 map(theids.add, job.msg_ids)
825 831 continue
826 832 theids.add(job)
827 833 if not theids.intersection(self.outstanding):
828 834 return True
829 835 self.spin()
830 836 while theids.intersection(self.outstanding):
831 837 if timeout >= 0 and ( time.time()-tic ) > timeout:
832 838 break
833 839 time.sleep(1e-3)
834 840 self.spin()
835 841 return len(theids.intersection(self.outstanding)) == 0
836 842
837 843 #--------------------------------------------------------------------------
838 844 # Control methods
839 845 #--------------------------------------------------------------------------
840 846
841 847 @spin_first
842 848 def clear(self, targets=None, block=None):
843 849 """Clear the namespace in target(s)."""
844 850 block = self.block if block is None else block
845 851 targets = self._build_targets(targets)[0]
846 852 for t in targets:
847 853 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
848 854 error = False
849 855 if block:
850 856 self._flush_ignored_control()
851 857 for i in range(len(targets)):
852 858 idents,msg = self.session.recv(self._control_socket,0)
853 859 if self.debug:
854 860 pprint(msg)
855 861 if msg['content']['status'] != 'ok':
856 862 error = self._unwrap_exception(msg['content'])
857 863 else:
858 864 self._ignored_control_replies += len(targets)
859 865 if error:
860 866 raise error
861 867
862 868
863 869 @spin_first
864 870 def abort(self, jobs=None, targets=None, block=None):
865 871 """Abort specific jobs from the execution queues of target(s).
866 872
867 873 This is a mechanism to prevent jobs that have already been submitted
868 874 from executing.
869 875
870 876 Parameters
871 877 ----------
872 878
873 879 jobs : msg_id, list of msg_ids, or AsyncResult
874 880 The jobs to be aborted
875 881
876 882
877 883 """
878 884 block = self.block if block is None else block
879 885 targets = self._build_targets(targets)[0]
880 886 msg_ids = []
881 887 if isinstance(jobs, (basestring,AsyncResult)):
882 888 jobs = [jobs]
883 889 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
884 890 if bad_ids:
885 891 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
886 892 for j in jobs:
887 893 if isinstance(j, AsyncResult):
888 894 msg_ids.extend(j.msg_ids)
889 895 else:
890 896 msg_ids.append(j)
891 897 content = dict(msg_ids=msg_ids)
892 898 for t in targets:
893 899 self.session.send(self._control_socket, 'abort_request',
894 900 content=content, ident=t)
895 901 error = False
896 902 if block:
897 903 self._flush_ignored_control()
898 904 for i in range(len(targets)):
899 905 idents,msg = self.session.recv(self._control_socket,0)
900 906 if self.debug:
901 907 pprint(msg)
902 908 if msg['content']['status'] != 'ok':
903 909 error = self._unwrap_exception(msg['content'])
904 910 else:
905 911 self._ignored_control_replies += len(targets)
906 912 if error:
907 913 raise error
908 914
909 915 @spin_first
910 916 def shutdown(self, targets=None, restart=False, hub=False, block=None):
911 917 """Terminates one or more engine processes, optionally including the hub."""
912 918 block = self.block if block is None else block
913 919 if hub:
914 920 targets = 'all'
915 921 targets = self._build_targets(targets)[0]
916 922 for t in targets:
917 923 self.session.send(self._control_socket, 'shutdown_request',
918 924 content={'restart':restart},ident=t)
919 925 error = False
920 926 if block or hub:
921 927 self._flush_ignored_control()
922 928 for i in range(len(targets)):
923 929 idents,msg = self.session.recv(self._control_socket, 0)
924 930 if self.debug:
925 931 pprint(msg)
926 932 if msg['content']['status'] != 'ok':
927 933 error = self._unwrap_exception(msg['content'])
928 934 else:
929 935 self._ignored_control_replies += len(targets)
930 936
931 937 if hub:
932 938 time.sleep(0.25)
933 939 self.session.send(self._query_socket, 'shutdown_request')
934 940 idents,msg = self.session.recv(self._query_socket, 0)
935 941 if self.debug:
936 942 pprint(msg)
937 943 if msg['content']['status'] != 'ok':
938 944 error = self._unwrap_exception(msg['content'])
939 945
940 946 if error:
941 947 raise error
942 948
943 949 #--------------------------------------------------------------------------
944 950 # Execution related methods
945 951 #--------------------------------------------------------------------------
946 952
947 953 def _maybe_raise(self, result):
948 954 """wrapper for maybe raising an exception if apply failed."""
949 955 if isinstance(result, error.RemoteError):
950 956 raise result
951 957
952 958 return result
953 959
954 960 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
955 961 ident=None):
956 962 """construct and send an apply message via a socket.
957 963
958 964 This is the principal method with which all engine execution is performed by views.
959 965 """
960 966
961 967 assert not self._closed, "cannot use me anymore, I'm closed!"
962 968 # defaults:
963 969 args = args if args is not None else []
964 970 kwargs = kwargs if kwargs is not None else {}
965 971 subheader = subheader if subheader is not None else {}
966 972
967 973 # validate arguments
968 974 if not callable(f):
969 975 raise TypeError("f must be callable, not %s"%type(f))
970 976 if not isinstance(args, (tuple, list)):
971 977 raise TypeError("args must be tuple or list, not %s"%type(args))
972 978 if not isinstance(kwargs, dict):
973 979 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
974 980 if not isinstance(subheader, dict):
975 981 raise TypeError("subheader must be dict, not %s"%type(subheader))
976 982
977 983 bufs = util.pack_apply_message(f,args,kwargs)
978 984
979 985 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
980 986 subheader=subheader, track=track)
981 987
982 988 msg_id = msg['msg_id']
983 989 self.outstanding.add(msg_id)
984 990 if ident:
985 991 # possibly routed to a specific engine
986 992 if isinstance(ident, list):
987 993 ident = ident[-1]
988 994 if ident in self._engines.values():
989 995 # save for later, in case of engine death
990 996 self._outstanding_dict[ident].add(msg_id)
991 997 self.history.append(msg_id)
992 998 self.metadata[msg_id]['submitted'] = datetime.now()
993 999
994 1000 return msg
995 1001
996 1002 #--------------------------------------------------------------------------
997 1003 # construct a View object
998 1004 #--------------------------------------------------------------------------
999 1005
1000 1006 def load_balanced_view(self, targets=None):
1001 1007 """construct a DirectView object.
1002 1008
1003 1009 If no arguments are specified, create a LoadBalancedView
1004 1010 using all engines.
1005 1011
1006 1012 Parameters
1007 1013 ----------
1008 1014
1009 1015 targets: list,slice,int,etc. [default: use all engines]
1010 1016 The subset of engines across which to load-balance
1011 1017 """
1012 1018 if targets is not None:
1013 1019 targets = self._build_targets(targets)[1]
1014 1020 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1015 1021
1016 1022 def direct_view(self, targets='all'):
1017 1023 """construct a DirectView object.
1018 1024
1019 1025 If no targets are specified, create a DirectView
1020 1026 using all engines.
1021 1027
1022 1028 Parameters
1023 1029 ----------
1024 1030
1025 1031 targets: list,slice,int,etc. [default: use all engines]
1026 1032 The engines to use for the View
1027 1033 """
1028 1034 single = isinstance(targets, int)
1029 1035 targets = self._build_targets(targets)[1]
1030 1036 if single:
1031 1037 targets = targets[0]
1032 1038 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1033 1039
1034 1040 #--------------------------------------------------------------------------
1035 1041 # Query methods
1036 1042 #--------------------------------------------------------------------------
1037 1043
1038 1044 @spin_first
1039 1045 def get_result(self, indices_or_msg_ids=None, block=None):
1040 1046 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1041 1047
1042 1048 If the client already has the results, no request to the Hub will be made.
1043 1049
1044 1050 This is a convenient way to construct AsyncResult objects, which are wrappers
1045 1051 that include metadata about execution, and allow for awaiting results that
1046 1052 were not submitted by this Client.
1047 1053
1048 1054 It can also be a convenient way to retrieve the metadata associated with
1049 1055 blocking execution, since it always retrieves
1050 1056
1051 1057 Examples
1052 1058 --------
1053 1059 ::
1054 1060
1055 1061 In [10]: r = client.apply()
1056 1062
1057 1063 Parameters
1058 1064 ----------
1059 1065
1060 1066 indices_or_msg_ids : integer history index, str msg_id, or list of either
1061 1067 The indices or msg_ids of indices to be retrieved
1062 1068
1063 1069 block : bool
1064 1070 Whether to wait for the result to be done
1065 1071
1066 1072 Returns
1067 1073 -------
1068 1074
1069 1075 AsyncResult
1070 1076 A single AsyncResult object will always be returned.
1071 1077
1072 1078 AsyncHubResult
1073 1079 A subclass of AsyncResult that retrieves results from the Hub
1074 1080
1075 1081 """
1076 1082 block = self.block if block is None else block
1077 1083 if indices_or_msg_ids is None:
1078 1084 indices_or_msg_ids = -1
1079 1085
1080 1086 if not isinstance(indices_or_msg_ids, (list,tuple)):
1081 1087 indices_or_msg_ids = [indices_or_msg_ids]
1082 1088
1083 1089 theids = []
1084 1090 for id in indices_or_msg_ids:
1085 1091 if isinstance(id, int):
1086 1092 id = self.history[id]
1087 1093 if not isinstance(id, str):
1088 1094 raise TypeError("indices must be str or int, not %r"%id)
1089 1095 theids.append(id)
1090 1096
1091 1097 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1092 1098 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1093 1099
1094 1100 if remote_ids:
1095 1101 ar = AsyncHubResult(self, msg_ids=theids)
1096 1102 else:
1097 1103 ar = AsyncResult(self, msg_ids=theids)
1098 1104
1099 1105 if block:
1100 1106 ar.wait()
1101 1107
1102 1108 return ar
1103 1109
1104 1110 @spin_first
1105 1111 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1106 1112 """Resubmit one or more tasks.
1107 1113
1108 1114 in-flight tasks may not be resubmitted.
1109 1115
1110 1116 Parameters
1111 1117 ----------
1112 1118
1113 1119 indices_or_msg_ids : integer history index, str msg_id, or list of either
1114 1120 The indices or msg_ids of indices to be retrieved
1115 1121
1116 1122 block : bool
1117 1123 Whether to wait for the result to be done
1118 1124
1119 1125 Returns
1120 1126 -------
1121 1127
1122 1128 AsyncHubResult
1123 1129 A subclass of AsyncResult that retrieves results from the Hub
1124 1130
1125 1131 """
1126 1132 block = self.block if block is None else block
1127 1133 if indices_or_msg_ids is None:
1128 1134 indices_or_msg_ids = -1
1129 1135
1130 1136 if not isinstance(indices_or_msg_ids, (list,tuple)):
1131 1137 indices_or_msg_ids = [indices_or_msg_ids]
1132 1138
1133 1139 theids = []
1134 1140 for id in indices_or_msg_ids:
1135 1141 if isinstance(id, int):
1136 1142 id = self.history[id]
1137 1143 if not isinstance(id, str):
1138 1144 raise TypeError("indices must be str or int, not %r"%id)
1139 1145 theids.append(id)
1140 1146
1141 1147 for msg_id in theids:
1142 1148 self.outstanding.discard(msg_id)
1143 1149 if msg_id in self.history:
1144 1150 self.history.remove(msg_id)
1145 1151 self.results.pop(msg_id, None)
1146 1152 self.metadata.pop(msg_id, None)
1147 1153 content = dict(msg_ids = theids)
1148 1154
1149 1155 self.session.send(self._query_socket, 'resubmit_request', content)
1150 1156
1151 1157 zmq.select([self._query_socket], [], [])
1152 1158 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1153 1159 if self.debug:
1154 1160 pprint(msg)
1155 1161 content = msg['content']
1156 1162 if content['status'] != 'ok':
1157 1163 raise self._unwrap_exception(content)
1158 1164
1159 1165 ar = AsyncHubResult(self, msg_ids=theids)
1160 1166
1161 1167 if block:
1162 1168 ar.wait()
1163 1169
1164 1170 return ar
1165 1171
1166 1172 @spin_first
1167 1173 def result_status(self, msg_ids, status_only=True):
1168 1174 """Check on the status of the result(s) of the apply request with `msg_ids`.
1169 1175
1170 1176 If status_only is False, then the actual results will be retrieved, else
1171 1177 only the status of the results will be checked.
1172 1178
1173 1179 Parameters
1174 1180 ----------
1175 1181
1176 1182 msg_ids : list of msg_ids
1177 1183 if int:
1178 1184 Passed as index to self.history for convenience.
1179 1185 status_only : bool (default: True)
1180 1186 if False:
1181 1187 Retrieve the actual results of completed tasks.
1182 1188
1183 1189 Returns
1184 1190 -------
1185 1191
1186 1192 results : dict
1187 1193 There will always be the keys 'pending' and 'completed', which will
1188 1194 be lists of msg_ids that are incomplete or complete. If `status_only`
1189 1195 is False, then completed results will be keyed by their `msg_id`.
1190 1196 """
1191 1197 if not isinstance(msg_ids, (list,tuple)):
1192 1198 msg_ids = [msg_ids]
1193 1199
1194 1200 theids = []
1195 1201 for msg_id in msg_ids:
1196 1202 if isinstance(msg_id, int):
1197 1203 msg_id = self.history[msg_id]
1198 1204 if not isinstance(msg_id, basestring):
1199 1205 raise TypeError("msg_ids must be str, not %r"%msg_id)
1200 1206 theids.append(msg_id)
1201 1207
1202 1208 completed = []
1203 1209 local_results = {}
1204 1210
1205 1211 # comment this block out to temporarily disable local shortcut:
1206 1212 for msg_id in theids:
1207 1213 if msg_id in self.results:
1208 1214 completed.append(msg_id)
1209 1215 local_results[msg_id] = self.results[msg_id]
1210 1216 theids.remove(msg_id)
1211 1217
1212 1218 if theids: # some not locally cached
1213 1219 content = dict(msg_ids=theids, status_only=status_only)
1214 1220 msg = self.session.send(self._query_socket, "result_request", content=content)
1215 1221 zmq.select([self._query_socket], [], [])
1216 1222 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1217 1223 if self.debug:
1218 1224 pprint(msg)
1219 1225 content = msg['content']
1220 1226 if content['status'] != 'ok':
1221 1227 raise self._unwrap_exception(content)
1222 1228 buffers = msg['buffers']
1223 1229 else:
1224 1230 content = dict(completed=[],pending=[])
1225 1231
1226 1232 content['completed'].extend(completed)
1227 1233
1228 1234 if status_only:
1229 1235 return content
1230 1236
1231 1237 failures = []
1232 1238 # load cached results into result:
1233 1239 content.update(local_results)
1234 1240
1235 1241 # update cache with results:
1236 1242 for msg_id in sorted(theids):
1237 1243 if msg_id in content['completed']:
1238 1244 rec = content[msg_id]
1239 1245 parent = rec['header']
1240 1246 header = rec['result_header']
1241 1247 rcontent = rec['result_content']
1242 1248 iodict = rec['io']
1243 1249 if isinstance(rcontent, str):
1244 1250 rcontent = self.session.unpack(rcontent)
1245 1251
1246 1252 md = self.metadata[msg_id]
1247 1253 md.update(self._extract_metadata(header, parent, rcontent))
1248 1254 md.update(iodict)
1249 1255
1250 1256 if rcontent['status'] == 'ok':
1251 1257 res,buffers = util.unserialize_object(buffers)
1252 1258 else:
1253 1259 print rcontent
1254 1260 res = self._unwrap_exception(rcontent)
1255 1261 failures.append(res)
1256 1262
1257 1263 self.results[msg_id] = res
1258 1264 content[msg_id] = res
1259 1265
1260 1266 if len(theids) == 1 and failures:
1261 1267 raise failures[0]
1262 1268
1263 1269 error.collect_exceptions(failures, "result_status")
1264 1270 return content
1265 1271
1266 1272 @spin_first
1267 1273 def queue_status(self, targets='all', verbose=False):
1268 1274 """Fetch the status of engine queues.
1269 1275
1270 1276 Parameters
1271 1277 ----------
1272 1278
1273 1279 targets : int/str/list of ints/strs
1274 1280 the engines whose states are to be queried.
1275 1281 default : all
1276 1282 verbose : bool
1277 1283 Whether to return lengths only, or lists of ids for each element
1278 1284 """
1279 1285 engine_ids = self._build_targets(targets)[1]
1280 1286 content = dict(targets=engine_ids, verbose=verbose)
1281 1287 self.session.send(self._query_socket, "queue_request", content=content)
1282 1288 idents,msg = self.session.recv(self._query_socket, 0)
1283 1289 if self.debug:
1284 1290 pprint(msg)
1285 1291 content = msg['content']
1286 1292 status = content.pop('status')
1287 1293 if status != 'ok':
1288 1294 raise self._unwrap_exception(content)
1289 1295 content = rekey(content)
1290 1296 if isinstance(targets, int):
1291 1297 return content[targets]
1292 1298 else:
1293 1299 return content
1294 1300
1295 1301 @spin_first
1296 1302 def purge_results(self, jobs=[], targets=[]):
1297 1303 """Tell the Hub to forget results.
1298 1304
1299 1305 Individual results can be purged by msg_id, or the entire
1300 1306 history of specific targets can be purged.
1301 1307
1302 1308 Parameters
1303 1309 ----------
1304 1310
1305 1311 jobs : str or list of str or AsyncResult objects
1306 1312 the msg_ids whose results should be forgotten.
1307 1313 targets : int/str/list of ints/strs
1308 1314 The targets, by uuid or int_id, whose entire history is to be purged.
1309 1315 Use `targets='all'` to scrub everything from the Hub's memory.
1310 1316
1311 1317 default : None
1312 1318 """
1313 1319 if not targets and not jobs:
1314 1320 raise ValueError("Must specify at least one of `targets` and `jobs`")
1315 1321 if targets:
1316 1322 targets = self._build_targets(targets)[1]
1317 1323
1318 1324 # construct msg_ids from jobs
1319 1325 msg_ids = []
1320 1326 if isinstance(jobs, (basestring,AsyncResult)):
1321 1327 jobs = [jobs]
1322 1328 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1323 1329 if bad_ids:
1324 1330 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1325 1331 for j in jobs:
1326 1332 if isinstance(j, AsyncResult):
1327 1333 msg_ids.extend(j.msg_ids)
1328 1334 else:
1329 1335 msg_ids.append(j)
1330 1336
1331 1337 content = dict(targets=targets, msg_ids=msg_ids)
1332 1338 self.session.send(self._query_socket, "purge_request", content=content)
1333 1339 idents, msg = self.session.recv(self._query_socket, 0)
1334 1340 if self.debug:
1335 1341 pprint(msg)
1336 1342 content = msg['content']
1337 1343 if content['status'] != 'ok':
1338 1344 raise self._unwrap_exception(content)
1339 1345
1340 1346 @spin_first
1341 1347 def hub_history(self):
1342 1348 """Get the Hub's history
1343 1349
1344 1350 Just like the Client, the Hub has a history, which is a list of msg_ids.
1345 1351 This will contain the history of all clients, and, depending on configuration,
1346 1352 may contain history across multiple cluster sessions.
1347 1353
1348 1354 Any msg_id returned here is a valid argument to `get_result`.
1349 1355
1350 1356 Returns
1351 1357 -------
1352 1358
1353 1359 msg_ids : list of strs
1354 1360 list of all msg_ids, ordered by task submission time.
1355 1361 """
1356 1362
1357 1363 self.session.send(self._query_socket, "history_request", content={})
1358 1364 idents, msg = self.session.recv(self._query_socket, 0)
1359 1365
1360 1366 if self.debug:
1361 1367 pprint(msg)
1362 1368 content = msg['content']
1363 1369 if content['status'] != 'ok':
1364 1370 raise self._unwrap_exception(content)
1365 1371 else:
1366 1372 return content['history']
1367 1373
1368 1374 @spin_first
1369 1375 def db_query(self, query, keys=None):
1370 1376 """Query the Hub's TaskRecord database
1371 1377
1372 1378 This will return a list of task record dicts that match `query`
1373 1379
1374 1380 Parameters
1375 1381 ----------
1376 1382
1377 1383 query : mongodb query dict
1378 1384 The search dict. See mongodb query docs for details.
1379 1385 keys : list of strs [optional]
1380 1386 The subset of keys to be returned. The default is to fetch everything but buffers.
1381 1387 'msg_id' will *always* be included.
1382 1388 """
1383 1389 if isinstance(keys, basestring):
1384 1390 keys = [keys]
1385 1391 content = dict(query=query, keys=keys)
1386 1392 self.session.send(self._query_socket, "db_request", content=content)
1387 1393 idents, msg = self.session.recv(self._query_socket, 0)
1388 1394 if self.debug:
1389 1395 pprint(msg)
1390 1396 content = msg['content']
1391 1397 if content['status'] != 'ok':
1392 1398 raise self._unwrap_exception(content)
1393 1399
1394 1400 records = content['records']
1395 1401
1396 1402 buffer_lens = content['buffer_lens']
1397 1403 result_buffer_lens = content['result_buffer_lens']
1398 1404 buffers = msg['buffers']
1399 1405 has_bufs = buffer_lens is not None
1400 1406 has_rbufs = result_buffer_lens is not None
1401 1407 for i,rec in enumerate(records):
1402 1408 # relink buffers
1403 1409 if has_bufs:
1404 1410 blen = buffer_lens[i]
1405 1411 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1406 1412 if has_rbufs:
1407 1413 blen = result_buffer_lens[i]
1408 1414 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1409 1415
1410 1416 return records
1411 1417
1412 1418 __all__ = [ 'Client' ]
General Comments 0
You need to be logged in to leave comments. Login now