##// END OF EJS Templates
allow rc.direct_view('all') to be lazily-evaluated...
MinRK -
Show More
@@ -1,1431 +1,1435 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 import time
22 22 import warnings
23 23 from datetime import datetime
24 24 from getpass import getpass
25 25 from pprint import pprint
26 26
27 27 pjoin = os.path.join
28 28
29 29 import zmq
30 30 # from zmq.eventloop import ioloop, zmqstream
31 31
32 32 from IPython.config.configurable import MultipleInstanceError
33 33 from IPython.core.application import BaseIPythonApplication
34 34
35 35 from IPython.utils.jsonutil import rekey
36 36 from IPython.utils.localinterfaces import LOCAL_IPS
37 37 from IPython.utils.path import get_ipython_dir
38 38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
39 39 Dict, List, Bool, Set)
40 40 from IPython.external.decorator import decorator
41 41 from IPython.external.ssh import tunnel
42 42
43 43 from IPython.parallel import error
44 44 from IPython.parallel import util
45 45
46 46 from IPython.zmq.session import Session, Message
47 47
48 48 from .asyncresult import AsyncResult, AsyncHubResult
49 49 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 50 from .view import DirectView, LoadBalancedView
51 51
52 52 if sys.version_info[0] >= 3:
53 53 # xrange is used in a couple 'isinstance' tests in py2
54 54 # should be just 'range' in 3k
55 55 xrange = range
56 56
57 57 #--------------------------------------------------------------------------
58 58 # Decorators for Client methods
59 59 #--------------------------------------------------------------------------
60 60
61 61 @decorator
62 62 def spin_first(f, self, *args, **kwargs):
63 63 """Call spin() to sync state prior to calling the method."""
64 64 self.spin()
65 65 return f(self, *args, **kwargs)
66 66
67 67
68 68 #--------------------------------------------------------------------------
69 69 # Classes
70 70 #--------------------------------------------------------------------------
71 71
72 72 class Metadata(dict):
73 73 """Subclass of dict for initializing metadata values.
74 74
75 75 Attribute access works on keys.
76 76
77 77 These objects have a strict set of keys - errors will raise if you try
78 78 to add new keys.
79 79 """
80 80 def __init__(self, *args, **kwargs):
81 81 dict.__init__(self)
82 82 md = {'msg_id' : None,
83 83 'submitted' : None,
84 84 'started' : None,
85 85 'completed' : None,
86 86 'received' : None,
87 87 'engine_uuid' : None,
88 88 'engine_id' : None,
89 89 'follow' : None,
90 90 'after' : None,
91 91 'status' : None,
92 92
93 93 'pyin' : None,
94 94 'pyout' : None,
95 95 'pyerr' : None,
96 96 'stdout' : '',
97 97 'stderr' : '',
98 98 }
99 99 self.update(md)
100 100 self.update(dict(*args, **kwargs))
101 101
102 102 def __getattr__(self, key):
103 103 """getattr aliased to getitem"""
104 104 if key in self.iterkeys():
105 105 return self[key]
106 106 else:
107 107 raise AttributeError(key)
108 108
109 109 def __setattr__(self, key, value):
110 110 """setattr aliased to setitem, with strict"""
111 111 if key in self.iterkeys():
112 112 self[key] = value
113 113 else:
114 114 raise AttributeError(key)
115 115
116 116 def __setitem__(self, key, value):
117 117 """strict static key enforcement"""
118 118 if key in self.iterkeys():
119 119 dict.__setitem__(self, key, value)
120 120 else:
121 121 raise KeyError(key)
122 122
123 123
124 124 class Client(HasTraits):
125 125 """A semi-synchronous client to the IPython ZMQ cluster
126 126
127 127 Parameters
128 128 ----------
129 129
130 130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 131 Connection information for the Hub's registration. If a json connector
132 132 file is given, then likely no further configuration is necessary.
133 133 [Default: use profile]
134 134 profile : bytes
135 135 The name of the Cluster profile to be used to find connector information.
136 136 If run from an IPython application, the default profile will be the same
137 137 as the running application, otherwise it will be 'default'.
138 138 context : zmq.Context
139 139 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 140 debug : bool
141 141 flag for lots of message printing for debug purposes
142 142 timeout : int/float
143 143 time (in seconds) to wait for connection replies from the Hub
144 144 [Default: 10]
145 145
146 146 #-------------- session related args ----------------
147 147
148 148 config : Config object
149 149 If specified, this will be relayed to the Session for configuration
150 150 username : str
151 151 set username for the session object
152 152 packer : str (import_string) or callable
153 153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 154 function to serialize messages. Must support same input as
155 155 JSON, and output must be bytes.
156 156 You can pass a callable directly as `pack`
157 157 unpacker : str (import_string) or callable
158 158 The inverse of packer. Only necessary if packer is specified as *not* one
159 159 of 'json' or 'pickle'.
160 160
161 161 #-------------- ssh related args ----------------
162 162 # These are args for configuring the ssh tunnel to be used
163 163 # credentials are used to forward connections over ssh to the Controller
164 164 # Note that the ip given in `addr` needs to be relative to sshserver
165 165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 166 # and set sshserver as the same machine the Controller is on. However,
167 167 # the only requirement is that sshserver is able to see the Controller
168 168 # (i.e. is within the same trusted network).
169 169
170 170 sshserver : str
171 171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 172 If keyfile or password is specified, and this is not, it will default to
173 173 the ip given in addr.
174 174 sshkey : str; path to public ssh key file
175 175 This specifies a key to be used in ssh login, default None.
176 176 Regular default ssh keys will be used without specifying this argument.
177 177 password : str
178 178 Your ssh password to sshserver. Note that if this is left None,
179 179 you will be prompted for it if passwordless key based login is unavailable.
180 180 paramiko : bool
181 181 flag for whether to use paramiko instead of shell ssh for tunneling.
182 182 [default: True on win32, False else]
183 183
184 184 ------- exec authentication args -------
185 185 If even localhost is untrusted, you can have some protection against
186 186 unauthorized execution by signing messages with HMAC digests.
187 187 Messages are still sent as cleartext, so if someone can snoop your
188 188 loopback traffic this will not protect your privacy, but will prevent
189 189 unauthorized execution.
190 190
191 191 exec_key : str
192 192 an authentication key or file containing a key
193 193 default: None
194 194
195 195
196 196 Attributes
197 197 ----------
198 198
199 199 ids : list of int engine IDs
200 200 requesting the ids attribute always synchronizes
201 201 the registration state. To request ids without synchronization,
202 202 use semi-private _ids attributes.
203 203
204 204 history : list of msg_ids
205 205 a list of msg_ids, keeping track of all the execution
206 206 messages you have submitted in order.
207 207
208 208 outstanding : set of msg_ids
209 209 a set of msg_ids that have been submitted, but whose
210 210 results have not yet been received.
211 211
212 212 results : dict
213 213 a dict of all our results, keyed by msg_id
214 214
215 215 block : bool
216 216 determines default behavior when block not specified
217 217 in execution methods
218 218
219 219 Methods
220 220 -------
221 221
222 222 spin
223 223 flushes incoming results and registration state changes
224 224 control methods spin, and requesting `ids` also ensures up to date
225 225
226 226 wait
227 227 wait on one or more msg_ids
228 228
229 229 execution methods
230 230 apply
231 231 legacy: execute, run
232 232
233 233 data movement
234 234 push, pull, scatter, gather
235 235
236 236 query methods
237 237 queue_status, get_result, purge, result_status
238 238
239 239 control methods
240 240 abort, shutdown
241 241
242 242 """
243 243
244 244
245 245 block = Bool(False)
246 246 outstanding = Set()
247 247 results = Instance('collections.defaultdict', (dict,))
248 248 metadata = Instance('collections.defaultdict', (Metadata,))
249 249 history = List()
250 250 debug = Bool(False)
251 251
252 252 profile=Unicode()
253 253 def _profile_default(self):
254 254 if BaseIPythonApplication.initialized():
255 255 # an IPython app *might* be running, try to get its profile
256 256 try:
257 257 return BaseIPythonApplication.instance().profile
258 258 except (AttributeError, MultipleInstanceError):
259 259 # could be a *different* subclass of config.Application,
260 260 # which would raise one of these two errors.
261 261 return u'default'
262 262 else:
263 263 return u'default'
264 264
265 265
266 266 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 267 _ids = List()
268 268 _connected=Bool(False)
269 269 _ssh=Bool(False)
270 270 _context = Instance('zmq.Context')
271 271 _config = Dict()
272 272 _engines=Instance(util.ReverseDict, (), {})
273 273 # _hub_socket=Instance('zmq.Socket')
274 274 _query_socket=Instance('zmq.Socket')
275 275 _control_socket=Instance('zmq.Socket')
276 276 _iopub_socket=Instance('zmq.Socket')
277 277 _notification_socket=Instance('zmq.Socket')
278 278 _mux_socket=Instance('zmq.Socket')
279 279 _task_socket=Instance('zmq.Socket')
280 280 _task_scheme=Unicode()
281 281 _closed = False
282 282 _ignored_control_replies=Int(0)
283 283 _ignored_hub_replies=Int(0)
284 284
285 285 def __new__(self, *args, **kw):
286 286 # don't raise on positional args
287 287 return HasTraits.__new__(self, **kw)
288 288
289 289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 290 context=None, debug=False, exec_key=None,
291 291 sshserver=None, sshkey=None, password=None, paramiko=None,
292 292 timeout=10, **extra_args
293 293 ):
294 294 if profile:
295 295 super(Client, self).__init__(debug=debug, profile=profile)
296 296 else:
297 297 super(Client, self).__init__(debug=debug)
298 298 if context is None:
299 299 context = zmq.Context.instance()
300 300 self._context = context
301 301
302 302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 303 if self._cd is not None:
304 304 if url_or_file is None:
305 305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 307 " Please specify at least one of url_or_file or profile."
308 308
309 309 try:
310 310 util.validate_url(url_or_file)
311 311 except AssertionError:
312 312 if not os.path.exists(url_or_file):
313 313 if self._cd:
314 314 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
315 315 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
316 316 with open(url_or_file) as f:
317 317 cfg = json.loads(f.read())
318 318 else:
319 319 cfg = {'url':url_or_file}
320 320
321 321 # sync defaults from args, json:
322 322 if sshserver:
323 323 cfg['ssh'] = sshserver
324 324 if exec_key:
325 325 cfg['exec_key'] = exec_key
326 326 exec_key = cfg['exec_key']
327 327 location = cfg.setdefault('location', None)
328 328 cfg['url'] = util.disambiguate_url(cfg['url'], location)
329 329 url = cfg['url']
330 330 proto,addr,port = util.split_url(url)
331 331 if location is not None and addr == '127.0.0.1':
332 332 # location specified, and connection is expected to be local
333 333 if location not in LOCAL_IPS and not sshserver:
334 334 # load ssh from JSON *only* if the controller is not on
335 335 # this machine
336 336 sshserver=cfg['ssh']
337 337 if location not in LOCAL_IPS and not sshserver:
338 338 # warn if no ssh specified, but SSH is probably needed
339 339 # This is only a warning, because the most likely cause
340 340 # is a local Controller on a laptop whose IP is dynamic
341 341 warnings.warn("""
342 342 Controller appears to be listening on localhost, but not on this machine.
343 343 If this is true, you should specify Client(...,sshserver='you@%s')
344 344 or instruct your controller to listen on an external IP."""%location,
345 345 RuntimeWarning)
346 346 elif not sshserver:
347 347 # otherwise sync with cfg
348 348 sshserver = cfg['ssh']
349 349
350 350 self._config = cfg
351 351
352 352 self._ssh = bool(sshserver or sshkey or password)
353 353 if self._ssh and sshserver is None:
354 354 # default to ssh via localhost
355 355 sshserver = url.split('://')[1].split(':')[0]
356 356 if self._ssh and password is None:
357 357 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
358 358 password=False
359 359 else:
360 360 password = getpass("SSH Password for %s: "%sshserver)
361 361 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
362 362
363 363 # configure and construct the session
364 364 if exec_key is not None:
365 365 if os.path.isfile(exec_key):
366 366 extra_args['keyfile'] = exec_key
367 367 else:
368 368 exec_key = util.asbytes(exec_key)
369 369 extra_args['key'] = exec_key
370 370 self.session = Session(**extra_args)
371 371
372 372 self._query_socket = self._context.socket(zmq.XREQ)
373 373 self._query_socket.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
374 374 if self._ssh:
375 375 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
376 376 else:
377 377 self._query_socket.connect(url)
378 378
379 379 self.session.debug = self.debug
380 380
381 381 self._notification_handlers = {'registration_notification' : self._register_engine,
382 382 'unregistration_notification' : self._unregister_engine,
383 383 'shutdown_notification' : lambda msg: self.close(),
384 384 }
385 385 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
386 386 'apply_reply' : self._handle_apply_reply}
387 387 self._connect(sshserver, ssh_kwargs, timeout)
388 388
389 389 def __del__(self):
390 390 """cleanup sockets, but _not_ context."""
391 391 self.close()
392 392
393 393 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
394 394 if ipython_dir is None:
395 395 ipython_dir = get_ipython_dir()
396 396 if profile_dir is not None:
397 397 try:
398 398 self._cd = ProfileDir.find_profile_dir(profile_dir)
399 399 return
400 400 except ProfileDirError:
401 401 pass
402 402 elif profile is not None:
403 403 try:
404 404 self._cd = ProfileDir.find_profile_dir_by_name(
405 405 ipython_dir, profile)
406 406 return
407 407 except ProfileDirError:
408 408 pass
409 409 self._cd = None
410 410
411 411 def _update_engines(self, engines):
412 412 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
413 413 for k,v in engines.iteritems():
414 414 eid = int(k)
415 415 self._engines[eid] = v
416 416 self._ids.append(eid)
417 417 self._ids = sorted(self._ids)
418 418 if sorted(self._engines.keys()) != range(len(self._engines)) and \
419 419 self._task_scheme == 'pure' and self._task_socket:
420 420 self._stop_scheduling_tasks()
421 421
422 422 def _stop_scheduling_tasks(self):
423 423 """Stop scheduling tasks because an engine has been unregistered
424 424 from a pure ZMQ scheduler.
425 425 """
426 426 self._task_socket.close()
427 427 self._task_socket = None
428 428 msg = "An engine has been unregistered, and we are using pure " +\
429 429 "ZMQ task scheduling. Task farming will be disabled."
430 430 if self.outstanding:
431 431 msg += " If you were running tasks when this happened, " +\
432 432 "some `outstanding` msg_ids may never resolve."
433 433 warnings.warn(msg, RuntimeWarning)
434 434
435 435 def _build_targets(self, targets):
436 436 """Turn valid target IDs or 'all' into two lists:
437 437 (int_ids, uuids).
438 438 """
439 439 if not self._ids:
440 440 # flush notification socket if no engines yet, just in case
441 441 if not self.ids:
442 442 raise error.NoEnginesRegistered("Can't build targets without any engines")
443 443
444 444 if targets is None:
445 445 targets = self._ids
446 446 elif isinstance(targets, basestring):
447 447 if targets.lower() == 'all':
448 448 targets = self._ids
449 449 else:
450 450 raise TypeError("%r not valid str target, must be 'all'"%(targets))
451 451 elif isinstance(targets, int):
452 452 if targets < 0:
453 453 targets = self.ids[targets]
454 454 if targets not in self._ids:
455 455 raise IndexError("No such engine: %i"%targets)
456 456 targets = [targets]
457 457
458 458 if isinstance(targets, slice):
459 459 indices = range(len(self._ids))[targets]
460 460 ids = self.ids
461 461 targets = [ ids[i] for i in indices ]
462 462
463 463 if not isinstance(targets, (tuple, list, xrange)):
464 464 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
465 465
466 466 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
467 467
468 468 def _connect(self, sshserver, ssh_kwargs, timeout):
469 469 """setup all our socket connections to the cluster. This is called from
470 470 __init__."""
471 471
472 472 # Maybe allow reconnecting?
473 473 if self._connected:
474 474 return
475 475 self._connected=True
476 476
477 477 def connect_socket(s, url):
478 478 url = util.disambiguate_url(url, self._config['location'])
479 479 if self._ssh:
480 480 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
481 481 else:
482 482 return s.connect(url)
483 483
484 484 self.session.send(self._query_socket, 'connection_request')
485 485 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
486 486 poller = zmq.Poller()
487 487 poller.register(self._query_socket, zmq.POLLIN)
488 488 # poll expects milliseconds, timeout is seconds
489 489 evts = poller.poll(timeout*1000)
490 490 if not evts:
491 491 raise error.TimeoutError("Hub connection request timed out")
492 492 idents,msg = self.session.recv(self._query_socket,mode=0)
493 493 if self.debug:
494 494 pprint(msg)
495 495 msg = Message(msg)
496 496 content = msg.content
497 497 self._config['registration'] = dict(content)
498 498 if content.status == 'ok':
499 499 ident = util.asbytes(self.session.session)
500 500 if content.mux:
501 501 self._mux_socket = self._context.socket(zmq.XREQ)
502 502 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
503 503 connect_socket(self._mux_socket, content.mux)
504 504 if content.task:
505 505 self._task_scheme, task_addr = content.task
506 506 self._task_socket = self._context.socket(zmq.XREQ)
507 507 self._task_socket.setsockopt(zmq.IDENTITY, ident)
508 508 connect_socket(self._task_socket, task_addr)
509 509 if content.notification:
510 510 self._notification_socket = self._context.socket(zmq.SUB)
511 511 connect_socket(self._notification_socket, content.notification)
512 512 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
513 513 # if content.query:
514 514 # self._query_socket = self._context.socket(zmq.XREQ)
515 515 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
516 516 # connect_socket(self._query_socket, content.query)
517 517 if content.control:
518 518 self._control_socket = self._context.socket(zmq.XREQ)
519 519 self._control_socket.setsockopt(zmq.IDENTITY, ident)
520 520 connect_socket(self._control_socket, content.control)
521 521 if content.iopub:
522 522 self._iopub_socket = self._context.socket(zmq.SUB)
523 523 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
524 524 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
525 525 connect_socket(self._iopub_socket, content.iopub)
526 526 self._update_engines(dict(content.engines))
527 527 else:
528 528 self._connected = False
529 529 raise Exception("Failed to connect!")
530 530
531 531 #--------------------------------------------------------------------------
532 532 # handlers and callbacks for incoming messages
533 533 #--------------------------------------------------------------------------
534 534
535 535 def _unwrap_exception(self, content):
536 536 """unwrap exception, and remap engine_id to int."""
537 537 e = error.unwrap_exception(content)
538 538 # print e.traceback
539 539 if e.engine_info:
540 540 e_uuid = e.engine_info['engine_uuid']
541 541 eid = self._engines[e_uuid]
542 542 e.engine_info['engine_id'] = eid
543 543 return e
544 544
545 545 def _extract_metadata(self, header, parent, content):
546 546 md = {'msg_id' : parent['msg_id'],
547 547 'received' : datetime.now(),
548 548 'engine_uuid' : header.get('engine', None),
549 549 'follow' : parent.get('follow', []),
550 550 'after' : parent.get('after', []),
551 551 'status' : content['status'],
552 552 }
553 553
554 554 if md['engine_uuid'] is not None:
555 555 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
556 556
557 557 if 'date' in parent:
558 558 md['submitted'] = parent['date']
559 559 if 'started' in header:
560 560 md['started'] = header['started']
561 561 if 'date' in header:
562 562 md['completed'] = header['date']
563 563 return md
564 564
565 565 def _register_engine(self, msg):
566 566 """Register a new engine, and update our connection info."""
567 567 content = msg['content']
568 568 eid = content['id']
569 569 d = {eid : content['queue']}
570 570 self._update_engines(d)
571 571
572 572 def _unregister_engine(self, msg):
573 573 """Unregister an engine that has died."""
574 574 content = msg['content']
575 575 eid = int(content['id'])
576 576 if eid in self._ids:
577 577 self._ids.remove(eid)
578 578 uuid = self._engines.pop(eid)
579 579
580 580 self._handle_stranded_msgs(eid, uuid)
581 581
582 582 if self._task_socket and self._task_scheme == 'pure':
583 583 self._stop_scheduling_tasks()
584 584
585 585 def _handle_stranded_msgs(self, eid, uuid):
586 586 """Handle messages known to be on an engine when the engine unregisters.
587 587
588 588 It is possible that this will fire prematurely - that is, an engine will
589 589 go down after completing a result, and the client will be notified
590 590 of the unregistration and later receive the successful result.
591 591 """
592 592
593 593 outstanding = self._outstanding_dict[uuid]
594 594
595 595 for msg_id in list(outstanding):
596 596 if msg_id in self.results:
597 597 # we already
598 598 continue
599 599 try:
600 600 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
601 601 except:
602 602 content = error.wrap_exception()
603 603 # build a fake message:
604 604 parent = {}
605 605 header = {}
606 606 parent['msg_id'] = msg_id
607 607 header['engine'] = uuid
608 608 header['date'] = datetime.now()
609 609 msg = dict(parent_header=parent, header=header, content=content)
610 610 self._handle_apply_reply(msg)
611 611
612 612 def _handle_execute_reply(self, msg):
613 613 """Save the reply to an execute_request into our results.
614 614
615 615 execute messages are never actually used. apply is used instead.
616 616 """
617 617
618 618 parent = msg['parent_header']
619 619 msg_id = parent['msg_id']
620 620 if msg_id not in self.outstanding:
621 621 if msg_id in self.history:
622 622 print ("got stale result: %s"%msg_id)
623 623 else:
624 624 print ("got unknown result: %s"%msg_id)
625 625 else:
626 626 self.outstanding.remove(msg_id)
627 627 self.results[msg_id] = self._unwrap_exception(msg['content'])
628 628
629 629 def _handle_apply_reply(self, msg):
630 630 """Save the reply to an apply_request into our results."""
631 631 parent = msg['parent_header']
632 632 msg_id = parent['msg_id']
633 633 if msg_id not in self.outstanding:
634 634 if msg_id in self.history:
635 635 print ("got stale result: %s"%msg_id)
636 636 print self.results[msg_id]
637 637 print msg
638 638 else:
639 639 print ("got unknown result: %s"%msg_id)
640 640 else:
641 641 self.outstanding.remove(msg_id)
642 642 content = msg['content']
643 643 header = msg['header']
644 644
645 645 # construct metadata:
646 646 md = self.metadata[msg_id]
647 647 md.update(self._extract_metadata(header, parent, content))
648 648 # is this redundant?
649 649 self.metadata[msg_id] = md
650 650
651 651 e_outstanding = self._outstanding_dict[md['engine_uuid']]
652 652 if msg_id in e_outstanding:
653 653 e_outstanding.remove(msg_id)
654 654
655 655 # construct result:
656 656 if content['status'] == 'ok':
657 657 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
658 658 elif content['status'] == 'aborted':
659 659 self.results[msg_id] = error.TaskAborted(msg_id)
660 660 elif content['status'] == 'resubmitted':
661 661 # TODO: handle resubmission
662 662 pass
663 663 else:
664 664 self.results[msg_id] = self._unwrap_exception(content)
665 665
666 666 def _flush_notifications(self):
667 667 """Flush notifications of engine registrations waiting
668 668 in ZMQ queue."""
669 669 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
670 670 while msg is not None:
671 671 if self.debug:
672 672 pprint(msg)
673 673 msg_type = msg['msg_type']
674 674 handler = self._notification_handlers.get(msg_type, None)
675 675 if handler is None:
676 676 raise Exception("Unhandled message type: %s"%msg.msg_type)
677 677 else:
678 678 handler(msg)
679 679 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
680 680
681 681 def _flush_results(self, sock):
682 682 """Flush task or queue results waiting in ZMQ queue."""
683 683 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
684 684 while msg is not None:
685 685 if self.debug:
686 686 pprint(msg)
687 687 msg_type = msg['msg_type']
688 688 handler = self._queue_handlers.get(msg_type, None)
689 689 if handler is None:
690 690 raise Exception("Unhandled message type: %s"%msg.msg_type)
691 691 else:
692 692 handler(msg)
693 693 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
694 694
695 695 def _flush_control(self, sock):
696 696 """Flush replies from the control channel waiting
697 697 in the ZMQ queue.
698 698
699 699 Currently: ignore them."""
700 700 if self._ignored_control_replies <= 0:
701 701 return
702 702 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
703 703 while msg is not None:
704 704 self._ignored_control_replies -= 1
705 705 if self.debug:
706 706 pprint(msg)
707 707 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
708 708
709 709 def _flush_ignored_control(self):
710 710 """flush ignored control replies"""
711 711 while self._ignored_control_replies > 0:
712 712 self.session.recv(self._control_socket)
713 713 self._ignored_control_replies -= 1
714 714
715 715 def _flush_ignored_hub_replies(self):
716 716 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
717 717 while msg is not None:
718 718 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
719 719
720 720 def _flush_iopub(self, sock):
721 721 """Flush replies from the iopub channel waiting
722 722 in the ZMQ queue.
723 723 """
724 724 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
725 725 while msg is not None:
726 726 if self.debug:
727 727 pprint(msg)
728 728 parent = msg['parent_header']
729 729 msg_id = parent['msg_id']
730 730 content = msg['content']
731 731 header = msg['header']
732 732 msg_type = msg['msg_type']
733 733
734 734 # init metadata:
735 735 md = self.metadata[msg_id]
736 736
737 737 if msg_type == 'stream':
738 738 name = content['name']
739 739 s = md[name] or ''
740 740 md[name] = s + content['data']
741 741 elif msg_type == 'pyerr':
742 742 md.update({'pyerr' : self._unwrap_exception(content)})
743 743 elif msg_type == 'pyin':
744 744 md.update({'pyin' : content['code']})
745 745 else:
746 746 md.update({msg_type : content.get('data', '')})
747 747
748 748 # reduntant?
749 749 self.metadata[msg_id] = md
750 750
751 751 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
752 752
753 753 #--------------------------------------------------------------------------
754 754 # len, getitem
755 755 #--------------------------------------------------------------------------
756 756
757 757 def __len__(self):
758 758 """len(client) returns # of engines."""
759 759 return len(self.ids)
760 760
761 761 def __getitem__(self, key):
762 762 """index access returns DirectView multiplexer objects
763 763
764 764 Must be int, slice, or list/tuple/xrange of ints"""
765 765 if not isinstance(key, (int, slice, tuple, list, xrange)):
766 766 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
767 767 else:
768 768 return self.direct_view(key)
769 769
770 770 #--------------------------------------------------------------------------
771 771 # Begin public methods
772 772 #--------------------------------------------------------------------------
773 773
774 774 @property
775 775 def ids(self):
776 776 """Always up-to-date ids property."""
777 777 self._flush_notifications()
778 778 # always copy:
779 779 return list(self._ids)
780 780
781 781 def close(self):
782 782 if self._closed:
783 783 return
784 784 snames = filter(lambda n: n.endswith('socket'), dir(self))
785 785 for socket in map(lambda name: getattr(self, name), snames):
786 786 if isinstance(socket, zmq.Socket) and not socket.closed:
787 787 socket.close()
788 788 self._closed = True
789 789
790 790 def spin(self):
791 791 """Flush any registration notifications and execution results
792 792 waiting in the ZMQ queue.
793 793 """
794 794 if self._notification_socket:
795 795 self._flush_notifications()
796 796 if self._mux_socket:
797 797 self._flush_results(self._mux_socket)
798 798 if self._task_socket:
799 799 self._flush_results(self._task_socket)
800 800 if self._control_socket:
801 801 self._flush_control(self._control_socket)
802 802 if self._iopub_socket:
803 803 self._flush_iopub(self._iopub_socket)
804 804 if self._query_socket:
805 805 self._flush_ignored_hub_replies()
806 806
807 807 def wait(self, jobs=None, timeout=-1):
808 808 """waits on one or more `jobs`, for up to `timeout` seconds.
809 809
810 810 Parameters
811 811 ----------
812 812
813 813 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
814 814 ints are indices to self.history
815 815 strs are msg_ids
816 816 default: wait on all outstanding messages
817 817 timeout : float
818 818 a time in seconds, after which to give up.
819 819 default is -1, which means no timeout
820 820
821 821 Returns
822 822 -------
823 823
824 824 True : when all msg_ids are done
825 825 False : timeout reached, some msg_ids still outstanding
826 826 """
827 827 tic = time.time()
828 828 if jobs is None:
829 829 theids = self.outstanding
830 830 else:
831 831 if isinstance(jobs, (int, basestring, AsyncResult)):
832 832 jobs = [jobs]
833 833 theids = set()
834 834 for job in jobs:
835 835 if isinstance(job, int):
836 836 # index access
837 837 job = self.history[job]
838 838 elif isinstance(job, AsyncResult):
839 839 map(theids.add, job.msg_ids)
840 840 continue
841 841 theids.add(job)
842 842 if not theids.intersection(self.outstanding):
843 843 return True
844 844 self.spin()
845 845 while theids.intersection(self.outstanding):
846 846 if timeout >= 0 and ( time.time()-tic ) > timeout:
847 847 break
848 848 time.sleep(1e-3)
849 849 self.spin()
850 850 return len(theids.intersection(self.outstanding)) == 0
851 851
852 852 #--------------------------------------------------------------------------
853 853 # Control methods
854 854 #--------------------------------------------------------------------------
855 855
856 856 @spin_first
857 857 def clear(self, targets=None, block=None):
858 858 """Clear the namespace in target(s)."""
859 859 block = self.block if block is None else block
860 860 targets = self._build_targets(targets)[0]
861 861 for t in targets:
862 862 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
863 863 error = False
864 864 if block:
865 865 self._flush_ignored_control()
866 866 for i in range(len(targets)):
867 867 idents,msg = self.session.recv(self._control_socket,0)
868 868 if self.debug:
869 869 pprint(msg)
870 870 if msg['content']['status'] != 'ok':
871 871 error = self._unwrap_exception(msg['content'])
872 872 else:
873 873 self._ignored_control_replies += len(targets)
874 874 if error:
875 875 raise error
876 876
877 877
878 878 @spin_first
879 879 def abort(self, jobs=None, targets=None, block=None):
880 880 """Abort specific jobs from the execution queues of target(s).
881 881
882 882 This is a mechanism to prevent jobs that have already been submitted
883 883 from executing.
884 884
885 885 Parameters
886 886 ----------
887 887
888 888 jobs : msg_id, list of msg_ids, or AsyncResult
889 889 The jobs to be aborted
890 890
891 891
892 892 """
893 893 block = self.block if block is None else block
894 894 targets = self._build_targets(targets)[0]
895 895 msg_ids = []
896 896 if isinstance(jobs, (basestring,AsyncResult)):
897 897 jobs = [jobs]
898 898 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
899 899 if bad_ids:
900 900 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
901 901 for j in jobs:
902 902 if isinstance(j, AsyncResult):
903 903 msg_ids.extend(j.msg_ids)
904 904 else:
905 905 msg_ids.append(j)
906 906 content = dict(msg_ids=msg_ids)
907 907 for t in targets:
908 908 self.session.send(self._control_socket, 'abort_request',
909 909 content=content, ident=t)
910 910 error = False
911 911 if block:
912 912 self._flush_ignored_control()
913 913 for i in range(len(targets)):
914 914 idents,msg = self.session.recv(self._control_socket,0)
915 915 if self.debug:
916 916 pprint(msg)
917 917 if msg['content']['status'] != 'ok':
918 918 error = self._unwrap_exception(msg['content'])
919 919 else:
920 920 self._ignored_control_replies += len(targets)
921 921 if error:
922 922 raise error
923 923
924 924 @spin_first
925 925 def shutdown(self, targets=None, restart=False, hub=False, block=None):
926 926 """Terminates one or more engine processes, optionally including the hub."""
927 927 block = self.block if block is None else block
928 928 if hub:
929 929 targets = 'all'
930 930 targets = self._build_targets(targets)[0]
931 931 for t in targets:
932 932 self.session.send(self._control_socket, 'shutdown_request',
933 933 content={'restart':restart},ident=t)
934 934 error = False
935 935 if block or hub:
936 936 self._flush_ignored_control()
937 937 for i in range(len(targets)):
938 938 idents,msg = self.session.recv(self._control_socket, 0)
939 939 if self.debug:
940 940 pprint(msg)
941 941 if msg['content']['status'] != 'ok':
942 942 error = self._unwrap_exception(msg['content'])
943 943 else:
944 944 self._ignored_control_replies += len(targets)
945 945
946 946 if hub:
947 947 time.sleep(0.25)
948 948 self.session.send(self._query_socket, 'shutdown_request')
949 949 idents,msg = self.session.recv(self._query_socket, 0)
950 950 if self.debug:
951 951 pprint(msg)
952 952 if msg['content']['status'] != 'ok':
953 953 error = self._unwrap_exception(msg['content'])
954 954
955 955 if error:
956 956 raise error
957 957
958 958 #--------------------------------------------------------------------------
959 959 # Execution related methods
960 960 #--------------------------------------------------------------------------
961 961
962 962 def _maybe_raise(self, result):
963 963 """wrapper for maybe raising an exception if apply failed."""
964 964 if isinstance(result, error.RemoteError):
965 965 raise result
966 966
967 967 return result
968 968
969 969 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
970 970 ident=None):
971 971 """construct and send an apply message via a socket.
972 972
973 973 This is the principal method with which all engine execution is performed by views.
974 974 """
975 975
976 976 assert not self._closed, "cannot use me anymore, I'm closed!"
977 977 # defaults:
978 978 args = args if args is not None else []
979 979 kwargs = kwargs if kwargs is not None else {}
980 980 subheader = subheader if subheader is not None else {}
981 981
982 982 # validate arguments
983 983 if not callable(f):
984 984 raise TypeError("f must be callable, not %s"%type(f))
985 985 if not isinstance(args, (tuple, list)):
986 986 raise TypeError("args must be tuple or list, not %s"%type(args))
987 987 if not isinstance(kwargs, dict):
988 988 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
989 989 if not isinstance(subheader, dict):
990 990 raise TypeError("subheader must be dict, not %s"%type(subheader))
991 991
992 992 bufs = util.pack_apply_message(f,args,kwargs)
993 993
994 994 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
995 995 subheader=subheader, track=track)
996 996
997 997 msg_id = msg['msg_id']
998 998 self.outstanding.add(msg_id)
999 999 if ident:
1000 1000 # possibly routed to a specific engine
1001 1001 if isinstance(ident, list):
1002 1002 ident = ident[-1]
1003 1003 if ident in self._engines.values():
1004 1004 # save for later, in case of engine death
1005 1005 self._outstanding_dict[ident].add(msg_id)
1006 1006 self.history.append(msg_id)
1007 1007 self.metadata[msg_id]['submitted'] = datetime.now()
1008 1008
1009 1009 return msg
1010 1010
1011 1011 #--------------------------------------------------------------------------
1012 1012 # construct a View object
1013 1013 #--------------------------------------------------------------------------
1014 1014
1015 1015 def load_balanced_view(self, targets=None):
1016 1016 """construct a DirectView object.
1017 1017
1018 1018 If no arguments are specified, create a LoadBalancedView
1019 1019 using all engines.
1020 1020
1021 1021 Parameters
1022 1022 ----------
1023 1023
1024 1024 targets: list,slice,int,etc. [default: use all engines]
1025 1025 The subset of engines across which to load-balance
1026 1026 """
1027 if targets == 'all':
1028 targets = None
1027 1029 if targets is not None:
1028 1030 targets = self._build_targets(targets)[1]
1029 1031 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1030 1032
1031 1033 def direct_view(self, targets='all'):
1032 1034 """construct a DirectView object.
1033 1035
1034 1036 If no targets are specified, create a DirectView
1035 1037 using all engines.
1036 1038
1037 1039 Parameters
1038 1040 ----------
1039 1041
1040 1042 targets: list,slice,int,etc. [default: use all engines]
1041 1043 The engines to use for the View
1042 1044 """
1043 1045 single = isinstance(targets, int)
1044 targets = self._build_targets(targets)[1]
1046 # allow 'all' to be lazily evaluated at each execution
1047 if targets != 'all':
1048 targets = self._build_targets(targets)[1]
1045 1049 if single:
1046 1050 targets = targets[0]
1047 1051 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1048 1052
1049 1053 #--------------------------------------------------------------------------
1050 1054 # Query methods
1051 1055 #--------------------------------------------------------------------------
1052 1056
1053 1057 @spin_first
1054 1058 def get_result(self, indices_or_msg_ids=None, block=None):
1055 1059 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1056 1060
1057 1061 If the client already has the results, no request to the Hub will be made.
1058 1062
1059 1063 This is a convenient way to construct AsyncResult objects, which are wrappers
1060 1064 that include metadata about execution, and allow for awaiting results that
1061 1065 were not submitted by this Client.
1062 1066
1063 1067 It can also be a convenient way to retrieve the metadata associated with
1064 1068 blocking execution, since it always retrieves
1065 1069
1066 1070 Examples
1067 1071 --------
1068 1072 ::
1069 1073
1070 1074 In [10]: r = client.apply()
1071 1075
1072 1076 Parameters
1073 1077 ----------
1074 1078
1075 1079 indices_or_msg_ids : integer history index, str msg_id, or list of either
1076 1080 The indices or msg_ids of indices to be retrieved
1077 1081
1078 1082 block : bool
1079 1083 Whether to wait for the result to be done
1080 1084
1081 1085 Returns
1082 1086 -------
1083 1087
1084 1088 AsyncResult
1085 1089 A single AsyncResult object will always be returned.
1086 1090
1087 1091 AsyncHubResult
1088 1092 A subclass of AsyncResult that retrieves results from the Hub
1089 1093
1090 1094 """
1091 1095 block = self.block if block is None else block
1092 1096 if indices_or_msg_ids is None:
1093 1097 indices_or_msg_ids = -1
1094 1098
1095 1099 if not isinstance(indices_or_msg_ids, (list,tuple)):
1096 1100 indices_or_msg_ids = [indices_or_msg_ids]
1097 1101
1098 1102 theids = []
1099 1103 for id in indices_or_msg_ids:
1100 1104 if isinstance(id, int):
1101 1105 id = self.history[id]
1102 1106 if not isinstance(id, basestring):
1103 1107 raise TypeError("indices must be str or int, not %r"%id)
1104 1108 theids.append(id)
1105 1109
1106 1110 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1107 1111 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1108 1112
1109 1113 if remote_ids:
1110 1114 ar = AsyncHubResult(self, msg_ids=theids)
1111 1115 else:
1112 1116 ar = AsyncResult(self, msg_ids=theids)
1113 1117
1114 1118 if block:
1115 1119 ar.wait()
1116 1120
1117 1121 return ar
1118 1122
1119 1123 @spin_first
1120 1124 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1121 1125 """Resubmit one or more tasks.
1122 1126
1123 1127 in-flight tasks may not be resubmitted.
1124 1128
1125 1129 Parameters
1126 1130 ----------
1127 1131
1128 1132 indices_or_msg_ids : integer history index, str msg_id, or list of either
1129 1133 The indices or msg_ids of indices to be retrieved
1130 1134
1131 1135 block : bool
1132 1136 Whether to wait for the result to be done
1133 1137
1134 1138 Returns
1135 1139 -------
1136 1140
1137 1141 AsyncHubResult
1138 1142 A subclass of AsyncResult that retrieves results from the Hub
1139 1143
1140 1144 """
1141 1145 block = self.block if block is None else block
1142 1146 if indices_or_msg_ids is None:
1143 1147 indices_or_msg_ids = -1
1144 1148
1145 1149 if not isinstance(indices_or_msg_ids, (list,tuple)):
1146 1150 indices_or_msg_ids = [indices_or_msg_ids]
1147 1151
1148 1152 theids = []
1149 1153 for id in indices_or_msg_ids:
1150 1154 if isinstance(id, int):
1151 1155 id = self.history[id]
1152 1156 if not isinstance(id, basestring):
1153 1157 raise TypeError("indices must be str or int, not %r"%id)
1154 1158 theids.append(id)
1155 1159
1156 1160 for msg_id in theids:
1157 1161 self.outstanding.discard(msg_id)
1158 1162 if msg_id in self.history:
1159 1163 self.history.remove(msg_id)
1160 1164 self.results.pop(msg_id, None)
1161 1165 self.metadata.pop(msg_id, None)
1162 1166 content = dict(msg_ids = theids)
1163 1167
1164 1168 self.session.send(self._query_socket, 'resubmit_request', content)
1165 1169
1166 1170 zmq.select([self._query_socket], [], [])
1167 1171 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1168 1172 if self.debug:
1169 1173 pprint(msg)
1170 1174 content = msg['content']
1171 1175 if content['status'] != 'ok':
1172 1176 raise self._unwrap_exception(content)
1173 1177
1174 1178 ar = AsyncHubResult(self, msg_ids=theids)
1175 1179
1176 1180 if block:
1177 1181 ar.wait()
1178 1182
1179 1183 return ar
1180 1184
1181 1185 @spin_first
1182 1186 def result_status(self, msg_ids, status_only=True):
1183 1187 """Check on the status of the result(s) of the apply request with `msg_ids`.
1184 1188
1185 1189 If status_only is False, then the actual results will be retrieved, else
1186 1190 only the status of the results will be checked.
1187 1191
1188 1192 Parameters
1189 1193 ----------
1190 1194
1191 1195 msg_ids : list of msg_ids
1192 1196 if int:
1193 1197 Passed as index to self.history for convenience.
1194 1198 status_only : bool (default: True)
1195 1199 if False:
1196 1200 Retrieve the actual results of completed tasks.
1197 1201
1198 1202 Returns
1199 1203 -------
1200 1204
1201 1205 results : dict
1202 1206 There will always be the keys 'pending' and 'completed', which will
1203 1207 be lists of msg_ids that are incomplete or complete. If `status_only`
1204 1208 is False, then completed results will be keyed by their `msg_id`.
1205 1209 """
1206 1210 if not isinstance(msg_ids, (list,tuple)):
1207 1211 msg_ids = [msg_ids]
1208 1212
1209 1213 theids = []
1210 1214 for msg_id in msg_ids:
1211 1215 if isinstance(msg_id, int):
1212 1216 msg_id = self.history[msg_id]
1213 1217 if not isinstance(msg_id, basestring):
1214 1218 raise TypeError("msg_ids must be str, not %r"%msg_id)
1215 1219 theids.append(msg_id)
1216 1220
1217 1221 completed = []
1218 1222 local_results = {}
1219 1223
1220 1224 # comment this block out to temporarily disable local shortcut:
1221 1225 for msg_id in theids:
1222 1226 if msg_id in self.results:
1223 1227 completed.append(msg_id)
1224 1228 local_results[msg_id] = self.results[msg_id]
1225 1229 theids.remove(msg_id)
1226 1230
1227 1231 if theids: # some not locally cached
1228 1232 content = dict(msg_ids=theids, status_only=status_only)
1229 1233 msg = self.session.send(self._query_socket, "result_request", content=content)
1230 1234 zmq.select([self._query_socket], [], [])
1231 1235 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1232 1236 if self.debug:
1233 1237 pprint(msg)
1234 1238 content = msg['content']
1235 1239 if content['status'] != 'ok':
1236 1240 raise self._unwrap_exception(content)
1237 1241 buffers = msg['buffers']
1238 1242 else:
1239 1243 content = dict(completed=[],pending=[])
1240 1244
1241 1245 content['completed'].extend(completed)
1242 1246
1243 1247 if status_only:
1244 1248 return content
1245 1249
1246 1250 failures = []
1247 1251 # load cached results into result:
1248 1252 content.update(local_results)
1249 1253
1250 1254 # update cache with results:
1251 1255 for msg_id in sorted(theids):
1252 1256 if msg_id in content['completed']:
1253 1257 rec = content[msg_id]
1254 1258 parent = rec['header']
1255 1259 header = rec['result_header']
1256 1260 rcontent = rec['result_content']
1257 1261 iodict = rec['io']
1258 1262 if isinstance(rcontent, str):
1259 1263 rcontent = self.session.unpack(rcontent)
1260 1264
1261 1265 md = self.metadata[msg_id]
1262 1266 md.update(self._extract_metadata(header, parent, rcontent))
1263 1267 md.update(iodict)
1264 1268
1265 1269 if rcontent['status'] == 'ok':
1266 1270 res,buffers = util.unserialize_object(buffers)
1267 1271 else:
1268 1272 print rcontent
1269 1273 res = self._unwrap_exception(rcontent)
1270 1274 failures.append(res)
1271 1275
1272 1276 self.results[msg_id] = res
1273 1277 content[msg_id] = res
1274 1278
1275 1279 if len(theids) == 1 and failures:
1276 1280 raise failures[0]
1277 1281
1278 1282 error.collect_exceptions(failures, "result_status")
1279 1283 return content
1280 1284
1281 1285 @spin_first
1282 1286 def queue_status(self, targets='all', verbose=False):
1283 1287 """Fetch the status of engine queues.
1284 1288
1285 1289 Parameters
1286 1290 ----------
1287 1291
1288 1292 targets : int/str/list of ints/strs
1289 1293 the engines whose states are to be queried.
1290 1294 default : all
1291 1295 verbose : bool
1292 1296 Whether to return lengths only, or lists of ids for each element
1293 1297 """
1294 1298 engine_ids = self._build_targets(targets)[1]
1295 1299 content = dict(targets=engine_ids, verbose=verbose)
1296 1300 self.session.send(self._query_socket, "queue_request", content=content)
1297 1301 idents,msg = self.session.recv(self._query_socket, 0)
1298 1302 if self.debug:
1299 1303 pprint(msg)
1300 1304 content = msg['content']
1301 1305 status = content.pop('status')
1302 1306 if status != 'ok':
1303 1307 raise self._unwrap_exception(content)
1304 1308 content = rekey(content)
1305 1309 if isinstance(targets, int):
1306 1310 return content[targets]
1307 1311 else:
1308 1312 return content
1309 1313
1310 1314 @spin_first
1311 1315 def purge_results(self, jobs=[], targets=[]):
1312 1316 """Tell the Hub to forget results.
1313 1317
1314 1318 Individual results can be purged by msg_id, or the entire
1315 1319 history of specific targets can be purged.
1316 1320
1317 1321 Use `purge_results('all')` to scrub everything from the Hub's db.
1318 1322
1319 1323 Parameters
1320 1324 ----------
1321 1325
1322 1326 jobs : str or list of str or AsyncResult objects
1323 1327 the msg_ids whose results should be forgotten.
1324 1328 targets : int/str/list of ints/strs
1325 1329 The targets, by int_id, whose entire history is to be purged.
1326 1330
1327 1331 default : None
1328 1332 """
1329 1333 if not targets and not jobs:
1330 1334 raise ValueError("Must specify at least one of `targets` and `jobs`")
1331 1335 if targets:
1332 1336 targets = self._build_targets(targets)[1]
1333 1337
1334 1338 # construct msg_ids from jobs
1335 1339 if jobs == 'all':
1336 1340 msg_ids = jobs
1337 1341 else:
1338 1342 msg_ids = []
1339 1343 if isinstance(jobs, (basestring,AsyncResult)):
1340 1344 jobs = [jobs]
1341 1345 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1342 1346 if bad_ids:
1343 1347 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1344 1348 for j in jobs:
1345 1349 if isinstance(j, AsyncResult):
1346 1350 msg_ids.extend(j.msg_ids)
1347 1351 else:
1348 1352 msg_ids.append(j)
1349 1353
1350 1354 content = dict(engine_ids=targets, msg_ids=msg_ids)
1351 1355 self.session.send(self._query_socket, "purge_request", content=content)
1352 1356 idents, msg = self.session.recv(self._query_socket, 0)
1353 1357 if self.debug:
1354 1358 pprint(msg)
1355 1359 content = msg['content']
1356 1360 if content['status'] != 'ok':
1357 1361 raise self._unwrap_exception(content)
1358 1362
1359 1363 @spin_first
1360 1364 def hub_history(self):
1361 1365 """Get the Hub's history
1362 1366
1363 1367 Just like the Client, the Hub has a history, which is a list of msg_ids.
1364 1368 This will contain the history of all clients, and, depending on configuration,
1365 1369 may contain history across multiple cluster sessions.
1366 1370
1367 1371 Any msg_id returned here is a valid argument to `get_result`.
1368 1372
1369 1373 Returns
1370 1374 -------
1371 1375
1372 1376 msg_ids : list of strs
1373 1377 list of all msg_ids, ordered by task submission time.
1374 1378 """
1375 1379
1376 1380 self.session.send(self._query_socket, "history_request", content={})
1377 1381 idents, msg = self.session.recv(self._query_socket, 0)
1378 1382
1379 1383 if self.debug:
1380 1384 pprint(msg)
1381 1385 content = msg['content']
1382 1386 if content['status'] != 'ok':
1383 1387 raise self._unwrap_exception(content)
1384 1388 else:
1385 1389 return content['history']
1386 1390
1387 1391 @spin_first
1388 1392 def db_query(self, query, keys=None):
1389 1393 """Query the Hub's TaskRecord database
1390 1394
1391 1395 This will return a list of task record dicts that match `query`
1392 1396
1393 1397 Parameters
1394 1398 ----------
1395 1399
1396 1400 query : mongodb query dict
1397 1401 The search dict. See mongodb query docs for details.
1398 1402 keys : list of strs [optional]
1399 1403 The subset of keys to be returned. The default is to fetch everything but buffers.
1400 1404 'msg_id' will *always* be included.
1401 1405 """
1402 1406 if isinstance(keys, basestring):
1403 1407 keys = [keys]
1404 1408 content = dict(query=query, keys=keys)
1405 1409 self.session.send(self._query_socket, "db_request", content=content)
1406 1410 idents, msg = self.session.recv(self._query_socket, 0)
1407 1411 if self.debug:
1408 1412 pprint(msg)
1409 1413 content = msg['content']
1410 1414 if content['status'] != 'ok':
1411 1415 raise self._unwrap_exception(content)
1412 1416
1413 1417 records = content['records']
1414 1418
1415 1419 buffer_lens = content['buffer_lens']
1416 1420 result_buffer_lens = content['result_buffer_lens']
1417 1421 buffers = msg['buffers']
1418 1422 has_bufs = buffer_lens is not None
1419 1423 has_rbufs = result_buffer_lens is not None
1420 1424 for i,rec in enumerate(records):
1421 1425 # relink buffers
1422 1426 if has_bufs:
1423 1427 blen = buffer_lens[i]
1424 1428 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1425 1429 if has_rbufs:
1426 1430 blen = result_buffer_lens[i]
1427 1431 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1428 1432
1429 1433 return records
1430 1434
1431 1435 __all__ = [ 'Client' ]
@@ -1,270 +1,279 b''
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import time
22 22 from datetime import datetime
23 23 from tempfile import mktemp
24 24
25 25 import zmq
26 26
27 27 from IPython.parallel.client import client as clientmod
28 28 from IPython.parallel import error
29 29 from IPython.parallel import AsyncResult, AsyncHubResult
30 30 from IPython.parallel import LoadBalancedView, DirectView
31 31
32 32 from clienttest import ClusterTestCase, segfault, wait, add_engines
33 33
34 34 def setup():
35 35 add_engines(4)
36 36
37 37 class TestClient(ClusterTestCase):
38 38
39 39 def test_ids(self):
40 40 n = len(self.client.ids)
41 41 self.add_engines(3)
42 42 self.assertEquals(len(self.client.ids), n+3)
43 43
44 44 def test_view_indexing(self):
45 45 """test index access for views"""
46 46 self.add_engines(2)
47 47 targets = self.client._build_targets('all')[-1]
48 48 v = self.client[:]
49 49 self.assertEquals(v.targets, targets)
50 50 t = self.client.ids[2]
51 51 v = self.client[t]
52 52 self.assert_(isinstance(v, DirectView))
53 53 self.assertEquals(v.targets, t)
54 54 t = self.client.ids[2:4]
55 55 v = self.client[t]
56 56 self.assert_(isinstance(v, DirectView))
57 57 self.assertEquals(v.targets, t)
58 58 v = self.client[::2]
59 59 self.assert_(isinstance(v, DirectView))
60 60 self.assertEquals(v.targets, targets[::2])
61 61 v = self.client[1::3]
62 62 self.assert_(isinstance(v, DirectView))
63 63 self.assertEquals(v.targets, targets[1::3])
64 64 v = self.client[:-3]
65 65 self.assert_(isinstance(v, DirectView))
66 66 self.assertEquals(v.targets, targets[:-3])
67 67 v = self.client[-1]
68 68 self.assert_(isinstance(v, DirectView))
69 69 self.assertEquals(v.targets, targets[-1])
70 70 self.assertRaises(TypeError, lambda : self.client[None])
71 71
72 72 def test_lbview_targets(self):
73 73 """test load_balanced_view targets"""
74 74 v = self.client.load_balanced_view()
75 75 self.assertEquals(v.targets, None)
76 76 v = self.client.load_balanced_view(-1)
77 77 self.assertEquals(v.targets, [self.client.ids[-1]])
78 78 v = self.client.load_balanced_view('all')
79 self.assertEquals(v.targets, self.client.ids)
79 self.assertEquals(v.targets, None)
80
81 def test_dview_targets(self):
82 """test load_balanced_view targets"""
83 v = self.client.direct_view()
84 self.assertEquals(v.targets, 'all')
85 v = self.client.direct_view('all')
86 self.assertEquals(v.targets, 'all')
87 v = self.client.direct_view(-1)
88 self.assertEquals(v.targets, self.client.ids[-1])
80 89
81 90 def test_targets(self):
82 91 """test various valid targets arguments"""
83 92 build = self.client._build_targets
84 93 ids = self.client.ids
85 94 idents,targets = build(None)
86 95 self.assertEquals(ids, targets)
87 96
88 97 def test_clear(self):
89 98 """test clear behavior"""
90 99 # self.add_engines(2)
91 100 v = self.client[:]
92 101 v.block=True
93 102 v.push(dict(a=5))
94 103 v.pull('a')
95 104 id0 = self.client.ids[-1]
96 105 self.client.clear(targets=id0, block=True)
97 106 a = self.client[:-1].get('a')
98 107 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
99 108 self.client.clear(block=True)
100 109 for i in self.client.ids:
101 110 # print i
102 111 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
103 112
104 113 def test_get_result(self):
105 114 """test getting results from the Hub."""
106 115 c = clientmod.Client(profile='iptest')
107 116 # self.add_engines(1)
108 117 t = c.ids[-1]
109 118 ar = c[t].apply_async(wait, 1)
110 119 # give the monitor time to notice the message
111 120 time.sleep(.25)
112 121 ahr = self.client.get_result(ar.msg_ids)
113 122 self.assertTrue(isinstance(ahr, AsyncHubResult))
114 123 self.assertEquals(ahr.get(), ar.get())
115 124 ar2 = self.client.get_result(ar.msg_ids)
116 125 self.assertFalse(isinstance(ar2, AsyncHubResult))
117 126 c.close()
118 127
119 128 def test_ids_list(self):
120 129 """test client.ids"""
121 130 # self.add_engines(2)
122 131 ids = self.client.ids
123 132 self.assertEquals(ids, self.client._ids)
124 133 self.assertFalse(ids is self.client._ids)
125 134 ids.remove(ids[-1])
126 135 self.assertNotEquals(ids, self.client._ids)
127 136
128 137 def test_queue_status(self):
129 138 # self.addEngine(4)
130 139 ids = self.client.ids
131 140 id0 = ids[0]
132 141 qs = self.client.queue_status(targets=id0)
133 142 self.assertTrue(isinstance(qs, dict))
134 143 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
135 144 allqs = self.client.queue_status()
136 145 self.assertTrue(isinstance(allqs, dict))
137 146 intkeys = list(allqs.keys())
138 147 intkeys.remove('unassigned')
139 148 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
140 149 unassigned = allqs.pop('unassigned')
141 150 for eid,qs in allqs.items():
142 151 self.assertTrue(isinstance(qs, dict))
143 152 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
144 153
145 154 def test_shutdown(self):
146 155 # self.addEngine(4)
147 156 ids = self.client.ids
148 157 id0 = ids[0]
149 158 self.client.shutdown(id0, block=True)
150 159 while id0 in self.client.ids:
151 160 time.sleep(0.1)
152 161 self.client.spin()
153 162
154 163 self.assertRaises(IndexError, lambda : self.client[id0])
155 164
156 165 def test_result_status(self):
157 166 pass
158 167 # to be written
159 168
160 169 def test_db_query_dt(self):
161 170 """test db query by date"""
162 171 hist = self.client.hub_history()
163 172 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
164 173 tic = middle['submitted']
165 174 before = self.client.db_query({'submitted' : {'$lt' : tic}})
166 175 after = self.client.db_query({'submitted' : {'$gte' : tic}})
167 176 self.assertEquals(len(before)+len(after),len(hist))
168 177 for b in before:
169 178 self.assertTrue(b['submitted'] < tic)
170 179 for a in after:
171 180 self.assertTrue(a['submitted'] >= tic)
172 181 same = self.client.db_query({'submitted' : tic})
173 182 for s in same:
174 183 self.assertTrue(s['submitted'] == tic)
175 184
176 185 def test_db_query_keys(self):
177 186 """test extracting subset of record keys"""
178 187 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
179 188 for rec in found:
180 189 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
181 190
182 191 def test_db_query_msg_id(self):
183 192 """ensure msg_id is always in db queries"""
184 193 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
185 194 for rec in found:
186 195 self.assertTrue('msg_id' in rec.keys())
187 196 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
188 197 for rec in found:
189 198 self.assertTrue('msg_id' in rec.keys())
190 199 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
191 200 for rec in found:
192 201 self.assertTrue('msg_id' in rec.keys())
193 202
194 203 def test_db_query_in(self):
195 204 """test db query with '$in','$nin' operators"""
196 205 hist = self.client.hub_history()
197 206 even = hist[::2]
198 207 odd = hist[1::2]
199 208 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
200 209 found = [ r['msg_id'] for r in recs ]
201 210 self.assertEquals(set(even), set(found))
202 211 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
203 212 found = [ r['msg_id'] for r in recs ]
204 213 self.assertEquals(set(odd), set(found))
205 214
206 215 def test_hub_history(self):
207 216 hist = self.client.hub_history()
208 217 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
209 218 recdict = {}
210 219 for rec in recs:
211 220 recdict[rec['msg_id']] = rec
212 221
213 222 latest = datetime(1984,1,1)
214 223 for msg_id in hist:
215 224 rec = recdict[msg_id]
216 225 newt = rec['submitted']
217 226 self.assertTrue(newt >= latest)
218 227 latest = newt
219 228 ar = self.client[-1].apply_async(lambda : 1)
220 229 ar.get()
221 230 time.sleep(0.25)
222 231 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
223 232
224 233 def test_resubmit(self):
225 234 def f():
226 235 import random
227 236 return random.random()
228 237 v = self.client.load_balanced_view()
229 238 ar = v.apply_async(f)
230 239 r1 = ar.get(1)
231 240 ahr = self.client.resubmit(ar.msg_ids)
232 241 r2 = ahr.get(1)
233 242 self.assertFalse(r1 == r2)
234 243
235 244 def test_resubmit_inflight(self):
236 245 """ensure ValueError on resubmit of inflight task"""
237 246 v = self.client.load_balanced_view()
238 247 ar = v.apply_async(time.sleep,1)
239 248 # give the message a chance to arrive
240 249 time.sleep(0.2)
241 250 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
242 251 ar.get(2)
243 252
244 253 def test_resubmit_badkey(self):
245 254 """ensure KeyError on resubmit of nonexistant task"""
246 255 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
247 256
248 257 def test_purge_results(self):
249 258 # ensure there are some tasks
250 259 for i in range(5):
251 260 self.client[:].apply_sync(lambda : 1)
252 261 # Wait for the Hub to realise the result is done:
253 262 # This prevents a race condition, where we
254 263 # might purge a result the Hub still thinks is pending.
255 264 time.sleep(0.1)
256 265 rc2 = clientmod.Client(profile='iptest')
257 266 hist = self.client.hub_history()
258 267 ahr = rc2.get_result([hist[-1]])
259 268 ahr.wait(10)
260 269 self.client.purge_results(hist[-1])
261 270 newhist = self.client.hub_history()
262 271 self.assertEquals(len(newhist)+1,len(hist))
263 272 rc2.spin()
264 273 rc2.close()
265 274
266 275 def test_purge_all_results(self):
267 276 self.client.purge_results('all')
268 277 hist = self.client.hub_history()
269 278 self.assertEquals(len(hist), 0)
270 279
General Comments 0
You need to be logged in to leave comments. Login now