##// END OF EJS Templates
better warning on non-local controller without ssh...
MinRK -
Show More
@@ -1,1412 +1,1412
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import time
21 21 import warnings
22 22 from datetime import datetime
23 23 from getpass import getpass
24 24 from pprint import pprint
25 25
26 26 pjoin = os.path.join
27 27
28 28 import zmq
29 29 # from zmq.eventloop import ioloop, zmqstream
30 30
31 31 from IPython.config.configurable import MultipleInstanceError
32 32 from IPython.core.application import BaseIPythonApplication
33 33
34 34 from IPython.utils.jsonutil import rekey
35 35 from IPython.utils.localinterfaces import LOCAL_IPS
36 36 from IPython.utils.path import get_ipython_dir
37 37 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
38 38 Dict, List, Bool, Set)
39 39 from IPython.external.decorator import decorator
40 40 from IPython.external.ssh import tunnel
41 41
42 42 from IPython.parallel import error
43 43 from IPython.parallel import util
44 44
45 45 from IPython.zmq.session import Session, Message
46 46
47 47 from .asyncresult import AsyncResult, AsyncHubResult
48 48 from IPython.core.profiledir import ProfileDir, ProfileDirError
49 49 from .view import DirectView, LoadBalancedView
50 50
51 51 #--------------------------------------------------------------------------
52 52 # Decorators for Client methods
53 53 #--------------------------------------------------------------------------
54 54
55 55 @decorator
56 56 def spin_first(f, self, *args, **kwargs):
57 57 """Call spin() to sync state prior to calling the method."""
58 58 self.spin()
59 59 return f(self, *args, **kwargs)
60 60
61 61
62 62 #--------------------------------------------------------------------------
63 63 # Classes
64 64 #--------------------------------------------------------------------------
65 65
66 66 class Metadata(dict):
67 67 """Subclass of dict for initializing metadata values.
68 68
69 69 Attribute access works on keys.
70 70
71 71 These objects have a strict set of keys - errors will raise if you try
72 72 to add new keys.
73 73 """
74 74 def __init__(self, *args, **kwargs):
75 75 dict.__init__(self)
76 76 md = {'msg_id' : None,
77 77 'submitted' : None,
78 78 'started' : None,
79 79 'completed' : None,
80 80 'received' : None,
81 81 'engine_uuid' : None,
82 82 'engine_id' : None,
83 83 'follow' : None,
84 84 'after' : None,
85 85 'status' : None,
86 86
87 87 'pyin' : None,
88 88 'pyout' : None,
89 89 'pyerr' : None,
90 90 'stdout' : '',
91 91 'stderr' : '',
92 92 }
93 93 self.update(md)
94 94 self.update(dict(*args, **kwargs))
95 95
96 96 def __getattr__(self, key):
97 97 """getattr aliased to getitem"""
98 98 if key in self.iterkeys():
99 99 return self[key]
100 100 else:
101 101 raise AttributeError(key)
102 102
103 103 def __setattr__(self, key, value):
104 104 """setattr aliased to setitem, with strict"""
105 105 if key in self.iterkeys():
106 106 self[key] = value
107 107 else:
108 108 raise AttributeError(key)
109 109
110 110 def __setitem__(self, key, value):
111 111 """strict static key enforcement"""
112 112 if key in self.iterkeys():
113 113 dict.__setitem__(self, key, value)
114 114 else:
115 115 raise KeyError(key)
116 116
117 117
118 118 class Client(HasTraits):
119 119 """A semi-synchronous client to the IPython ZMQ cluster
120 120
121 121 Parameters
122 122 ----------
123 123
124 124 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
125 125 Connection information for the Hub's registration. If a json connector
126 126 file is given, then likely no further configuration is necessary.
127 127 [Default: use profile]
128 128 profile : bytes
129 129 The name of the Cluster profile to be used to find connector information.
130 130 If run from an IPython application, the default profile will be the same
131 131 as the running application, otherwise it will be 'default'.
132 132 context : zmq.Context
133 133 Pass an existing zmq.Context instance, otherwise the client will create its own.
134 134 debug : bool
135 135 flag for lots of message printing for debug purposes
136 136 timeout : int/float
137 137 time (in seconds) to wait for connection replies from the Hub
138 138 [Default: 10]
139 139
140 140 #-------------- session related args ----------------
141 141
142 142 config : Config object
143 143 If specified, this will be relayed to the Session for configuration
144 144 username : str
145 145 set username for the session object
146 146 packer : str (import_string) or callable
147 147 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
148 148 function to serialize messages. Must support same input as
149 149 JSON, and output must be bytes.
150 150 You can pass a callable directly as `pack`
151 151 unpacker : str (import_string) or callable
152 152 The inverse of packer. Only necessary if packer is specified as *not* one
153 153 of 'json' or 'pickle'.
154 154
155 155 #-------------- ssh related args ----------------
156 156 # These are args for configuring the ssh tunnel to be used
157 157 # credentials are used to forward connections over ssh to the Controller
158 158 # Note that the ip given in `addr` needs to be relative to sshserver
159 159 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
160 160 # and set sshserver as the same machine the Controller is on. However,
161 161 # the only requirement is that sshserver is able to see the Controller
162 162 # (i.e. is within the same trusted network).
163 163
164 164 sshserver : str
165 165 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
166 166 If keyfile or password is specified, and this is not, it will default to
167 167 the ip given in addr.
168 168 sshkey : str; path to public ssh key file
169 169 This specifies a key to be used in ssh login, default None.
170 170 Regular default ssh keys will be used without specifying this argument.
171 171 password : str
172 172 Your ssh password to sshserver. Note that if this is left None,
173 173 you will be prompted for it if passwordless key based login is unavailable.
174 174 paramiko : bool
175 175 flag for whether to use paramiko instead of shell ssh for tunneling.
176 176 [default: True on win32, False else]
177 177
178 178 ------- exec authentication args -------
179 179 If even localhost is untrusted, you can have some protection against
180 180 unauthorized execution by signing messages with HMAC digests.
181 181 Messages are still sent as cleartext, so if someone can snoop your
182 182 loopback traffic this will not protect your privacy, but will prevent
183 183 unauthorized execution.
184 184
185 185 exec_key : str
186 186 an authentication key or file containing a key
187 187 default: None
188 188
189 189
190 190 Attributes
191 191 ----------
192 192
193 193 ids : list of int engine IDs
194 194 requesting the ids attribute always synchronizes
195 195 the registration state. To request ids without synchronization,
196 196 use semi-private _ids attributes.
197 197
198 198 history : list of msg_ids
199 199 a list of msg_ids, keeping track of all the execution
200 200 messages you have submitted in order.
201 201
202 202 outstanding : set of msg_ids
203 203 a set of msg_ids that have been submitted, but whose
204 204 results have not yet been received.
205 205
206 206 results : dict
207 207 a dict of all our results, keyed by msg_id
208 208
209 209 block : bool
210 210 determines default behavior when block not specified
211 211 in execution methods
212 212
213 213 Methods
214 214 -------
215 215
216 216 spin
217 217 flushes incoming results and registration state changes
218 218 control methods spin, and requesting `ids` also ensures up to date
219 219
220 220 wait
221 221 wait on one or more msg_ids
222 222
223 223 execution methods
224 224 apply
225 225 legacy: execute, run
226 226
227 227 data movement
228 228 push, pull, scatter, gather
229 229
230 230 query methods
231 231 queue_status, get_result, purge, result_status
232 232
233 233 control methods
234 234 abort, shutdown
235 235
236 236 """
237 237
238 238
239 239 block = Bool(False)
240 240 outstanding = Set()
241 241 results = Instance('collections.defaultdict', (dict,))
242 242 metadata = Instance('collections.defaultdict', (Metadata,))
243 243 history = List()
244 244 debug = Bool(False)
245 245
246 246 profile=Unicode()
247 247 def _profile_default(self):
248 248 if BaseIPythonApplication.initialized():
249 249 # an IPython app *might* be running, try to get its profile
250 250 try:
251 251 return BaseIPythonApplication.instance().profile
252 252 except (AttributeError, MultipleInstanceError):
253 253 # could be a *different* subclass of config.Application,
254 254 # which would raise one of these two errors.
255 255 return u'default'
256 256 else:
257 257 return u'default'
258 258
259 259
260 260 _outstanding_dict = Instance('collections.defaultdict', (set,))
261 261 _ids = List()
262 262 _connected=Bool(False)
263 263 _ssh=Bool(False)
264 264 _context = Instance('zmq.Context')
265 265 _config = Dict()
266 266 _engines=Instance(util.ReverseDict, (), {})
267 267 # _hub_socket=Instance('zmq.Socket')
268 268 _query_socket=Instance('zmq.Socket')
269 269 _control_socket=Instance('zmq.Socket')
270 270 _iopub_socket=Instance('zmq.Socket')
271 271 _notification_socket=Instance('zmq.Socket')
272 272 _mux_socket=Instance('zmq.Socket')
273 273 _task_socket=Instance('zmq.Socket')
274 274 _task_scheme=Unicode()
275 275 _closed = False
276 276 _ignored_control_replies=Int(0)
277 277 _ignored_hub_replies=Int(0)
278 278
279 279 def __new__(self, *args, **kw):
280 280 # don't raise on positional args
281 281 return HasTraits.__new__(self, **kw)
282 282
283 283 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
284 284 context=None, debug=False, exec_key=None,
285 285 sshserver=None, sshkey=None, password=None, paramiko=None,
286 286 timeout=10, **extra_args
287 287 ):
288 288 if profile:
289 289 super(Client, self).__init__(debug=debug, profile=profile)
290 290 else:
291 291 super(Client, self).__init__(debug=debug)
292 292 if context is None:
293 293 context = zmq.Context.instance()
294 294 self._context = context
295 295
296 296 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
297 297 if self._cd is not None:
298 298 if url_or_file is None:
299 299 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
300 300 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
301 301 " Please specify at least one of url_or_file or profile."
302 302
303 303 try:
304 304 util.validate_url(url_or_file)
305 305 except AssertionError:
306 306 if not os.path.exists(url_or_file):
307 307 if self._cd:
308 308 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
309 309 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
310 310 with open(url_or_file) as f:
311 311 cfg = json.loads(f.read())
312 312 else:
313 313 cfg = {'url':url_or_file}
314 314
315 315 # sync defaults from args, json:
316 316 if sshserver:
317 317 cfg['ssh'] = sshserver
318 318 if exec_key:
319 319 cfg['exec_key'] = exec_key
320 320 exec_key = cfg['exec_key']
321 321 sshserver=cfg['ssh']
322 322 url = cfg['url']
323 323 location = cfg.setdefault('location', None)
324 324 cfg['url'] = util.disambiguate_url(cfg['url'], location)
325 325 url = cfg['url']
326 326 if location is not None:
327 327 proto,addr,port = util.split_url(url)
328 328 if addr == '127.0.0.1' and location not in LOCAL_IPS and not sshserver:
329 sshserver = location
330 warnings.warn(
331 "Controller appears to be listening on localhost, but is not local. "
332 "IPython will try to use SSH tunnels to %s"%location,
329 warnings.warn("""
330 Controller appears to be listening on localhost, but not on this machine.
331 If this is true, you should specify Client(...,sshserver='you@%s')
332 or instruct your controller to listen on an external IP."""%location,
333 333 RuntimeWarning)
334 334
335 335 self._config = cfg
336 336
337 337 self._ssh = bool(sshserver or sshkey or password)
338 338 if self._ssh and sshserver is None:
339 339 # default to ssh via localhost
340 340 sshserver = url.split('://')[1].split(':')[0]
341 341 if self._ssh and password is None:
342 342 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
343 343 password=False
344 344 else:
345 345 password = getpass("SSH Password for %s: "%sshserver)
346 346 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
347 347
348 348 # configure and construct the session
349 349 if exec_key is not None:
350 350 if os.path.isfile(exec_key):
351 351 extra_args['keyfile'] = exec_key
352 352 else:
353 353 if isinstance(exec_key, unicode):
354 354 exec_key = exec_key.encode('ascii')
355 355 extra_args['key'] = exec_key
356 356 self.session = Session(**extra_args)
357 357
358 358 self._query_socket = self._context.socket(zmq.XREQ)
359 359 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
360 360 if self._ssh:
361 361 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
362 362 else:
363 363 self._query_socket.connect(url)
364 364
365 365 self.session.debug = self.debug
366 366
367 367 self._notification_handlers = {'registration_notification' : self._register_engine,
368 368 'unregistration_notification' : self._unregister_engine,
369 369 'shutdown_notification' : lambda msg: self.close(),
370 370 }
371 371 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
372 372 'apply_reply' : self._handle_apply_reply}
373 373 self._connect(sshserver, ssh_kwargs, timeout)
374 374
375 375 def __del__(self):
376 376 """cleanup sockets, but _not_ context."""
377 377 self.close()
378 378
379 379 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
380 380 if ipython_dir is None:
381 381 ipython_dir = get_ipython_dir()
382 382 if profile_dir is not None:
383 383 try:
384 384 self._cd = ProfileDir.find_profile_dir(profile_dir)
385 385 return
386 386 except ProfileDirError:
387 387 pass
388 388 elif profile is not None:
389 389 try:
390 390 self._cd = ProfileDir.find_profile_dir_by_name(
391 391 ipython_dir, profile)
392 392 return
393 393 except ProfileDirError:
394 394 pass
395 395 self._cd = None
396 396
397 397 def _update_engines(self, engines):
398 398 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
399 399 for k,v in engines.iteritems():
400 400 eid = int(k)
401 401 self._engines[eid] = bytes(v) # force not unicode
402 402 self._ids.append(eid)
403 403 self._ids = sorted(self._ids)
404 404 if sorted(self._engines.keys()) != range(len(self._engines)) and \
405 405 self._task_scheme == 'pure' and self._task_socket:
406 406 self._stop_scheduling_tasks()
407 407
408 408 def _stop_scheduling_tasks(self):
409 409 """Stop scheduling tasks because an engine has been unregistered
410 410 from a pure ZMQ scheduler.
411 411 """
412 412 self._task_socket.close()
413 413 self._task_socket = None
414 414 msg = "An engine has been unregistered, and we are using pure " +\
415 415 "ZMQ task scheduling. Task farming will be disabled."
416 416 if self.outstanding:
417 417 msg += " If you were running tasks when this happened, " +\
418 418 "some `outstanding` msg_ids may never resolve."
419 419 warnings.warn(msg, RuntimeWarning)
420 420
421 421 def _build_targets(self, targets):
422 422 """Turn valid target IDs or 'all' into two lists:
423 423 (int_ids, uuids).
424 424 """
425 425 if not self._ids:
426 426 # flush notification socket if no engines yet, just in case
427 427 if not self.ids:
428 428 raise error.NoEnginesRegistered("Can't build targets without any engines")
429 429
430 430 if targets is None:
431 431 targets = self._ids
432 432 elif isinstance(targets, str):
433 433 if targets.lower() == 'all':
434 434 targets = self._ids
435 435 else:
436 436 raise TypeError("%r not valid str target, must be 'all'"%(targets))
437 437 elif isinstance(targets, int):
438 438 if targets < 0:
439 439 targets = self.ids[targets]
440 440 if targets not in self._ids:
441 441 raise IndexError("No such engine: %i"%targets)
442 442 targets = [targets]
443 443
444 444 if isinstance(targets, slice):
445 445 indices = range(len(self._ids))[targets]
446 446 ids = self.ids
447 447 targets = [ ids[i] for i in indices ]
448 448
449 449 if not isinstance(targets, (tuple, list, xrange)):
450 450 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
451 451
452 452 return [self._engines[t] for t in targets], list(targets)
453 453
454 454 def _connect(self, sshserver, ssh_kwargs, timeout):
455 455 """setup all our socket connections to the cluster. This is called from
456 456 __init__."""
457 457
458 458 # Maybe allow reconnecting?
459 459 if self._connected:
460 460 return
461 461 self._connected=True
462 462
463 463 def connect_socket(s, url):
464 464 url = util.disambiguate_url(url, self._config['location'])
465 465 if self._ssh:
466 466 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
467 467 else:
468 468 return s.connect(url)
469 469
470 470 self.session.send(self._query_socket, 'connection_request')
471 471 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
472 472 poller = zmq.Poller()
473 473 poller.register(self._query_socket, zmq.POLLIN)
474 474 # poll expects milliseconds, timeout is seconds
475 475 evts = poller.poll(timeout*1000)
476 476 if not evts:
477 477 raise error.TimeoutError("Hub connection request timed out")
478 478 idents,msg = self.session.recv(self._query_socket,mode=0)
479 479 if self.debug:
480 480 pprint(msg)
481 481 msg = Message(msg)
482 482 content = msg.content
483 483 self._config['registration'] = dict(content)
484 484 if content.status == 'ok':
485 485 if content.mux:
486 486 self._mux_socket = self._context.socket(zmq.XREQ)
487 487 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
488 488 connect_socket(self._mux_socket, content.mux)
489 489 if content.task:
490 490 self._task_scheme, task_addr = content.task
491 491 self._task_socket = self._context.socket(zmq.XREQ)
492 492 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
493 493 connect_socket(self._task_socket, task_addr)
494 494 if content.notification:
495 495 self._notification_socket = self._context.socket(zmq.SUB)
496 496 connect_socket(self._notification_socket, content.notification)
497 497 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
498 498 # if content.query:
499 499 # self._query_socket = self._context.socket(zmq.XREQ)
500 500 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
501 501 # connect_socket(self._query_socket, content.query)
502 502 if content.control:
503 503 self._control_socket = self._context.socket(zmq.XREQ)
504 504 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
505 505 connect_socket(self._control_socket, content.control)
506 506 if content.iopub:
507 507 self._iopub_socket = self._context.socket(zmq.SUB)
508 508 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
509 509 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
510 510 connect_socket(self._iopub_socket, content.iopub)
511 511 self._update_engines(dict(content.engines))
512 512 else:
513 513 self._connected = False
514 514 raise Exception("Failed to connect!")
515 515
516 516 #--------------------------------------------------------------------------
517 517 # handlers and callbacks for incoming messages
518 518 #--------------------------------------------------------------------------
519 519
520 520 def _unwrap_exception(self, content):
521 521 """unwrap exception, and remap engine_id to int."""
522 522 e = error.unwrap_exception(content)
523 523 # print e.traceback
524 524 if e.engine_info:
525 525 e_uuid = e.engine_info['engine_uuid']
526 526 eid = self._engines[e_uuid]
527 527 e.engine_info['engine_id'] = eid
528 528 return e
529 529
530 530 def _extract_metadata(self, header, parent, content):
531 531 md = {'msg_id' : parent['msg_id'],
532 532 'received' : datetime.now(),
533 533 'engine_uuid' : header.get('engine', None),
534 534 'follow' : parent.get('follow', []),
535 535 'after' : parent.get('after', []),
536 536 'status' : content['status'],
537 537 }
538 538
539 539 if md['engine_uuid'] is not None:
540 540 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
541 541
542 542 if 'date' in parent:
543 543 md['submitted'] = parent['date']
544 544 if 'started' in header:
545 545 md['started'] = header['started']
546 546 if 'date' in header:
547 547 md['completed'] = header['date']
548 548 return md
549 549
550 550 def _register_engine(self, msg):
551 551 """Register a new engine, and update our connection info."""
552 552 content = msg['content']
553 553 eid = content['id']
554 554 d = {eid : content['queue']}
555 555 self._update_engines(d)
556 556
557 557 def _unregister_engine(self, msg):
558 558 """Unregister an engine that has died."""
559 559 content = msg['content']
560 560 eid = int(content['id'])
561 561 if eid in self._ids:
562 562 self._ids.remove(eid)
563 563 uuid = self._engines.pop(eid)
564 564
565 565 self._handle_stranded_msgs(eid, uuid)
566 566
567 567 if self._task_socket and self._task_scheme == 'pure':
568 568 self._stop_scheduling_tasks()
569 569
570 570 def _handle_stranded_msgs(self, eid, uuid):
571 571 """Handle messages known to be on an engine when the engine unregisters.
572 572
573 573 It is possible that this will fire prematurely - that is, an engine will
574 574 go down after completing a result, and the client will be notified
575 575 of the unregistration and later receive the successful result.
576 576 """
577 577
578 578 outstanding = self._outstanding_dict[uuid]
579 579
580 580 for msg_id in list(outstanding):
581 581 if msg_id in self.results:
582 582 # we already
583 583 continue
584 584 try:
585 585 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
586 586 except:
587 587 content = error.wrap_exception()
588 588 # build a fake message:
589 589 parent = {}
590 590 header = {}
591 591 parent['msg_id'] = msg_id
592 592 header['engine'] = uuid
593 593 header['date'] = datetime.now()
594 594 msg = dict(parent_header=parent, header=header, content=content)
595 595 self._handle_apply_reply(msg)
596 596
597 597 def _handle_execute_reply(self, msg):
598 598 """Save the reply to an execute_request into our results.
599 599
600 600 execute messages are never actually used. apply is used instead.
601 601 """
602 602
603 603 parent = msg['parent_header']
604 604 msg_id = parent['msg_id']
605 605 if msg_id not in self.outstanding:
606 606 if msg_id in self.history:
607 607 print ("got stale result: %s"%msg_id)
608 608 else:
609 609 print ("got unknown result: %s"%msg_id)
610 610 else:
611 611 self.outstanding.remove(msg_id)
612 612 self.results[msg_id] = self._unwrap_exception(msg['content'])
613 613
614 614 def _handle_apply_reply(self, msg):
615 615 """Save the reply to an apply_request into our results."""
616 616 parent = msg['parent_header']
617 617 msg_id = parent['msg_id']
618 618 if msg_id not in self.outstanding:
619 619 if msg_id in self.history:
620 620 print ("got stale result: %s"%msg_id)
621 621 print self.results[msg_id]
622 622 print msg
623 623 else:
624 624 print ("got unknown result: %s"%msg_id)
625 625 else:
626 626 self.outstanding.remove(msg_id)
627 627 content = msg['content']
628 628 header = msg['header']
629 629
630 630 # construct metadata:
631 631 md = self.metadata[msg_id]
632 632 md.update(self._extract_metadata(header, parent, content))
633 633 # is this redundant?
634 634 self.metadata[msg_id] = md
635 635
636 636 e_outstanding = self._outstanding_dict[md['engine_uuid']]
637 637 if msg_id in e_outstanding:
638 638 e_outstanding.remove(msg_id)
639 639
640 640 # construct result:
641 641 if content['status'] == 'ok':
642 642 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
643 643 elif content['status'] == 'aborted':
644 644 self.results[msg_id] = error.TaskAborted(msg_id)
645 645 elif content['status'] == 'resubmitted':
646 646 # TODO: handle resubmission
647 647 pass
648 648 else:
649 649 self.results[msg_id] = self._unwrap_exception(content)
650 650
651 651 def _flush_notifications(self):
652 652 """Flush notifications of engine registrations waiting
653 653 in ZMQ queue."""
654 654 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
655 655 while msg is not None:
656 656 if self.debug:
657 657 pprint(msg)
658 658 msg_type = msg['msg_type']
659 659 handler = self._notification_handlers.get(msg_type, None)
660 660 if handler is None:
661 661 raise Exception("Unhandled message type: %s"%msg.msg_type)
662 662 else:
663 663 handler(msg)
664 664 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
665 665
666 666 def _flush_results(self, sock):
667 667 """Flush task or queue results waiting in ZMQ queue."""
668 668 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
669 669 while msg is not None:
670 670 if self.debug:
671 671 pprint(msg)
672 672 msg_type = msg['msg_type']
673 673 handler = self._queue_handlers.get(msg_type, None)
674 674 if handler is None:
675 675 raise Exception("Unhandled message type: %s"%msg.msg_type)
676 676 else:
677 677 handler(msg)
678 678 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
679 679
680 680 def _flush_control(self, sock):
681 681 """Flush replies from the control channel waiting
682 682 in the ZMQ queue.
683 683
684 684 Currently: ignore them."""
685 685 if self._ignored_control_replies <= 0:
686 686 return
687 687 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
688 688 while msg is not None:
689 689 self._ignored_control_replies -= 1
690 690 if self.debug:
691 691 pprint(msg)
692 692 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
693 693
694 694 def _flush_ignored_control(self):
695 695 """flush ignored control replies"""
696 696 while self._ignored_control_replies > 0:
697 697 self.session.recv(self._control_socket)
698 698 self._ignored_control_replies -= 1
699 699
700 700 def _flush_ignored_hub_replies(self):
701 701 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
702 702 while msg is not None:
703 703 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
704 704
705 705 def _flush_iopub(self, sock):
706 706 """Flush replies from the iopub channel waiting
707 707 in the ZMQ queue.
708 708 """
709 709 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
710 710 while msg is not None:
711 711 if self.debug:
712 712 pprint(msg)
713 713 parent = msg['parent_header']
714 714 msg_id = parent['msg_id']
715 715 content = msg['content']
716 716 header = msg['header']
717 717 msg_type = msg['msg_type']
718 718
719 719 # init metadata:
720 720 md = self.metadata[msg_id]
721 721
722 722 if msg_type == 'stream':
723 723 name = content['name']
724 724 s = md[name] or ''
725 725 md[name] = s + content['data']
726 726 elif msg_type == 'pyerr':
727 727 md.update({'pyerr' : self._unwrap_exception(content)})
728 728 elif msg_type == 'pyin':
729 729 md.update({'pyin' : content['code']})
730 730 else:
731 731 md.update({msg_type : content.get('data', '')})
732 732
733 733 # reduntant?
734 734 self.metadata[msg_id] = md
735 735
736 736 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
737 737
738 738 #--------------------------------------------------------------------------
739 739 # len, getitem
740 740 #--------------------------------------------------------------------------
741 741
742 742 def __len__(self):
743 743 """len(client) returns # of engines."""
744 744 return len(self.ids)
745 745
746 746 def __getitem__(self, key):
747 747 """index access returns DirectView multiplexer objects
748 748
749 749 Must be int, slice, or list/tuple/xrange of ints"""
750 750 if not isinstance(key, (int, slice, tuple, list, xrange)):
751 751 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
752 752 else:
753 753 return self.direct_view(key)
754 754
755 755 #--------------------------------------------------------------------------
756 756 # Begin public methods
757 757 #--------------------------------------------------------------------------
758 758
759 759 @property
760 760 def ids(self):
761 761 """Always up-to-date ids property."""
762 762 self._flush_notifications()
763 763 # always copy:
764 764 return list(self._ids)
765 765
766 766 def close(self):
767 767 if self._closed:
768 768 return
769 769 snames = filter(lambda n: n.endswith('socket'), dir(self))
770 770 for socket in map(lambda name: getattr(self, name), snames):
771 771 if isinstance(socket, zmq.Socket) and not socket.closed:
772 772 socket.close()
773 773 self._closed = True
774 774
775 775 def spin(self):
776 776 """Flush any registration notifications and execution results
777 777 waiting in the ZMQ queue.
778 778 """
779 779 if self._notification_socket:
780 780 self._flush_notifications()
781 781 if self._mux_socket:
782 782 self._flush_results(self._mux_socket)
783 783 if self._task_socket:
784 784 self._flush_results(self._task_socket)
785 785 if self._control_socket:
786 786 self._flush_control(self._control_socket)
787 787 if self._iopub_socket:
788 788 self._flush_iopub(self._iopub_socket)
789 789 if self._query_socket:
790 790 self._flush_ignored_hub_replies()
791 791
792 792 def wait(self, jobs=None, timeout=-1):
793 793 """waits on one or more `jobs`, for up to `timeout` seconds.
794 794
795 795 Parameters
796 796 ----------
797 797
798 798 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
799 799 ints are indices to self.history
800 800 strs are msg_ids
801 801 default: wait on all outstanding messages
802 802 timeout : float
803 803 a time in seconds, after which to give up.
804 804 default is -1, which means no timeout
805 805
806 806 Returns
807 807 -------
808 808
809 809 True : when all msg_ids are done
810 810 False : timeout reached, some msg_ids still outstanding
811 811 """
812 812 tic = time.time()
813 813 if jobs is None:
814 814 theids = self.outstanding
815 815 else:
816 816 if isinstance(jobs, (int, str, AsyncResult)):
817 817 jobs = [jobs]
818 818 theids = set()
819 819 for job in jobs:
820 820 if isinstance(job, int):
821 821 # index access
822 822 job = self.history[job]
823 823 elif isinstance(job, AsyncResult):
824 824 map(theids.add, job.msg_ids)
825 825 continue
826 826 theids.add(job)
827 827 if not theids.intersection(self.outstanding):
828 828 return True
829 829 self.spin()
830 830 while theids.intersection(self.outstanding):
831 831 if timeout >= 0 and ( time.time()-tic ) > timeout:
832 832 break
833 833 time.sleep(1e-3)
834 834 self.spin()
835 835 return len(theids.intersection(self.outstanding)) == 0
836 836
837 837 #--------------------------------------------------------------------------
838 838 # Control methods
839 839 #--------------------------------------------------------------------------
840 840
841 841 @spin_first
842 842 def clear(self, targets=None, block=None):
843 843 """Clear the namespace in target(s)."""
844 844 block = self.block if block is None else block
845 845 targets = self._build_targets(targets)[0]
846 846 for t in targets:
847 847 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
848 848 error = False
849 849 if block:
850 850 self._flush_ignored_control()
851 851 for i in range(len(targets)):
852 852 idents,msg = self.session.recv(self._control_socket,0)
853 853 if self.debug:
854 854 pprint(msg)
855 855 if msg['content']['status'] != 'ok':
856 856 error = self._unwrap_exception(msg['content'])
857 857 else:
858 858 self._ignored_control_replies += len(targets)
859 859 if error:
860 860 raise error
861 861
862 862
863 863 @spin_first
864 864 def abort(self, jobs=None, targets=None, block=None):
865 865 """Abort specific jobs from the execution queues of target(s).
866 866
867 867 This is a mechanism to prevent jobs that have already been submitted
868 868 from executing.
869 869
870 870 Parameters
871 871 ----------
872 872
873 873 jobs : msg_id, list of msg_ids, or AsyncResult
874 874 The jobs to be aborted
875 875
876 876
877 877 """
878 878 block = self.block if block is None else block
879 879 targets = self._build_targets(targets)[0]
880 880 msg_ids = []
881 881 if isinstance(jobs, (basestring,AsyncResult)):
882 882 jobs = [jobs]
883 883 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
884 884 if bad_ids:
885 885 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
886 886 for j in jobs:
887 887 if isinstance(j, AsyncResult):
888 888 msg_ids.extend(j.msg_ids)
889 889 else:
890 890 msg_ids.append(j)
891 891 content = dict(msg_ids=msg_ids)
892 892 for t in targets:
893 893 self.session.send(self._control_socket, 'abort_request',
894 894 content=content, ident=t)
895 895 error = False
896 896 if block:
897 897 self._flush_ignored_control()
898 898 for i in range(len(targets)):
899 899 idents,msg = self.session.recv(self._control_socket,0)
900 900 if self.debug:
901 901 pprint(msg)
902 902 if msg['content']['status'] != 'ok':
903 903 error = self._unwrap_exception(msg['content'])
904 904 else:
905 905 self._ignored_control_replies += len(targets)
906 906 if error:
907 907 raise error
908 908
909 909 @spin_first
910 910 def shutdown(self, targets=None, restart=False, hub=False, block=None):
911 911 """Terminates one or more engine processes, optionally including the hub."""
912 912 block = self.block if block is None else block
913 913 if hub:
914 914 targets = 'all'
915 915 targets = self._build_targets(targets)[0]
916 916 for t in targets:
917 917 self.session.send(self._control_socket, 'shutdown_request',
918 918 content={'restart':restart},ident=t)
919 919 error = False
920 920 if block or hub:
921 921 self._flush_ignored_control()
922 922 for i in range(len(targets)):
923 923 idents,msg = self.session.recv(self._control_socket, 0)
924 924 if self.debug:
925 925 pprint(msg)
926 926 if msg['content']['status'] != 'ok':
927 927 error = self._unwrap_exception(msg['content'])
928 928 else:
929 929 self._ignored_control_replies += len(targets)
930 930
931 931 if hub:
932 932 time.sleep(0.25)
933 933 self.session.send(self._query_socket, 'shutdown_request')
934 934 idents,msg = self.session.recv(self._query_socket, 0)
935 935 if self.debug:
936 936 pprint(msg)
937 937 if msg['content']['status'] != 'ok':
938 938 error = self._unwrap_exception(msg['content'])
939 939
940 940 if error:
941 941 raise error
942 942
943 943 #--------------------------------------------------------------------------
944 944 # Execution related methods
945 945 #--------------------------------------------------------------------------
946 946
947 947 def _maybe_raise(self, result):
948 948 """wrapper for maybe raising an exception if apply failed."""
949 949 if isinstance(result, error.RemoteError):
950 950 raise result
951 951
952 952 return result
953 953
954 954 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
955 955 ident=None):
956 956 """construct and send an apply message via a socket.
957 957
958 958 This is the principal method with which all engine execution is performed by views.
959 959 """
960 960
961 961 assert not self._closed, "cannot use me anymore, I'm closed!"
962 962 # defaults:
963 963 args = args if args is not None else []
964 964 kwargs = kwargs if kwargs is not None else {}
965 965 subheader = subheader if subheader is not None else {}
966 966
967 967 # validate arguments
968 968 if not callable(f):
969 969 raise TypeError("f must be callable, not %s"%type(f))
970 970 if not isinstance(args, (tuple, list)):
971 971 raise TypeError("args must be tuple or list, not %s"%type(args))
972 972 if not isinstance(kwargs, dict):
973 973 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
974 974 if not isinstance(subheader, dict):
975 975 raise TypeError("subheader must be dict, not %s"%type(subheader))
976 976
977 977 bufs = util.pack_apply_message(f,args,kwargs)
978 978
979 979 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
980 980 subheader=subheader, track=track)
981 981
982 982 msg_id = msg['msg_id']
983 983 self.outstanding.add(msg_id)
984 984 if ident:
985 985 # possibly routed to a specific engine
986 986 if isinstance(ident, list):
987 987 ident = ident[-1]
988 988 if ident in self._engines.values():
989 989 # save for later, in case of engine death
990 990 self._outstanding_dict[ident].add(msg_id)
991 991 self.history.append(msg_id)
992 992 self.metadata[msg_id]['submitted'] = datetime.now()
993 993
994 994 return msg
995 995
996 996 #--------------------------------------------------------------------------
997 997 # construct a View object
998 998 #--------------------------------------------------------------------------
999 999
1000 1000 def load_balanced_view(self, targets=None):
1001 1001 """construct a DirectView object.
1002 1002
1003 1003 If no arguments are specified, create a LoadBalancedView
1004 1004 using all engines.
1005 1005
1006 1006 Parameters
1007 1007 ----------
1008 1008
1009 1009 targets: list,slice,int,etc. [default: use all engines]
1010 1010 The subset of engines across which to load-balance
1011 1011 """
1012 1012 if targets is not None:
1013 1013 targets = self._build_targets(targets)[1]
1014 1014 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1015 1015
1016 1016 def direct_view(self, targets='all'):
1017 1017 """construct a DirectView object.
1018 1018
1019 1019 If no targets are specified, create a DirectView
1020 1020 using all engines.
1021 1021
1022 1022 Parameters
1023 1023 ----------
1024 1024
1025 1025 targets: list,slice,int,etc. [default: use all engines]
1026 1026 The engines to use for the View
1027 1027 """
1028 1028 single = isinstance(targets, int)
1029 1029 targets = self._build_targets(targets)[1]
1030 1030 if single:
1031 1031 targets = targets[0]
1032 1032 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1033 1033
1034 1034 #--------------------------------------------------------------------------
1035 1035 # Query methods
1036 1036 #--------------------------------------------------------------------------
1037 1037
1038 1038 @spin_first
1039 1039 def get_result(self, indices_or_msg_ids=None, block=None):
1040 1040 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1041 1041
1042 1042 If the client already has the results, no request to the Hub will be made.
1043 1043
1044 1044 This is a convenient way to construct AsyncResult objects, which are wrappers
1045 1045 that include metadata about execution, and allow for awaiting results that
1046 1046 were not submitted by this Client.
1047 1047
1048 1048 It can also be a convenient way to retrieve the metadata associated with
1049 1049 blocking execution, since it always retrieves
1050 1050
1051 1051 Examples
1052 1052 --------
1053 1053 ::
1054 1054
1055 1055 In [10]: r = client.apply()
1056 1056
1057 1057 Parameters
1058 1058 ----------
1059 1059
1060 1060 indices_or_msg_ids : integer history index, str msg_id, or list of either
1061 1061 The indices or msg_ids of indices to be retrieved
1062 1062
1063 1063 block : bool
1064 1064 Whether to wait for the result to be done
1065 1065
1066 1066 Returns
1067 1067 -------
1068 1068
1069 1069 AsyncResult
1070 1070 A single AsyncResult object will always be returned.
1071 1071
1072 1072 AsyncHubResult
1073 1073 A subclass of AsyncResult that retrieves results from the Hub
1074 1074
1075 1075 """
1076 1076 block = self.block if block is None else block
1077 1077 if indices_or_msg_ids is None:
1078 1078 indices_or_msg_ids = -1
1079 1079
1080 1080 if not isinstance(indices_or_msg_ids, (list,tuple)):
1081 1081 indices_or_msg_ids = [indices_or_msg_ids]
1082 1082
1083 1083 theids = []
1084 1084 for id in indices_or_msg_ids:
1085 1085 if isinstance(id, int):
1086 1086 id = self.history[id]
1087 1087 if not isinstance(id, str):
1088 1088 raise TypeError("indices must be str or int, not %r"%id)
1089 1089 theids.append(id)
1090 1090
1091 1091 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1092 1092 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1093 1093
1094 1094 if remote_ids:
1095 1095 ar = AsyncHubResult(self, msg_ids=theids)
1096 1096 else:
1097 1097 ar = AsyncResult(self, msg_ids=theids)
1098 1098
1099 1099 if block:
1100 1100 ar.wait()
1101 1101
1102 1102 return ar
1103 1103
1104 1104 @spin_first
1105 1105 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1106 1106 """Resubmit one or more tasks.
1107 1107
1108 1108 in-flight tasks may not be resubmitted.
1109 1109
1110 1110 Parameters
1111 1111 ----------
1112 1112
1113 1113 indices_or_msg_ids : integer history index, str msg_id, or list of either
1114 1114 The indices or msg_ids of indices to be retrieved
1115 1115
1116 1116 block : bool
1117 1117 Whether to wait for the result to be done
1118 1118
1119 1119 Returns
1120 1120 -------
1121 1121
1122 1122 AsyncHubResult
1123 1123 A subclass of AsyncResult that retrieves results from the Hub
1124 1124
1125 1125 """
1126 1126 block = self.block if block is None else block
1127 1127 if indices_or_msg_ids is None:
1128 1128 indices_or_msg_ids = -1
1129 1129
1130 1130 if not isinstance(indices_or_msg_ids, (list,tuple)):
1131 1131 indices_or_msg_ids = [indices_or_msg_ids]
1132 1132
1133 1133 theids = []
1134 1134 for id in indices_or_msg_ids:
1135 1135 if isinstance(id, int):
1136 1136 id = self.history[id]
1137 1137 if not isinstance(id, str):
1138 1138 raise TypeError("indices must be str or int, not %r"%id)
1139 1139 theids.append(id)
1140 1140
1141 1141 for msg_id in theids:
1142 1142 self.outstanding.discard(msg_id)
1143 1143 if msg_id in self.history:
1144 1144 self.history.remove(msg_id)
1145 1145 self.results.pop(msg_id, None)
1146 1146 self.metadata.pop(msg_id, None)
1147 1147 content = dict(msg_ids = theids)
1148 1148
1149 1149 self.session.send(self._query_socket, 'resubmit_request', content)
1150 1150
1151 1151 zmq.select([self._query_socket], [], [])
1152 1152 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1153 1153 if self.debug:
1154 1154 pprint(msg)
1155 1155 content = msg['content']
1156 1156 if content['status'] != 'ok':
1157 1157 raise self._unwrap_exception(content)
1158 1158
1159 1159 ar = AsyncHubResult(self, msg_ids=theids)
1160 1160
1161 1161 if block:
1162 1162 ar.wait()
1163 1163
1164 1164 return ar
1165 1165
1166 1166 @spin_first
1167 1167 def result_status(self, msg_ids, status_only=True):
1168 1168 """Check on the status of the result(s) of the apply request with `msg_ids`.
1169 1169
1170 1170 If status_only is False, then the actual results will be retrieved, else
1171 1171 only the status of the results will be checked.
1172 1172
1173 1173 Parameters
1174 1174 ----------
1175 1175
1176 1176 msg_ids : list of msg_ids
1177 1177 if int:
1178 1178 Passed as index to self.history for convenience.
1179 1179 status_only : bool (default: True)
1180 1180 if False:
1181 1181 Retrieve the actual results of completed tasks.
1182 1182
1183 1183 Returns
1184 1184 -------
1185 1185
1186 1186 results : dict
1187 1187 There will always be the keys 'pending' and 'completed', which will
1188 1188 be lists of msg_ids that are incomplete or complete. If `status_only`
1189 1189 is False, then completed results will be keyed by their `msg_id`.
1190 1190 """
1191 1191 if not isinstance(msg_ids, (list,tuple)):
1192 1192 msg_ids = [msg_ids]
1193 1193
1194 1194 theids = []
1195 1195 for msg_id in msg_ids:
1196 1196 if isinstance(msg_id, int):
1197 1197 msg_id = self.history[msg_id]
1198 1198 if not isinstance(msg_id, basestring):
1199 1199 raise TypeError("msg_ids must be str, not %r"%msg_id)
1200 1200 theids.append(msg_id)
1201 1201
1202 1202 completed = []
1203 1203 local_results = {}
1204 1204
1205 1205 # comment this block out to temporarily disable local shortcut:
1206 1206 for msg_id in theids:
1207 1207 if msg_id in self.results:
1208 1208 completed.append(msg_id)
1209 1209 local_results[msg_id] = self.results[msg_id]
1210 1210 theids.remove(msg_id)
1211 1211
1212 1212 if theids: # some not locally cached
1213 1213 content = dict(msg_ids=theids, status_only=status_only)
1214 1214 msg = self.session.send(self._query_socket, "result_request", content=content)
1215 1215 zmq.select([self._query_socket], [], [])
1216 1216 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1217 1217 if self.debug:
1218 1218 pprint(msg)
1219 1219 content = msg['content']
1220 1220 if content['status'] != 'ok':
1221 1221 raise self._unwrap_exception(content)
1222 1222 buffers = msg['buffers']
1223 1223 else:
1224 1224 content = dict(completed=[],pending=[])
1225 1225
1226 1226 content['completed'].extend(completed)
1227 1227
1228 1228 if status_only:
1229 1229 return content
1230 1230
1231 1231 failures = []
1232 1232 # load cached results into result:
1233 1233 content.update(local_results)
1234 1234
1235 1235 # update cache with results:
1236 1236 for msg_id in sorted(theids):
1237 1237 if msg_id in content['completed']:
1238 1238 rec = content[msg_id]
1239 1239 parent = rec['header']
1240 1240 header = rec['result_header']
1241 1241 rcontent = rec['result_content']
1242 1242 iodict = rec['io']
1243 1243 if isinstance(rcontent, str):
1244 1244 rcontent = self.session.unpack(rcontent)
1245 1245
1246 1246 md = self.metadata[msg_id]
1247 1247 md.update(self._extract_metadata(header, parent, rcontent))
1248 1248 md.update(iodict)
1249 1249
1250 1250 if rcontent['status'] == 'ok':
1251 1251 res,buffers = util.unserialize_object(buffers)
1252 1252 else:
1253 1253 print rcontent
1254 1254 res = self._unwrap_exception(rcontent)
1255 1255 failures.append(res)
1256 1256
1257 1257 self.results[msg_id] = res
1258 1258 content[msg_id] = res
1259 1259
1260 1260 if len(theids) == 1 and failures:
1261 1261 raise failures[0]
1262 1262
1263 1263 error.collect_exceptions(failures, "result_status")
1264 1264 return content
1265 1265
1266 1266 @spin_first
1267 1267 def queue_status(self, targets='all', verbose=False):
1268 1268 """Fetch the status of engine queues.
1269 1269
1270 1270 Parameters
1271 1271 ----------
1272 1272
1273 1273 targets : int/str/list of ints/strs
1274 1274 the engines whose states are to be queried.
1275 1275 default : all
1276 1276 verbose : bool
1277 1277 Whether to return lengths only, or lists of ids for each element
1278 1278 """
1279 1279 engine_ids = self._build_targets(targets)[1]
1280 1280 content = dict(targets=engine_ids, verbose=verbose)
1281 1281 self.session.send(self._query_socket, "queue_request", content=content)
1282 1282 idents,msg = self.session.recv(self._query_socket, 0)
1283 1283 if self.debug:
1284 1284 pprint(msg)
1285 1285 content = msg['content']
1286 1286 status = content.pop('status')
1287 1287 if status != 'ok':
1288 1288 raise self._unwrap_exception(content)
1289 1289 content = rekey(content)
1290 1290 if isinstance(targets, int):
1291 1291 return content[targets]
1292 1292 else:
1293 1293 return content
1294 1294
1295 1295 @spin_first
1296 1296 def purge_results(self, jobs=[], targets=[]):
1297 1297 """Tell the Hub to forget results.
1298 1298
1299 1299 Individual results can be purged by msg_id, or the entire
1300 1300 history of specific targets can be purged.
1301 1301
1302 1302 Parameters
1303 1303 ----------
1304 1304
1305 1305 jobs : str or list of str or AsyncResult objects
1306 1306 the msg_ids whose results should be forgotten.
1307 1307 targets : int/str/list of ints/strs
1308 1308 The targets, by uuid or int_id, whose entire history is to be purged.
1309 1309 Use `targets='all'` to scrub everything from the Hub's memory.
1310 1310
1311 1311 default : None
1312 1312 """
1313 1313 if not targets and not jobs:
1314 1314 raise ValueError("Must specify at least one of `targets` and `jobs`")
1315 1315 if targets:
1316 1316 targets = self._build_targets(targets)[1]
1317 1317
1318 1318 # construct msg_ids from jobs
1319 1319 msg_ids = []
1320 1320 if isinstance(jobs, (basestring,AsyncResult)):
1321 1321 jobs = [jobs]
1322 1322 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1323 1323 if bad_ids:
1324 1324 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1325 1325 for j in jobs:
1326 1326 if isinstance(j, AsyncResult):
1327 1327 msg_ids.extend(j.msg_ids)
1328 1328 else:
1329 1329 msg_ids.append(j)
1330 1330
1331 1331 content = dict(targets=targets, msg_ids=msg_ids)
1332 1332 self.session.send(self._query_socket, "purge_request", content=content)
1333 1333 idents, msg = self.session.recv(self._query_socket, 0)
1334 1334 if self.debug:
1335 1335 pprint(msg)
1336 1336 content = msg['content']
1337 1337 if content['status'] != 'ok':
1338 1338 raise self._unwrap_exception(content)
1339 1339
1340 1340 @spin_first
1341 1341 def hub_history(self):
1342 1342 """Get the Hub's history
1343 1343
1344 1344 Just like the Client, the Hub has a history, which is a list of msg_ids.
1345 1345 This will contain the history of all clients, and, depending on configuration,
1346 1346 may contain history across multiple cluster sessions.
1347 1347
1348 1348 Any msg_id returned here is a valid argument to `get_result`.
1349 1349
1350 1350 Returns
1351 1351 -------
1352 1352
1353 1353 msg_ids : list of strs
1354 1354 list of all msg_ids, ordered by task submission time.
1355 1355 """
1356 1356
1357 1357 self.session.send(self._query_socket, "history_request", content={})
1358 1358 idents, msg = self.session.recv(self._query_socket, 0)
1359 1359
1360 1360 if self.debug:
1361 1361 pprint(msg)
1362 1362 content = msg['content']
1363 1363 if content['status'] != 'ok':
1364 1364 raise self._unwrap_exception(content)
1365 1365 else:
1366 1366 return content['history']
1367 1367
1368 1368 @spin_first
1369 1369 def db_query(self, query, keys=None):
1370 1370 """Query the Hub's TaskRecord database
1371 1371
1372 1372 This will return a list of task record dicts that match `query`
1373 1373
1374 1374 Parameters
1375 1375 ----------
1376 1376
1377 1377 query : mongodb query dict
1378 1378 The search dict. See mongodb query docs for details.
1379 1379 keys : list of strs [optional]
1380 1380 The subset of keys to be returned. The default is to fetch everything but buffers.
1381 1381 'msg_id' will *always* be included.
1382 1382 """
1383 1383 if isinstance(keys, basestring):
1384 1384 keys = [keys]
1385 1385 content = dict(query=query, keys=keys)
1386 1386 self.session.send(self._query_socket, "db_request", content=content)
1387 1387 idents, msg = self.session.recv(self._query_socket, 0)
1388 1388 if self.debug:
1389 1389 pprint(msg)
1390 1390 content = msg['content']
1391 1391 if content['status'] != 'ok':
1392 1392 raise self._unwrap_exception(content)
1393 1393
1394 1394 records = content['records']
1395 1395
1396 1396 buffer_lens = content['buffer_lens']
1397 1397 result_buffer_lens = content['result_buffer_lens']
1398 1398 buffers = msg['buffers']
1399 1399 has_bufs = buffer_lens is not None
1400 1400 has_rbufs = result_buffer_lens is not None
1401 1401 for i,rec in enumerate(records):
1402 1402 # relink buffers
1403 1403 if has_bufs:
1404 1404 blen = buffer_lens[i]
1405 1405 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1406 1406 if has_rbufs:
1407 1407 blen = result_buffer_lens[i]
1408 1408 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1409 1409
1410 1410 return records
1411 1411
1412 1412 __all__ = [ 'Client' ]
General Comments 0
You need to be logged in to leave comments. Login now