##// END OF EJS Templates
handle targets='all' in remotefunction...
MinRK -
Show More
@@ -1,1434 +1,1440 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 import time
22 22 import warnings
23 23 from datetime import datetime
24 24 from getpass import getpass
25 25 from pprint import pprint
26 26
27 27 pjoin = os.path.join
28 28
29 29 import zmq
30 30 # from zmq.eventloop import ioloop, zmqstream
31 31
32 32 from IPython.config.configurable import MultipleInstanceError
33 33 from IPython.core.application import BaseIPythonApplication
34 34
35 35 from IPython.utils.jsonutil import rekey
36 36 from IPython.utils.localinterfaces import LOCAL_IPS
37 37 from IPython.utils.path import get_ipython_dir
38 38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
39 39 Dict, List, Bool, Set)
40 40 from IPython.external.decorator import decorator
41 41 from IPython.external.ssh import tunnel
42 42
43 43 from IPython.parallel import error
44 44 from IPython.parallel import util
45 45
46 46 from IPython.zmq.session import Session, Message
47 47
48 48 from .asyncresult import AsyncResult, AsyncHubResult
49 49 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 50 from .view import DirectView, LoadBalancedView
51 51
52 52 if sys.version_info[0] >= 3:
53 53 # xrange is used in a couple 'isinstance' tests in py2
54 54 # should be just 'range' in 3k
55 55 xrange = range
56 56
57 57 #--------------------------------------------------------------------------
58 58 # Decorators for Client methods
59 59 #--------------------------------------------------------------------------
60 60
61 61 @decorator
62 62 def spin_first(f, self, *args, **kwargs):
63 63 """Call spin() to sync state prior to calling the method."""
64 64 self.spin()
65 65 return f(self, *args, **kwargs)
66 66
67 67
68 68 #--------------------------------------------------------------------------
69 69 # Classes
70 70 #--------------------------------------------------------------------------
71 71
72 72 class Metadata(dict):
73 73 """Subclass of dict for initializing metadata values.
74 74
75 75 Attribute access works on keys.
76 76
77 77 These objects have a strict set of keys - errors will raise if you try
78 78 to add new keys.
79 79 """
80 80 def __init__(self, *args, **kwargs):
81 81 dict.__init__(self)
82 82 md = {'msg_id' : None,
83 83 'submitted' : None,
84 84 'started' : None,
85 85 'completed' : None,
86 86 'received' : None,
87 87 'engine_uuid' : None,
88 88 'engine_id' : None,
89 89 'follow' : None,
90 90 'after' : None,
91 91 'status' : None,
92 92
93 93 'pyin' : None,
94 94 'pyout' : None,
95 95 'pyerr' : None,
96 96 'stdout' : '',
97 97 'stderr' : '',
98 98 }
99 99 self.update(md)
100 100 self.update(dict(*args, **kwargs))
101 101
102 102 def __getattr__(self, key):
103 103 """getattr aliased to getitem"""
104 104 if key in self.iterkeys():
105 105 return self[key]
106 106 else:
107 107 raise AttributeError(key)
108 108
109 109 def __setattr__(self, key, value):
110 110 """setattr aliased to setitem, with strict"""
111 111 if key in self.iterkeys():
112 112 self[key] = value
113 113 else:
114 114 raise AttributeError(key)
115 115
116 116 def __setitem__(self, key, value):
117 117 """strict static key enforcement"""
118 118 if key in self.iterkeys():
119 119 dict.__setitem__(self, key, value)
120 120 else:
121 121 raise KeyError(key)
122 122
123 123
124 124 class Client(HasTraits):
125 125 """A semi-synchronous client to the IPython ZMQ cluster
126 126
127 127 Parameters
128 128 ----------
129 129
130 130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 131 Connection information for the Hub's registration. If a json connector
132 132 file is given, then likely no further configuration is necessary.
133 133 [Default: use profile]
134 134 profile : bytes
135 135 The name of the Cluster profile to be used to find connector information.
136 136 If run from an IPython application, the default profile will be the same
137 137 as the running application, otherwise it will be 'default'.
138 138 context : zmq.Context
139 139 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 140 debug : bool
141 141 flag for lots of message printing for debug purposes
142 142 timeout : int/float
143 143 time (in seconds) to wait for connection replies from the Hub
144 144 [Default: 10]
145 145
146 146 #-------------- session related args ----------------
147 147
148 148 config : Config object
149 149 If specified, this will be relayed to the Session for configuration
150 150 username : str
151 151 set username for the session object
152 152 packer : str (import_string) or callable
153 153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 154 function to serialize messages. Must support same input as
155 155 JSON, and output must be bytes.
156 156 You can pass a callable directly as `pack`
157 157 unpacker : str (import_string) or callable
158 158 The inverse of packer. Only necessary if packer is specified as *not* one
159 159 of 'json' or 'pickle'.
160 160
161 161 #-------------- ssh related args ----------------
162 162 # These are args for configuring the ssh tunnel to be used
163 163 # credentials are used to forward connections over ssh to the Controller
164 164 # Note that the ip given in `addr` needs to be relative to sshserver
165 165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 166 # and set sshserver as the same machine the Controller is on. However,
167 167 # the only requirement is that sshserver is able to see the Controller
168 168 # (i.e. is within the same trusted network).
169 169
170 170 sshserver : str
171 171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 172 If keyfile or password is specified, and this is not, it will default to
173 173 the ip given in addr.
174 174 sshkey : str; path to ssh private key file
175 175 This specifies a key to be used in ssh login, default None.
176 176 Regular default ssh keys will be used without specifying this argument.
177 177 password : str
178 178 Your ssh password to sshserver. Note that if this is left None,
179 179 you will be prompted for it if passwordless key based login is unavailable.
180 180 paramiko : bool
181 181 flag for whether to use paramiko instead of shell ssh for tunneling.
182 182 [default: True on win32, False else]
183 183
184 184 ------- exec authentication args -------
185 185 If even localhost is untrusted, you can have some protection against
186 186 unauthorized execution by signing messages with HMAC digests.
187 187 Messages are still sent as cleartext, so if someone can snoop your
188 188 loopback traffic this will not protect your privacy, but will prevent
189 189 unauthorized execution.
190 190
191 191 exec_key : str
192 192 an authentication key or file containing a key
193 193 default: None
194 194
195 195
196 196 Attributes
197 197 ----------
198 198
199 199 ids : list of int engine IDs
200 200 requesting the ids attribute always synchronizes
201 201 the registration state. To request ids without synchronization,
202 202 use semi-private _ids attributes.
203 203
204 204 history : list of msg_ids
205 205 a list of msg_ids, keeping track of all the execution
206 206 messages you have submitted in order.
207 207
208 208 outstanding : set of msg_ids
209 209 a set of msg_ids that have been submitted, but whose
210 210 results have not yet been received.
211 211
212 212 results : dict
213 213 a dict of all our results, keyed by msg_id
214 214
215 215 block : bool
216 216 determines default behavior when block not specified
217 217 in execution methods
218 218
219 219 Methods
220 220 -------
221 221
222 222 spin
223 223 flushes incoming results and registration state changes
224 224 control methods spin, and requesting `ids` also ensures up to date
225 225
226 226 wait
227 227 wait on one or more msg_ids
228 228
229 229 execution methods
230 230 apply
231 231 legacy: execute, run
232 232
233 233 data movement
234 234 push, pull, scatter, gather
235 235
236 236 query methods
237 237 queue_status, get_result, purge, result_status
238 238
239 239 control methods
240 240 abort, shutdown
241 241
242 242 """
243 243
244 244
245 245 block = Bool(False)
246 246 outstanding = Set()
247 247 results = Instance('collections.defaultdict', (dict,))
248 248 metadata = Instance('collections.defaultdict', (Metadata,))
249 249 history = List()
250 250 debug = Bool(False)
251 251
252 252 profile=Unicode()
253 253 def _profile_default(self):
254 254 if BaseIPythonApplication.initialized():
255 255 # an IPython app *might* be running, try to get its profile
256 256 try:
257 257 return BaseIPythonApplication.instance().profile
258 258 except (AttributeError, MultipleInstanceError):
259 259 # could be a *different* subclass of config.Application,
260 260 # which would raise one of these two errors.
261 261 return u'default'
262 262 else:
263 263 return u'default'
264 264
265 265
266 266 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 267 _ids = List()
268 268 _connected=Bool(False)
269 269 _ssh=Bool(False)
270 270 _context = Instance('zmq.Context')
271 271 _config = Dict()
272 272 _engines=Instance(util.ReverseDict, (), {})
273 273 # _hub_socket=Instance('zmq.Socket')
274 274 _query_socket=Instance('zmq.Socket')
275 275 _control_socket=Instance('zmq.Socket')
276 276 _iopub_socket=Instance('zmq.Socket')
277 277 _notification_socket=Instance('zmq.Socket')
278 278 _mux_socket=Instance('zmq.Socket')
279 279 _task_socket=Instance('zmq.Socket')
280 280 _task_scheme=Unicode()
281 281 _closed = False
282 282 _ignored_control_replies=Int(0)
283 283 _ignored_hub_replies=Int(0)
284 284
285 285 def __new__(self, *args, **kw):
286 286 # don't raise on positional args
287 287 return HasTraits.__new__(self, **kw)
288 288
289 289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 290 context=None, debug=False, exec_key=None,
291 291 sshserver=None, sshkey=None, password=None, paramiko=None,
292 292 timeout=10, **extra_args
293 293 ):
294 294 if profile:
295 295 super(Client, self).__init__(debug=debug, profile=profile)
296 296 else:
297 297 super(Client, self).__init__(debug=debug)
298 298 if context is None:
299 299 context = zmq.Context.instance()
300 300 self._context = context
301 301
302 302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 303 if self._cd is not None:
304 304 if url_or_file is None:
305 305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 307 " Please specify at least one of url_or_file or profile."
308 308
309 309 if not util.is_url(url_or_file):
310 310 # it's not a url, try for a file
311 311 if not os.path.exists(url_or_file):
312 312 if self._cd:
313 313 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
314 314 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
315 315 with open(url_or_file) as f:
316 316 cfg = json.loads(f.read())
317 317 else:
318 318 cfg = {'url':url_or_file}
319 319
320 320 # sync defaults from args, json:
321 321 if sshserver:
322 322 cfg['ssh'] = sshserver
323 323 if exec_key:
324 324 cfg['exec_key'] = exec_key
325 325 exec_key = cfg['exec_key']
326 326 location = cfg.setdefault('location', None)
327 327 cfg['url'] = util.disambiguate_url(cfg['url'], location)
328 328 url = cfg['url']
329 329 proto,addr,port = util.split_url(url)
330 330 if location is not None and addr == '127.0.0.1':
331 331 # location specified, and connection is expected to be local
332 332 if location not in LOCAL_IPS and not sshserver:
333 333 # load ssh from JSON *only* if the controller is not on
334 334 # this machine
335 335 sshserver=cfg['ssh']
336 336 if location not in LOCAL_IPS and not sshserver:
337 337 # warn if no ssh specified, but SSH is probably needed
338 338 # This is only a warning, because the most likely cause
339 339 # is a local Controller on a laptop whose IP is dynamic
340 340 warnings.warn("""
341 341 Controller appears to be listening on localhost, but not on this machine.
342 342 If this is true, you should specify Client(...,sshserver='you@%s')
343 343 or instruct your controller to listen on an external IP."""%location,
344 344 RuntimeWarning)
345 345 elif not sshserver:
346 346 # otherwise sync with cfg
347 347 sshserver = cfg['ssh']
348 348
349 349 self._config = cfg
350 350
351 351 self._ssh = bool(sshserver or sshkey or password)
352 352 if self._ssh and sshserver is None:
353 353 # default to ssh via localhost
354 354 sshserver = url.split('://')[1].split(':')[0]
355 355 if self._ssh and password is None:
356 356 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
357 357 password=False
358 358 else:
359 359 password = getpass("SSH Password for %s: "%sshserver)
360 360 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
361 361
362 362 # configure and construct the session
363 363 if exec_key is not None:
364 364 if os.path.isfile(exec_key):
365 365 extra_args['keyfile'] = exec_key
366 366 else:
367 367 exec_key = util.asbytes(exec_key)
368 368 extra_args['key'] = exec_key
369 369 self.session = Session(**extra_args)
370 370
371 371 self._query_socket = self._context.socket(zmq.DEALER)
372 372 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
373 373 if self._ssh:
374 374 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
375 375 else:
376 376 self._query_socket.connect(url)
377 377
378 378 self.session.debug = self.debug
379 379
380 380 self._notification_handlers = {'registration_notification' : self._register_engine,
381 381 'unregistration_notification' : self._unregister_engine,
382 382 'shutdown_notification' : lambda msg: self.close(),
383 383 }
384 384 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
385 385 'apply_reply' : self._handle_apply_reply}
386 386 self._connect(sshserver, ssh_kwargs, timeout)
387 387
388 388 def __del__(self):
389 389 """cleanup sockets, but _not_ context."""
390 390 self.close()
391 391
392 392 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
393 393 if ipython_dir is None:
394 394 ipython_dir = get_ipython_dir()
395 395 if profile_dir is not None:
396 396 try:
397 397 self._cd = ProfileDir.find_profile_dir(profile_dir)
398 398 return
399 399 except ProfileDirError:
400 400 pass
401 401 elif profile is not None:
402 402 try:
403 403 self._cd = ProfileDir.find_profile_dir_by_name(
404 404 ipython_dir, profile)
405 405 return
406 406 except ProfileDirError:
407 407 pass
408 408 self._cd = None
409 409
410 410 def _update_engines(self, engines):
411 411 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
412 412 for k,v in engines.iteritems():
413 413 eid = int(k)
414 414 self._engines[eid] = v
415 415 self._ids.append(eid)
416 416 self._ids = sorted(self._ids)
417 417 if sorted(self._engines.keys()) != range(len(self._engines)) and \
418 418 self._task_scheme == 'pure' and self._task_socket:
419 419 self._stop_scheduling_tasks()
420 420
421 421 def _stop_scheduling_tasks(self):
422 422 """Stop scheduling tasks because an engine has been unregistered
423 423 from a pure ZMQ scheduler.
424 424 """
425 425 self._task_socket.close()
426 426 self._task_socket = None
427 427 msg = "An engine has been unregistered, and we are using pure " +\
428 428 "ZMQ task scheduling. Task farming will be disabled."
429 429 if self.outstanding:
430 430 msg += " If you were running tasks when this happened, " +\
431 431 "some `outstanding` msg_ids may never resolve."
432 432 warnings.warn(msg, RuntimeWarning)
433 433
434 434 def _build_targets(self, targets):
435 435 """Turn valid target IDs or 'all' into two lists:
436 436 (int_ids, uuids).
437 437 """
438 438 if not self._ids:
439 439 # flush notification socket if no engines yet, just in case
440 440 if not self.ids:
441 441 raise error.NoEnginesRegistered("Can't build targets without any engines")
442 442
443 443 if targets is None:
444 444 targets = self._ids
445 445 elif isinstance(targets, basestring):
446 446 if targets.lower() == 'all':
447 447 targets = self._ids
448 448 else:
449 449 raise TypeError("%r not valid str target, must be 'all'"%(targets))
450 450 elif isinstance(targets, int):
451 451 if targets < 0:
452 452 targets = self.ids[targets]
453 453 if targets not in self._ids:
454 454 raise IndexError("No such engine: %i"%targets)
455 455 targets = [targets]
456 456
457 457 if isinstance(targets, slice):
458 458 indices = range(len(self._ids))[targets]
459 459 ids = self.ids
460 460 targets = [ ids[i] for i in indices ]
461 461
462 462 if not isinstance(targets, (tuple, list, xrange)):
463 463 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
464 464
465 465 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
466 466
467 467 def _connect(self, sshserver, ssh_kwargs, timeout):
468 468 """setup all our socket connections to the cluster. This is called from
469 469 __init__."""
470 470
471 471 # Maybe allow reconnecting?
472 472 if self._connected:
473 473 return
474 474 self._connected=True
475 475
476 476 def connect_socket(s, url):
477 477 url = util.disambiguate_url(url, self._config['location'])
478 478 if self._ssh:
479 479 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
480 480 else:
481 481 return s.connect(url)
482 482
483 483 self.session.send(self._query_socket, 'connection_request')
484 484 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
485 485 poller = zmq.Poller()
486 486 poller.register(self._query_socket, zmq.POLLIN)
487 487 # poll expects milliseconds, timeout is seconds
488 488 evts = poller.poll(timeout*1000)
489 489 if not evts:
490 490 raise error.TimeoutError("Hub connection request timed out")
491 491 idents,msg = self.session.recv(self._query_socket,mode=0)
492 492 if self.debug:
493 493 pprint(msg)
494 494 msg = Message(msg)
495 495 content = msg.content
496 496 self._config['registration'] = dict(content)
497 497 if content.status == 'ok':
498 498 ident = self.session.bsession
499 499 if content.mux:
500 500 self._mux_socket = self._context.socket(zmq.DEALER)
501 501 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
502 502 connect_socket(self._mux_socket, content.mux)
503 503 if content.task:
504 504 self._task_scheme, task_addr = content.task
505 505 self._task_socket = self._context.socket(zmq.DEALER)
506 506 self._task_socket.setsockopt(zmq.IDENTITY, ident)
507 507 connect_socket(self._task_socket, task_addr)
508 508 if content.notification:
509 509 self._notification_socket = self._context.socket(zmq.SUB)
510 510 connect_socket(self._notification_socket, content.notification)
511 511 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
512 512 # if content.query:
513 513 # self._query_socket = self._context.socket(zmq.DEALER)
514 514 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
515 515 # connect_socket(self._query_socket, content.query)
516 516 if content.control:
517 517 self._control_socket = self._context.socket(zmq.DEALER)
518 518 self._control_socket.setsockopt(zmq.IDENTITY, ident)
519 519 connect_socket(self._control_socket, content.control)
520 520 if content.iopub:
521 521 self._iopub_socket = self._context.socket(zmq.SUB)
522 522 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
523 523 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
524 524 connect_socket(self._iopub_socket, content.iopub)
525 525 self._update_engines(dict(content.engines))
526 526 else:
527 527 self._connected = False
528 528 raise Exception("Failed to connect!")
529 529
530 530 #--------------------------------------------------------------------------
531 531 # handlers and callbacks for incoming messages
532 532 #--------------------------------------------------------------------------
533 533
534 534 def _unwrap_exception(self, content):
535 535 """unwrap exception, and remap engine_id to int."""
536 536 e = error.unwrap_exception(content)
537 537 # print e.traceback
538 538 if e.engine_info:
539 539 e_uuid = e.engine_info['engine_uuid']
540 540 eid = self._engines[e_uuid]
541 541 e.engine_info['engine_id'] = eid
542 542 return e
543 543
544 544 def _extract_metadata(self, header, parent, content):
545 545 md = {'msg_id' : parent['msg_id'],
546 546 'received' : datetime.now(),
547 547 'engine_uuid' : header.get('engine', None),
548 548 'follow' : parent.get('follow', []),
549 549 'after' : parent.get('after', []),
550 550 'status' : content['status'],
551 551 }
552 552
553 553 if md['engine_uuid'] is not None:
554 554 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
555 555
556 556 if 'date' in parent:
557 557 md['submitted'] = parent['date']
558 558 if 'started' in header:
559 559 md['started'] = header['started']
560 560 if 'date' in header:
561 561 md['completed'] = header['date']
562 562 return md
563 563
564 564 def _register_engine(self, msg):
565 565 """Register a new engine, and update our connection info."""
566 566 content = msg['content']
567 567 eid = content['id']
568 568 d = {eid : content['queue']}
569 569 self._update_engines(d)
570 570
571 571 def _unregister_engine(self, msg):
572 572 """Unregister an engine that has died."""
573 573 content = msg['content']
574 574 eid = int(content['id'])
575 575 if eid in self._ids:
576 576 self._ids.remove(eid)
577 577 uuid = self._engines.pop(eid)
578 578
579 579 self._handle_stranded_msgs(eid, uuid)
580 580
581 581 if self._task_socket and self._task_scheme == 'pure':
582 582 self._stop_scheduling_tasks()
583 583
584 584 def _handle_stranded_msgs(self, eid, uuid):
585 585 """Handle messages known to be on an engine when the engine unregisters.
586 586
587 587 It is possible that this will fire prematurely - that is, an engine will
588 588 go down after completing a result, and the client will be notified
589 589 of the unregistration and later receive the successful result.
590 590 """
591 591
592 592 outstanding = self._outstanding_dict[uuid]
593 593
594 594 for msg_id in list(outstanding):
595 595 if msg_id in self.results:
596 596 # we already
597 597 continue
598 598 try:
599 599 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
600 600 except:
601 601 content = error.wrap_exception()
602 602 # build a fake message:
603 603 parent = {}
604 604 header = {}
605 605 parent['msg_id'] = msg_id
606 606 header['engine'] = uuid
607 607 header['date'] = datetime.now()
608 608 msg = dict(parent_header=parent, header=header, content=content)
609 609 self._handle_apply_reply(msg)
610 610
611 611 def _handle_execute_reply(self, msg):
612 612 """Save the reply to an execute_request into our results.
613 613
614 614 execute messages are never actually used. apply is used instead.
615 615 """
616 616
617 617 parent = msg['parent_header']
618 618 msg_id = parent['msg_id']
619 619 if msg_id not in self.outstanding:
620 620 if msg_id in self.history:
621 621 print ("got stale result: %s"%msg_id)
622 622 else:
623 623 print ("got unknown result: %s"%msg_id)
624 624 else:
625 625 self.outstanding.remove(msg_id)
626 626 self.results[msg_id] = self._unwrap_exception(msg['content'])
627 627
628 628 def _handle_apply_reply(self, msg):
629 629 """Save the reply to an apply_request into our results."""
630 630 parent = msg['parent_header']
631 631 msg_id = parent['msg_id']
632 632 if msg_id not in self.outstanding:
633 633 if msg_id in self.history:
634 634 print ("got stale result: %s"%msg_id)
635 635 print self.results[msg_id]
636 636 print msg
637 637 else:
638 638 print ("got unknown result: %s"%msg_id)
639 639 else:
640 640 self.outstanding.remove(msg_id)
641 641 content = msg['content']
642 642 header = msg['header']
643 643
644 644 # construct metadata:
645 645 md = self.metadata[msg_id]
646 646 md.update(self._extract_metadata(header, parent, content))
647 647 # is this redundant?
648 648 self.metadata[msg_id] = md
649 649
650 650 e_outstanding = self._outstanding_dict[md['engine_uuid']]
651 651 if msg_id in e_outstanding:
652 652 e_outstanding.remove(msg_id)
653 653
654 654 # construct result:
655 655 if content['status'] == 'ok':
656 656 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
657 657 elif content['status'] == 'aborted':
658 658 self.results[msg_id] = error.TaskAborted(msg_id)
659 659 elif content['status'] == 'resubmitted':
660 660 # TODO: handle resubmission
661 661 pass
662 662 else:
663 663 self.results[msg_id] = self._unwrap_exception(content)
664 664
665 665 def _flush_notifications(self):
666 666 """Flush notifications of engine registrations waiting
667 667 in ZMQ queue."""
668 668 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
669 669 while msg is not None:
670 670 if self.debug:
671 671 pprint(msg)
672 672 msg_type = msg['header']['msg_type']
673 673 handler = self._notification_handlers.get(msg_type, None)
674 674 if handler is None:
675 675 raise Exception("Unhandled message type: %s"%msg.msg_type)
676 676 else:
677 677 handler(msg)
678 678 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
679 679
680 680 def _flush_results(self, sock):
681 681 """Flush task or queue results waiting in ZMQ queue."""
682 682 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
683 683 while msg is not None:
684 684 if self.debug:
685 685 pprint(msg)
686 686 msg_type = msg['header']['msg_type']
687 687 handler = self._queue_handlers.get(msg_type, None)
688 688 if handler is None:
689 689 raise Exception("Unhandled message type: %s"%msg.msg_type)
690 690 else:
691 691 handler(msg)
692 692 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
693 693
694 694 def _flush_control(self, sock):
695 695 """Flush replies from the control channel waiting
696 696 in the ZMQ queue.
697 697
698 698 Currently: ignore them."""
699 699 if self._ignored_control_replies <= 0:
700 700 return
701 701 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
702 702 while msg is not None:
703 703 self._ignored_control_replies -= 1
704 704 if self.debug:
705 705 pprint(msg)
706 706 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
707 707
708 708 def _flush_ignored_control(self):
709 709 """flush ignored control replies"""
710 710 while self._ignored_control_replies > 0:
711 711 self.session.recv(self._control_socket)
712 712 self._ignored_control_replies -= 1
713 713
714 714 def _flush_ignored_hub_replies(self):
715 715 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
716 716 while msg is not None:
717 717 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
718 718
719 719 def _flush_iopub(self, sock):
720 720 """Flush replies from the iopub channel waiting
721 721 in the ZMQ queue.
722 722 """
723 723 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
724 724 while msg is not None:
725 725 if self.debug:
726 726 pprint(msg)
727 727 parent = msg['parent_header']
728 728 msg_id = parent['msg_id']
729 729 content = msg['content']
730 730 header = msg['header']
731 731 msg_type = msg['header']['msg_type']
732 732
733 733 # init metadata:
734 734 md = self.metadata[msg_id]
735 735
736 736 if msg_type == 'stream':
737 737 name = content['name']
738 738 s = md[name] or ''
739 739 md[name] = s + content['data']
740 740 elif msg_type == 'pyerr':
741 741 md.update({'pyerr' : self._unwrap_exception(content)})
742 742 elif msg_type == 'pyin':
743 743 md.update({'pyin' : content['code']})
744 744 else:
745 745 md.update({msg_type : content.get('data', '')})
746 746
747 747 # reduntant?
748 748 self.metadata[msg_id] = md
749 749
750 750 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
751 751
752 752 #--------------------------------------------------------------------------
753 753 # len, getitem
754 754 #--------------------------------------------------------------------------
755 755
756 756 def __len__(self):
757 757 """len(client) returns # of engines."""
758 758 return len(self.ids)
759 759
760 760 def __getitem__(self, key):
761 761 """index access returns DirectView multiplexer objects
762 762
763 763 Must be int, slice, or list/tuple/xrange of ints"""
764 764 if not isinstance(key, (int, slice, tuple, list, xrange)):
765 765 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
766 766 else:
767 767 return self.direct_view(key)
768 768
769 769 #--------------------------------------------------------------------------
770 770 # Begin public methods
771 771 #--------------------------------------------------------------------------
772 772
773 773 @property
774 774 def ids(self):
775 775 """Always up-to-date ids property."""
776 776 self._flush_notifications()
777 777 # always copy:
778 778 return list(self._ids)
779 779
780 780 def close(self):
781 781 if self._closed:
782 782 return
783 783 snames = filter(lambda n: n.endswith('socket'), dir(self))
784 784 for socket in map(lambda name: getattr(self, name), snames):
785 785 if isinstance(socket, zmq.Socket) and not socket.closed:
786 786 socket.close()
787 787 self._closed = True
788 788
789 789 def spin(self):
790 790 """Flush any registration notifications and execution results
791 791 waiting in the ZMQ queue.
792 792 """
793 793 if self._notification_socket:
794 794 self._flush_notifications()
795 795 if self._mux_socket:
796 796 self._flush_results(self._mux_socket)
797 797 if self._task_socket:
798 798 self._flush_results(self._task_socket)
799 799 if self._control_socket:
800 800 self._flush_control(self._control_socket)
801 801 if self._iopub_socket:
802 802 self._flush_iopub(self._iopub_socket)
803 803 if self._query_socket:
804 804 self._flush_ignored_hub_replies()
805 805
806 806 def wait(self, jobs=None, timeout=-1):
807 807 """waits on one or more `jobs`, for up to `timeout` seconds.
808 808
809 809 Parameters
810 810 ----------
811 811
812 812 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
813 813 ints are indices to self.history
814 814 strs are msg_ids
815 815 default: wait on all outstanding messages
816 816 timeout : float
817 817 a time in seconds, after which to give up.
818 818 default is -1, which means no timeout
819 819
820 820 Returns
821 821 -------
822 822
823 823 True : when all msg_ids are done
824 824 False : timeout reached, some msg_ids still outstanding
825 825 """
826 826 tic = time.time()
827 827 if jobs is None:
828 828 theids = self.outstanding
829 829 else:
830 830 if isinstance(jobs, (int, basestring, AsyncResult)):
831 831 jobs = [jobs]
832 832 theids = set()
833 833 for job in jobs:
834 834 if isinstance(job, int):
835 835 # index access
836 836 job = self.history[job]
837 837 elif isinstance(job, AsyncResult):
838 838 map(theids.add, job.msg_ids)
839 839 continue
840 840 theids.add(job)
841 841 if not theids.intersection(self.outstanding):
842 842 return True
843 843 self.spin()
844 844 while theids.intersection(self.outstanding):
845 845 if timeout >= 0 and ( time.time()-tic ) > timeout:
846 846 break
847 847 time.sleep(1e-3)
848 848 self.spin()
849 849 return len(theids.intersection(self.outstanding)) == 0
850 850
851 851 #--------------------------------------------------------------------------
852 852 # Control methods
853 853 #--------------------------------------------------------------------------
854 854
855 855 @spin_first
856 856 def clear(self, targets=None, block=None):
857 857 """Clear the namespace in target(s)."""
858 858 block = self.block if block is None else block
859 859 targets = self._build_targets(targets)[0]
860 860 for t in targets:
861 861 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
862 862 error = False
863 863 if block:
864 864 self._flush_ignored_control()
865 865 for i in range(len(targets)):
866 866 idents,msg = self.session.recv(self._control_socket,0)
867 867 if self.debug:
868 868 pprint(msg)
869 869 if msg['content']['status'] != 'ok':
870 870 error = self._unwrap_exception(msg['content'])
871 871 else:
872 872 self._ignored_control_replies += len(targets)
873 873 if error:
874 874 raise error
875 875
876 876
877 877 @spin_first
878 878 def abort(self, jobs=None, targets=None, block=None):
879 879 """Abort specific jobs from the execution queues of target(s).
880 880
881 881 This is a mechanism to prevent jobs that have already been submitted
882 882 from executing.
883 883
884 884 Parameters
885 885 ----------
886 886
887 887 jobs : msg_id, list of msg_ids, or AsyncResult
888 888 The jobs to be aborted
889 889
890 890
891 891 """
892 892 block = self.block if block is None else block
893 893 targets = self._build_targets(targets)[0]
894 894 msg_ids = []
895 895 if isinstance(jobs, (basestring,AsyncResult)):
896 896 jobs = [jobs]
897 897 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
898 898 if bad_ids:
899 899 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
900 900 for j in jobs:
901 901 if isinstance(j, AsyncResult):
902 902 msg_ids.extend(j.msg_ids)
903 903 else:
904 904 msg_ids.append(j)
905 905 content = dict(msg_ids=msg_ids)
906 906 for t in targets:
907 907 self.session.send(self._control_socket, 'abort_request',
908 908 content=content, ident=t)
909 909 error = False
910 910 if block:
911 911 self._flush_ignored_control()
912 912 for i in range(len(targets)):
913 913 idents,msg = self.session.recv(self._control_socket,0)
914 914 if self.debug:
915 915 pprint(msg)
916 916 if msg['content']['status'] != 'ok':
917 917 error = self._unwrap_exception(msg['content'])
918 918 else:
919 919 self._ignored_control_replies += len(targets)
920 920 if error:
921 921 raise error
922 922
923 923 @spin_first
924 924 def shutdown(self, targets=None, restart=False, hub=False, block=None):
925 925 """Terminates one or more engine processes, optionally including the hub."""
926 926 block = self.block if block is None else block
927 927 if hub:
928 928 targets = 'all'
929 929 targets = self._build_targets(targets)[0]
930 930 for t in targets:
931 931 self.session.send(self._control_socket, 'shutdown_request',
932 932 content={'restart':restart},ident=t)
933 933 error = False
934 934 if block or hub:
935 935 self._flush_ignored_control()
936 936 for i in range(len(targets)):
937 937 idents,msg = self.session.recv(self._control_socket, 0)
938 938 if self.debug:
939 939 pprint(msg)
940 940 if msg['content']['status'] != 'ok':
941 941 error = self._unwrap_exception(msg['content'])
942 942 else:
943 943 self._ignored_control_replies += len(targets)
944 944
945 945 if hub:
946 946 time.sleep(0.25)
947 947 self.session.send(self._query_socket, 'shutdown_request')
948 948 idents,msg = self.session.recv(self._query_socket, 0)
949 949 if self.debug:
950 950 pprint(msg)
951 951 if msg['content']['status'] != 'ok':
952 952 error = self._unwrap_exception(msg['content'])
953 953
954 954 if error:
955 955 raise error
956 956
957 957 #--------------------------------------------------------------------------
958 958 # Execution related methods
959 959 #--------------------------------------------------------------------------
960 960
961 961 def _maybe_raise(self, result):
962 962 """wrapper for maybe raising an exception if apply failed."""
963 963 if isinstance(result, error.RemoteError):
964 964 raise result
965 965
966 966 return result
967 967
968 968 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
969 969 ident=None):
970 970 """construct and send an apply message via a socket.
971 971
972 972 This is the principal method with which all engine execution is performed by views.
973 973 """
974 974
975 975 assert not self._closed, "cannot use me anymore, I'm closed!"
976 976 # defaults:
977 977 args = args if args is not None else []
978 978 kwargs = kwargs if kwargs is not None else {}
979 979 subheader = subheader if subheader is not None else {}
980 980
981 981 # validate arguments
982 982 if not callable(f):
983 983 raise TypeError("f must be callable, not %s"%type(f))
984 984 if not isinstance(args, (tuple, list)):
985 985 raise TypeError("args must be tuple or list, not %s"%type(args))
986 986 if not isinstance(kwargs, dict):
987 987 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
988 988 if not isinstance(subheader, dict):
989 989 raise TypeError("subheader must be dict, not %s"%type(subheader))
990 990
991 991 bufs = util.pack_apply_message(f,args,kwargs)
992 992
993 993 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
994 994 subheader=subheader, track=track)
995 995
996 996 msg_id = msg['header']['msg_id']
997 997 self.outstanding.add(msg_id)
998 998 if ident:
999 999 # possibly routed to a specific engine
1000 1000 if isinstance(ident, list):
1001 1001 ident = ident[-1]
1002 1002 if ident in self._engines.values():
1003 1003 # save for later, in case of engine death
1004 1004 self._outstanding_dict[ident].add(msg_id)
1005 1005 self.history.append(msg_id)
1006 1006 self.metadata[msg_id]['submitted'] = datetime.now()
1007 1007
1008 1008 return msg
1009 1009
1010 1010 #--------------------------------------------------------------------------
1011 1011 # construct a View object
1012 1012 #--------------------------------------------------------------------------
1013 1013
1014 1014 def load_balanced_view(self, targets=None):
1015 1015 """construct a DirectView object.
1016 1016
1017 1017 If no arguments are specified, create a LoadBalancedView
1018 1018 using all engines.
1019 1019
1020 1020 Parameters
1021 1021 ----------
1022 1022
1023 1023 targets: list,slice,int,etc. [default: use all engines]
1024 1024 The subset of engines across which to load-balance
1025 1025 """
1026 1026 if targets == 'all':
1027 1027 targets = None
1028 1028 if targets is not None:
1029 1029 targets = self._build_targets(targets)[1]
1030 1030 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1031 1031
1032 1032 def direct_view(self, targets='all'):
1033 1033 """construct a DirectView object.
1034 1034
1035 If no targets are specified, create a DirectView
1036 using all engines.
1035 If no targets are specified, create a DirectView using all engines.
1036
1037 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1038 evaluate the target engines at each execution, whereas rc[:] will connect to
1039 all *current* engines, and that list will not change.
1040
1041 That is, 'all' will always use all engines, whereas rc[:] will not use
1042 engines added after the DirectView is constructed.
1037 1043
1038 1044 Parameters
1039 1045 ----------
1040 1046
1041 1047 targets: list,slice,int,etc. [default: use all engines]
1042 1048 The engines to use for the View
1043 1049 """
1044 1050 single = isinstance(targets, int)
1045 1051 # allow 'all' to be lazily evaluated at each execution
1046 1052 if targets != 'all':
1047 1053 targets = self._build_targets(targets)[1]
1048 1054 if single:
1049 1055 targets = targets[0]
1050 1056 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1051 1057
1052 1058 #--------------------------------------------------------------------------
1053 1059 # Query methods
1054 1060 #--------------------------------------------------------------------------
1055 1061
1056 1062 @spin_first
1057 1063 def get_result(self, indices_or_msg_ids=None, block=None):
1058 1064 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1059 1065
1060 1066 If the client already has the results, no request to the Hub will be made.
1061 1067
1062 1068 This is a convenient way to construct AsyncResult objects, which are wrappers
1063 1069 that include metadata about execution, and allow for awaiting results that
1064 1070 were not submitted by this Client.
1065 1071
1066 1072 It can also be a convenient way to retrieve the metadata associated with
1067 1073 blocking execution, since it always retrieves
1068 1074
1069 1075 Examples
1070 1076 --------
1071 1077 ::
1072 1078
1073 1079 In [10]: r = client.apply()
1074 1080
1075 1081 Parameters
1076 1082 ----------
1077 1083
1078 1084 indices_or_msg_ids : integer history index, str msg_id, or list of either
1079 1085 The indices or msg_ids of indices to be retrieved
1080 1086
1081 1087 block : bool
1082 1088 Whether to wait for the result to be done
1083 1089
1084 1090 Returns
1085 1091 -------
1086 1092
1087 1093 AsyncResult
1088 1094 A single AsyncResult object will always be returned.
1089 1095
1090 1096 AsyncHubResult
1091 1097 A subclass of AsyncResult that retrieves results from the Hub
1092 1098
1093 1099 """
1094 1100 block = self.block if block is None else block
1095 1101 if indices_or_msg_ids is None:
1096 1102 indices_or_msg_ids = -1
1097 1103
1098 1104 if not isinstance(indices_or_msg_ids, (list,tuple)):
1099 1105 indices_or_msg_ids = [indices_or_msg_ids]
1100 1106
1101 1107 theids = []
1102 1108 for id in indices_or_msg_ids:
1103 1109 if isinstance(id, int):
1104 1110 id = self.history[id]
1105 1111 if not isinstance(id, basestring):
1106 1112 raise TypeError("indices must be str or int, not %r"%id)
1107 1113 theids.append(id)
1108 1114
1109 1115 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1110 1116 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1111 1117
1112 1118 if remote_ids:
1113 1119 ar = AsyncHubResult(self, msg_ids=theids)
1114 1120 else:
1115 1121 ar = AsyncResult(self, msg_ids=theids)
1116 1122
1117 1123 if block:
1118 1124 ar.wait()
1119 1125
1120 1126 return ar
1121 1127
1122 1128 @spin_first
1123 1129 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1124 1130 """Resubmit one or more tasks.
1125 1131
1126 1132 in-flight tasks may not be resubmitted.
1127 1133
1128 1134 Parameters
1129 1135 ----------
1130 1136
1131 1137 indices_or_msg_ids : integer history index, str msg_id, or list of either
1132 1138 The indices or msg_ids of indices to be retrieved
1133 1139
1134 1140 block : bool
1135 1141 Whether to wait for the result to be done
1136 1142
1137 1143 Returns
1138 1144 -------
1139 1145
1140 1146 AsyncHubResult
1141 1147 A subclass of AsyncResult that retrieves results from the Hub
1142 1148
1143 1149 """
1144 1150 block = self.block if block is None else block
1145 1151 if indices_or_msg_ids is None:
1146 1152 indices_or_msg_ids = -1
1147 1153
1148 1154 if not isinstance(indices_or_msg_ids, (list,tuple)):
1149 1155 indices_or_msg_ids = [indices_or_msg_ids]
1150 1156
1151 1157 theids = []
1152 1158 for id in indices_or_msg_ids:
1153 1159 if isinstance(id, int):
1154 1160 id = self.history[id]
1155 1161 if not isinstance(id, basestring):
1156 1162 raise TypeError("indices must be str or int, not %r"%id)
1157 1163 theids.append(id)
1158 1164
1159 1165 for msg_id in theids:
1160 1166 self.outstanding.discard(msg_id)
1161 1167 if msg_id in self.history:
1162 1168 self.history.remove(msg_id)
1163 1169 self.results.pop(msg_id, None)
1164 1170 self.metadata.pop(msg_id, None)
1165 1171 content = dict(msg_ids = theids)
1166 1172
1167 1173 self.session.send(self._query_socket, 'resubmit_request', content)
1168 1174
1169 1175 zmq.select([self._query_socket], [], [])
1170 1176 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1171 1177 if self.debug:
1172 1178 pprint(msg)
1173 1179 content = msg['content']
1174 1180 if content['status'] != 'ok':
1175 1181 raise self._unwrap_exception(content)
1176 1182
1177 1183 ar = AsyncHubResult(self, msg_ids=theids)
1178 1184
1179 1185 if block:
1180 1186 ar.wait()
1181 1187
1182 1188 return ar
1183 1189
1184 1190 @spin_first
1185 1191 def result_status(self, msg_ids, status_only=True):
1186 1192 """Check on the status of the result(s) of the apply request with `msg_ids`.
1187 1193
1188 1194 If status_only is False, then the actual results will be retrieved, else
1189 1195 only the status of the results will be checked.
1190 1196
1191 1197 Parameters
1192 1198 ----------
1193 1199
1194 1200 msg_ids : list of msg_ids
1195 1201 if int:
1196 1202 Passed as index to self.history for convenience.
1197 1203 status_only : bool (default: True)
1198 1204 if False:
1199 1205 Retrieve the actual results of completed tasks.
1200 1206
1201 1207 Returns
1202 1208 -------
1203 1209
1204 1210 results : dict
1205 1211 There will always be the keys 'pending' and 'completed', which will
1206 1212 be lists of msg_ids that are incomplete or complete. If `status_only`
1207 1213 is False, then completed results will be keyed by their `msg_id`.
1208 1214 """
1209 1215 if not isinstance(msg_ids, (list,tuple)):
1210 1216 msg_ids = [msg_ids]
1211 1217
1212 1218 theids = []
1213 1219 for msg_id in msg_ids:
1214 1220 if isinstance(msg_id, int):
1215 1221 msg_id = self.history[msg_id]
1216 1222 if not isinstance(msg_id, basestring):
1217 1223 raise TypeError("msg_ids must be str, not %r"%msg_id)
1218 1224 theids.append(msg_id)
1219 1225
1220 1226 completed = []
1221 1227 local_results = {}
1222 1228
1223 1229 # comment this block out to temporarily disable local shortcut:
1224 1230 for msg_id in theids:
1225 1231 if msg_id in self.results:
1226 1232 completed.append(msg_id)
1227 1233 local_results[msg_id] = self.results[msg_id]
1228 1234 theids.remove(msg_id)
1229 1235
1230 1236 if theids: # some not locally cached
1231 1237 content = dict(msg_ids=theids, status_only=status_only)
1232 1238 msg = self.session.send(self._query_socket, "result_request", content=content)
1233 1239 zmq.select([self._query_socket], [], [])
1234 1240 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1235 1241 if self.debug:
1236 1242 pprint(msg)
1237 1243 content = msg['content']
1238 1244 if content['status'] != 'ok':
1239 1245 raise self._unwrap_exception(content)
1240 1246 buffers = msg['buffers']
1241 1247 else:
1242 1248 content = dict(completed=[],pending=[])
1243 1249
1244 1250 content['completed'].extend(completed)
1245 1251
1246 1252 if status_only:
1247 1253 return content
1248 1254
1249 1255 failures = []
1250 1256 # load cached results into result:
1251 1257 content.update(local_results)
1252 1258
1253 1259 # update cache with results:
1254 1260 for msg_id in sorted(theids):
1255 1261 if msg_id in content['completed']:
1256 1262 rec = content[msg_id]
1257 1263 parent = rec['header']
1258 1264 header = rec['result_header']
1259 1265 rcontent = rec['result_content']
1260 1266 iodict = rec['io']
1261 1267 if isinstance(rcontent, str):
1262 1268 rcontent = self.session.unpack(rcontent)
1263 1269
1264 1270 md = self.metadata[msg_id]
1265 1271 md.update(self._extract_metadata(header, parent, rcontent))
1266 1272 md.update(iodict)
1267 1273
1268 1274 if rcontent['status'] == 'ok':
1269 1275 res,buffers = util.unserialize_object(buffers)
1270 1276 else:
1271 1277 print rcontent
1272 1278 res = self._unwrap_exception(rcontent)
1273 1279 failures.append(res)
1274 1280
1275 1281 self.results[msg_id] = res
1276 1282 content[msg_id] = res
1277 1283
1278 1284 if len(theids) == 1 and failures:
1279 1285 raise failures[0]
1280 1286
1281 1287 error.collect_exceptions(failures, "result_status")
1282 1288 return content
1283 1289
1284 1290 @spin_first
1285 1291 def queue_status(self, targets='all', verbose=False):
1286 1292 """Fetch the status of engine queues.
1287 1293
1288 1294 Parameters
1289 1295 ----------
1290 1296
1291 1297 targets : int/str/list of ints/strs
1292 1298 the engines whose states are to be queried.
1293 1299 default : all
1294 1300 verbose : bool
1295 1301 Whether to return lengths only, or lists of ids for each element
1296 1302 """
1297 1303 engine_ids = self._build_targets(targets)[1]
1298 1304 content = dict(targets=engine_ids, verbose=verbose)
1299 1305 self.session.send(self._query_socket, "queue_request", content=content)
1300 1306 idents,msg = self.session.recv(self._query_socket, 0)
1301 1307 if self.debug:
1302 1308 pprint(msg)
1303 1309 content = msg['content']
1304 1310 status = content.pop('status')
1305 1311 if status != 'ok':
1306 1312 raise self._unwrap_exception(content)
1307 1313 content = rekey(content)
1308 1314 if isinstance(targets, int):
1309 1315 return content[targets]
1310 1316 else:
1311 1317 return content
1312 1318
1313 1319 @spin_first
1314 1320 def purge_results(self, jobs=[], targets=[]):
1315 1321 """Tell the Hub to forget results.
1316 1322
1317 1323 Individual results can be purged by msg_id, or the entire
1318 1324 history of specific targets can be purged.
1319 1325
1320 1326 Use `purge_results('all')` to scrub everything from the Hub's db.
1321 1327
1322 1328 Parameters
1323 1329 ----------
1324 1330
1325 1331 jobs : str or list of str or AsyncResult objects
1326 1332 the msg_ids whose results should be forgotten.
1327 1333 targets : int/str/list of ints/strs
1328 1334 The targets, by int_id, whose entire history is to be purged.
1329 1335
1330 1336 default : None
1331 1337 """
1332 1338 if not targets and not jobs:
1333 1339 raise ValueError("Must specify at least one of `targets` and `jobs`")
1334 1340 if targets:
1335 1341 targets = self._build_targets(targets)[1]
1336 1342
1337 1343 # construct msg_ids from jobs
1338 1344 if jobs == 'all':
1339 1345 msg_ids = jobs
1340 1346 else:
1341 1347 msg_ids = []
1342 1348 if isinstance(jobs, (basestring,AsyncResult)):
1343 1349 jobs = [jobs]
1344 1350 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1345 1351 if bad_ids:
1346 1352 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1347 1353 for j in jobs:
1348 1354 if isinstance(j, AsyncResult):
1349 1355 msg_ids.extend(j.msg_ids)
1350 1356 else:
1351 1357 msg_ids.append(j)
1352 1358
1353 1359 content = dict(engine_ids=targets, msg_ids=msg_ids)
1354 1360 self.session.send(self._query_socket, "purge_request", content=content)
1355 1361 idents, msg = self.session.recv(self._query_socket, 0)
1356 1362 if self.debug:
1357 1363 pprint(msg)
1358 1364 content = msg['content']
1359 1365 if content['status'] != 'ok':
1360 1366 raise self._unwrap_exception(content)
1361 1367
1362 1368 @spin_first
1363 1369 def hub_history(self):
1364 1370 """Get the Hub's history
1365 1371
1366 1372 Just like the Client, the Hub has a history, which is a list of msg_ids.
1367 1373 This will contain the history of all clients, and, depending on configuration,
1368 1374 may contain history across multiple cluster sessions.
1369 1375
1370 1376 Any msg_id returned here is a valid argument to `get_result`.
1371 1377
1372 1378 Returns
1373 1379 -------
1374 1380
1375 1381 msg_ids : list of strs
1376 1382 list of all msg_ids, ordered by task submission time.
1377 1383 """
1378 1384
1379 1385 self.session.send(self._query_socket, "history_request", content={})
1380 1386 idents, msg = self.session.recv(self._query_socket, 0)
1381 1387
1382 1388 if self.debug:
1383 1389 pprint(msg)
1384 1390 content = msg['content']
1385 1391 if content['status'] != 'ok':
1386 1392 raise self._unwrap_exception(content)
1387 1393 else:
1388 1394 return content['history']
1389 1395
1390 1396 @spin_first
1391 1397 def db_query(self, query, keys=None):
1392 1398 """Query the Hub's TaskRecord database
1393 1399
1394 1400 This will return a list of task record dicts that match `query`
1395 1401
1396 1402 Parameters
1397 1403 ----------
1398 1404
1399 1405 query : mongodb query dict
1400 1406 The search dict. See mongodb query docs for details.
1401 1407 keys : list of strs [optional]
1402 1408 The subset of keys to be returned. The default is to fetch everything but buffers.
1403 1409 'msg_id' will *always* be included.
1404 1410 """
1405 1411 if isinstance(keys, basestring):
1406 1412 keys = [keys]
1407 1413 content = dict(query=query, keys=keys)
1408 1414 self.session.send(self._query_socket, "db_request", content=content)
1409 1415 idents, msg = self.session.recv(self._query_socket, 0)
1410 1416 if self.debug:
1411 1417 pprint(msg)
1412 1418 content = msg['content']
1413 1419 if content['status'] != 'ok':
1414 1420 raise self._unwrap_exception(content)
1415 1421
1416 1422 records = content['records']
1417 1423
1418 1424 buffer_lens = content['buffer_lens']
1419 1425 result_buffer_lens = content['result_buffer_lens']
1420 1426 buffers = msg['buffers']
1421 1427 has_bufs = buffer_lens is not None
1422 1428 has_rbufs = result_buffer_lens is not None
1423 1429 for i,rec in enumerate(records):
1424 1430 # relink buffers
1425 1431 if has_bufs:
1426 1432 blen = buffer_lens[i]
1427 1433 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1428 1434 if has_rbufs:
1429 1435 blen = result_buffer_lens[i]
1430 1436 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1431 1437
1432 1438 return records
1433 1439
1434 1440 __all__ = [ 'Client' ]
@@ -1,219 +1,222 b''
1 1 """Remote Functions and decorators for Views.
2 2
3 3 Authors:
4 4
5 5 * Brian Granger
6 6 * Min RK
7 7 """
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2010-2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import sys
22 22 import warnings
23 23
24 24 from IPython.testing.skipdoctest import skip_doctest
25 25
26 26 from . import map as Map
27 27 from .asyncresult import AsyncMapResult
28 28
29 29 #-----------------------------------------------------------------------------
30 30 # Decorators
31 31 #-----------------------------------------------------------------------------
32 32
33 33 @skip_doctest
34 34 def remote(view, block=None, **flags):
35 35 """Turn a function into a remote function.
36 36
37 37 This method can be used for map:
38 38
39 39 In [1]: @remote(view,block=True)
40 40 ...: def func(a):
41 41 ...: pass
42 42 """
43 43
44 44 def remote_function(f):
45 45 return RemoteFunction(view, f, block=block, **flags)
46 46 return remote_function
47 47
48 48 @skip_doctest
49 49 def parallel(view, dist='b', block=None, ordered=True, **flags):
50 50 """Turn a function into a parallel remote function.
51 51
52 52 This method can be used for map:
53 53
54 54 In [1]: @parallel(view, block=True)
55 55 ...: def func(a):
56 56 ...: pass
57 57 """
58 58
59 59 def parallel_function(f):
60 60 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
61 61 return parallel_function
62 62
63 63 #--------------------------------------------------------------------------
64 64 # Classes
65 65 #--------------------------------------------------------------------------
66 66
67 67 class RemoteFunction(object):
68 68 """Turn an existing function into a remote function.
69 69
70 70 Parameters
71 71 ----------
72 72
73 73 view : View instance
74 74 The view to be used for execution
75 75 f : callable
76 76 The function to be wrapped into a remote function
77 77 block : bool [default: None]
78 78 Whether to wait for results or not. The default behavior is
79 79 to use the current `block` attribute of `view`
80 80
81 81 **flags : remaining kwargs are passed to View.temp_flags
82 82 """
83 83
84 84 view = None # the remote connection
85 85 func = None # the wrapped function
86 86 block = None # whether to block
87 87 flags = None # dict of extra kwargs for temp_flags
88 88
89 89 def __init__(self, view, f, block=None, **flags):
90 90 self.view = view
91 91 self.func = f
92 92 self.block=block
93 93 self.flags=flags
94 94
95 95 def __call__(self, *args, **kwargs):
96 96 block = self.view.block if self.block is None else self.block
97 97 with self.view.temp_flags(block=block, **self.flags):
98 98 return self.view.apply(self.func, *args, **kwargs)
99 99
100 100
101 101 class ParallelFunction(RemoteFunction):
102 102 """Class for mapping a function to sequences.
103 103
104 104 This will distribute the sequences according the a mapper, and call
105 105 the function on each sub-sequence. If called via map, then the function
106 106 will be called once on each element, rather that each sub-sequence.
107 107
108 108 Parameters
109 109 ----------
110 110
111 111 view : View instance
112 112 The view to be used for execution
113 113 f : callable
114 114 The function to be wrapped into a remote function
115 115 dist : str [default: 'b']
116 116 The key for which mapObject to use to distribute sequences
117 117 options are:
118 118 * 'b' : use contiguous chunks in order
119 119 * 'r' : use round-robin striping
120 120 block : bool [default: None]
121 121 Whether to wait for results or not. The default behavior is
122 122 to use the current `block` attribute of `view`
123 123 chunksize : int or None
124 124 The size of chunk to use when breaking up sequences in a load-balanced manner
125 125 ordered : bool [default: True]
126 126 Whether
127 127 **flags : remaining kwargs are passed to View.temp_flags
128 128 """
129 129
130 130 chunksize=None
131 131 ordered=None
132 132 mapObject=None
133 133
134 134 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
135 135 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
136 136 self.chunksize = chunksize
137 137 self.ordered = ordered
138 138
139 139 mapClass = Map.dists[dist]
140 140 self.mapObject = mapClass()
141 141
142 142 def __call__(self, *sequences):
143 client = self.view.client
144
143 145 # check that the length of sequences match
144 146 len_0 = len(sequences[0])
145 147 for s in sequences:
146 148 if len(s)!=len_0:
147 149 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
148 150 raise ValueError(msg)
149 151 balanced = 'Balanced' in self.view.__class__.__name__
150 152 if balanced:
151 153 if self.chunksize:
152 154 nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0)
153 155 else:
154 156 nparts = len_0
155 157 targets = [None]*nparts
156 158 else:
157 159 if self.chunksize:
158 160 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
159 161 # multiplexed:
160 162 targets = self.view.targets
163 # 'all' is lazily evaluated at execution time, which is now:
164 if targets == 'all':
165 targets = client._build_targets(targets)[1]
161 166 nparts = len(targets)
162 167
163 168 msg_ids = []
164 # my_f = lambda *a: map(self.func, *a)
165 client = self.view.client
166 169 for index, t in enumerate(targets):
167 170 args = []
168 171 for seq in sequences:
169 172 part = self.mapObject.getPartition(seq, index, nparts)
170 173 if len(part) == 0:
171 174 continue
172 175 else:
173 176 args.append(part)
174 177 if not args:
175 178 continue
176 179
177 180 # print (args)
178 181 if hasattr(self, '_map'):
179 182 if sys.version_info[0] >= 3:
180 183 f = lambda f, *sequences: list(map(f, *sequences))
181 184 else:
182 185 f = map
183 186 args = [self.func]+args
184 187 else:
185 188 f=self.func
186 189
187 190 view = self.view if balanced else client[t]
188 191 with view.temp_flags(block=False, **self.flags):
189 192 ar = view.apply(f, *args)
190 193
191 194 msg_ids.append(ar.msg_ids[0])
192 195
193 196 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
194 197 fname=self.func.__name__,
195 198 ordered=self.ordered
196 199 )
197 200
198 201 if self.block:
199 202 try:
200 203 return r.get()
201 204 except KeyboardInterrupt:
202 205 return r
203 206 else:
204 207 return r
205 208
206 209 def map(self, *sequences):
207 210 """call a function on each element of a sequence remotely.
208 211 This should behave very much like the builtin map, but return an AsyncMapResult
209 212 if self.block is False.
210 213 """
211 214 # set _map as a flag for use inside self.__call__
212 215 self._map = True
213 216 try:
214 217 ret = self.__call__(*sequences)
215 218 finally:
216 219 del self._map
217 220 return ret
218 221
219 222 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,281 +1,316 b''
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import time
22 22 from datetime import datetime
23 23 from tempfile import mktemp
24 24
25 25 import zmq
26 26
27 27 from IPython.parallel.client import client as clientmod
28 28 from IPython.parallel import error
29 29 from IPython.parallel import AsyncResult, AsyncHubResult
30 30 from IPython.parallel import LoadBalancedView, DirectView
31 31
32 32 from clienttest import ClusterTestCase, segfault, wait, add_engines
33 33
34 34 def setup():
35 35 add_engines(4)
36 36
37 37 class TestClient(ClusterTestCase):
38 38
39 39 def test_ids(self):
40 40 n = len(self.client.ids)
41 41 self.add_engines(3)
42 42 self.assertEquals(len(self.client.ids), n+3)
43 43
44 44 def test_view_indexing(self):
45 45 """test index access for views"""
46 46 self.add_engines(2)
47 47 targets = self.client._build_targets('all')[-1]
48 48 v = self.client[:]
49 49 self.assertEquals(v.targets, targets)
50 50 t = self.client.ids[2]
51 51 v = self.client[t]
52 52 self.assert_(isinstance(v, DirectView))
53 53 self.assertEquals(v.targets, t)
54 54 t = self.client.ids[2:4]
55 55 v = self.client[t]
56 56 self.assert_(isinstance(v, DirectView))
57 57 self.assertEquals(v.targets, t)
58 58 v = self.client[::2]
59 59 self.assert_(isinstance(v, DirectView))
60 60 self.assertEquals(v.targets, targets[::2])
61 61 v = self.client[1::3]
62 62 self.assert_(isinstance(v, DirectView))
63 63 self.assertEquals(v.targets, targets[1::3])
64 64 v = self.client[:-3]
65 65 self.assert_(isinstance(v, DirectView))
66 66 self.assertEquals(v.targets, targets[:-3])
67 67 v = self.client[-1]
68 68 self.assert_(isinstance(v, DirectView))
69 69 self.assertEquals(v.targets, targets[-1])
70 70 self.assertRaises(TypeError, lambda : self.client[None])
71 71
72 72 def test_lbview_targets(self):
73 73 """test load_balanced_view targets"""
74 74 v = self.client.load_balanced_view()
75 75 self.assertEquals(v.targets, None)
76 76 v = self.client.load_balanced_view(-1)
77 77 self.assertEquals(v.targets, [self.client.ids[-1]])
78 78 v = self.client.load_balanced_view('all')
79 79 self.assertEquals(v.targets, None)
80 80
81 81 def test_dview_targets(self):
82 """test load_balanced_view targets"""
82 """test direct_view targets"""
83 83 v = self.client.direct_view()
84 84 self.assertEquals(v.targets, 'all')
85 85 v = self.client.direct_view('all')
86 86 self.assertEquals(v.targets, 'all')
87 87 v = self.client.direct_view(-1)
88 88 self.assertEquals(v.targets, self.client.ids[-1])
89 89
90 def test_lazy_all_targets(self):
91 """test lazy evaluation of rc.direct_view('all')"""
92 v = self.client.direct_view()
93 self.assertEquals(v.targets, 'all')
94
95 def double(x):
96 return x*2
97 seq = range(100)
98 ref = [ double(x) for x in seq ]
99
100 # add some engines, which should be used
101 self.add_engines(2)
102 n1 = len(self.client.ids)
103
104 # simple apply
105 r = v.apply_sync(lambda : 1)
106 self.assertEquals(r, [1] * n1)
107
108 # map goes through remotefunction
109 r = v.map_sync(double, seq)
110 self.assertEquals(r, ref)
111
112 # add a couple more engines, and try again
113 self.add_engines(2)
114 n2 = len(self.client.ids)
115 self.assertNotEquals(n2, n1)
116
117 # apply
118 r = v.apply_sync(lambda : 1)
119 self.assertEquals(r, [1] * n2)
120
121 # map
122 r = v.map_sync(double, seq)
123 self.assertEquals(r, ref)
124
90 125 def test_targets(self):
91 126 """test various valid targets arguments"""
92 127 build = self.client._build_targets
93 128 ids = self.client.ids
94 129 idents,targets = build(None)
95 130 self.assertEquals(ids, targets)
96 131
97 132 def test_clear(self):
98 133 """test clear behavior"""
99 134 # self.add_engines(2)
100 135 v = self.client[:]
101 136 v.block=True
102 137 v.push(dict(a=5))
103 138 v.pull('a')
104 139 id0 = self.client.ids[-1]
105 140 self.client.clear(targets=id0, block=True)
106 141 a = self.client[:-1].get('a')
107 142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
108 143 self.client.clear(block=True)
109 144 for i in self.client.ids:
110 145 # print i
111 146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
112 147
113 148 def test_get_result(self):
114 149 """test getting results from the Hub."""
115 150 c = clientmod.Client(profile='iptest')
116 151 # self.add_engines(1)
117 152 t = c.ids[-1]
118 153 ar = c[t].apply_async(wait, 1)
119 154 # give the monitor time to notice the message
120 155 time.sleep(.25)
121 156 ahr = self.client.get_result(ar.msg_ids)
122 157 self.assertTrue(isinstance(ahr, AsyncHubResult))
123 158 self.assertEquals(ahr.get(), ar.get())
124 159 ar2 = self.client.get_result(ar.msg_ids)
125 160 self.assertFalse(isinstance(ar2, AsyncHubResult))
126 161 c.close()
127 162
128 163 def test_ids_list(self):
129 164 """test client.ids"""
130 165 # self.add_engines(2)
131 166 ids = self.client.ids
132 167 self.assertEquals(ids, self.client._ids)
133 168 self.assertFalse(ids is self.client._ids)
134 169 ids.remove(ids[-1])
135 170 self.assertNotEquals(ids, self.client._ids)
136 171
137 172 def test_queue_status(self):
138 173 # self.addEngine(4)
139 174 ids = self.client.ids
140 175 id0 = ids[0]
141 176 qs = self.client.queue_status(targets=id0)
142 177 self.assertTrue(isinstance(qs, dict))
143 178 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
144 179 allqs = self.client.queue_status()
145 180 self.assertTrue(isinstance(allqs, dict))
146 181 intkeys = list(allqs.keys())
147 182 intkeys.remove('unassigned')
148 183 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
149 184 unassigned = allqs.pop('unassigned')
150 185 for eid,qs in allqs.items():
151 186 self.assertTrue(isinstance(qs, dict))
152 187 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
153 188
154 189 def test_shutdown(self):
155 190 # self.addEngine(4)
156 191 ids = self.client.ids
157 192 id0 = ids[0]
158 193 self.client.shutdown(id0, block=True)
159 194 while id0 in self.client.ids:
160 195 time.sleep(0.1)
161 196 self.client.spin()
162 197
163 198 self.assertRaises(IndexError, lambda : self.client[id0])
164 199
165 200 def test_result_status(self):
166 201 pass
167 202 # to be written
168 203
169 204 def test_db_query_dt(self):
170 205 """test db query by date"""
171 206 hist = self.client.hub_history()
172 207 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
173 208 tic = middle['submitted']
174 209 before = self.client.db_query({'submitted' : {'$lt' : tic}})
175 210 after = self.client.db_query({'submitted' : {'$gte' : tic}})
176 211 self.assertEquals(len(before)+len(after),len(hist))
177 212 for b in before:
178 213 self.assertTrue(b['submitted'] < tic)
179 214 for a in after:
180 215 self.assertTrue(a['submitted'] >= tic)
181 216 same = self.client.db_query({'submitted' : tic})
182 217 for s in same:
183 218 self.assertTrue(s['submitted'] == tic)
184 219
185 220 def test_db_query_keys(self):
186 221 """test extracting subset of record keys"""
187 222 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
188 223 for rec in found:
189 224 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
190 225
191 226 def test_db_query_msg_id(self):
192 227 """ensure msg_id is always in db queries"""
193 228 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
194 229 for rec in found:
195 230 self.assertTrue('msg_id' in rec.keys())
196 231 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
197 232 for rec in found:
198 233 self.assertTrue('msg_id' in rec.keys())
199 234 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
200 235 for rec in found:
201 236 self.assertTrue('msg_id' in rec.keys())
202 237
203 238 def test_db_query_in(self):
204 239 """test db query with '$in','$nin' operators"""
205 240 hist = self.client.hub_history()
206 241 even = hist[::2]
207 242 odd = hist[1::2]
208 243 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
209 244 found = [ r['msg_id'] for r in recs ]
210 245 self.assertEquals(set(even), set(found))
211 246 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
212 247 found = [ r['msg_id'] for r in recs ]
213 248 self.assertEquals(set(odd), set(found))
214 249
215 250 def test_hub_history(self):
216 251 hist = self.client.hub_history()
217 252 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
218 253 recdict = {}
219 254 for rec in recs:
220 255 recdict[rec['msg_id']] = rec
221 256
222 257 latest = datetime(1984,1,1)
223 258 for msg_id in hist:
224 259 rec = recdict[msg_id]
225 260 newt = rec['submitted']
226 261 self.assertTrue(newt >= latest)
227 262 latest = newt
228 263 ar = self.client[-1].apply_async(lambda : 1)
229 264 ar.get()
230 265 time.sleep(0.25)
231 266 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
232 267
233 268 def test_resubmit(self):
234 269 def f():
235 270 import random
236 271 return random.random()
237 272 v = self.client.load_balanced_view()
238 273 ar = v.apply_async(f)
239 274 r1 = ar.get(1)
240 275 # give the Hub a chance to notice:
241 276 time.sleep(0.5)
242 277 ahr = self.client.resubmit(ar.msg_ids)
243 278 r2 = ahr.get(1)
244 279 self.assertFalse(r1 == r2)
245 280
246 281 def test_resubmit_inflight(self):
247 282 """ensure ValueError on resubmit of inflight task"""
248 283 v = self.client.load_balanced_view()
249 284 ar = v.apply_async(time.sleep,1)
250 285 # give the message a chance to arrive
251 286 time.sleep(0.2)
252 287 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
253 288 ar.get(2)
254 289
255 290 def test_resubmit_badkey(self):
256 291 """ensure KeyError on resubmit of nonexistant task"""
257 292 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
258 293
259 294 def test_purge_results(self):
260 295 # ensure there are some tasks
261 296 for i in range(5):
262 297 self.client[:].apply_sync(lambda : 1)
263 298 # Wait for the Hub to realise the result is done:
264 299 # This prevents a race condition, where we
265 300 # might purge a result the Hub still thinks is pending.
266 301 time.sleep(0.1)
267 302 rc2 = clientmod.Client(profile='iptest')
268 303 hist = self.client.hub_history()
269 304 ahr = rc2.get_result([hist[-1]])
270 305 ahr.wait(10)
271 306 self.client.purge_results(hist[-1])
272 307 newhist = self.client.hub_history()
273 308 self.assertEquals(len(newhist)+1,len(hist))
274 309 rc2.spin()
275 310 rc2.close()
276 311
277 312 def test_purge_all_results(self):
278 313 self.client.purge_results('all')
279 314 hist = self.client.hub_history()
280 315 self.assertEquals(len(hist), 0)
281 316
General Comments 0
You need to be logged in to leave comments. Login now