##// END OF EJS Templates
update client.get_result to match AsyncResult behavior
MinRK -
Show More
@@ -1,1824 +1,1831 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 from threading import Thread, Event
22 22 import time
23 23 import warnings
24 24 from datetime import datetime
25 25 from getpass import getpass
26 26 from pprint import pprint
27 27
28 28 pjoin = os.path.join
29 29
30 30 import zmq
31 31 # from zmq.eventloop import ioloop, zmqstream
32 32
33 33 from IPython.config.configurable import MultipleInstanceError
34 34 from IPython.core.application import BaseIPythonApplication
35 35 from IPython.core.profiledir import ProfileDir, ProfileDirError
36 36
37 37 from IPython.utils.coloransi import TermColors
38 38 from IPython.utils.jsonutil import rekey
39 39 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
40 40 from IPython.utils.path import get_ipython_dir
41 41 from IPython.utils.py3compat import cast_bytes
42 42 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
43 43 Dict, List, Bool, Set, Any)
44 44 from IPython.external.decorator import decorator
45 45 from IPython.external.ssh import tunnel
46 46
47 47 from IPython.parallel import Reference
48 48 from IPython.parallel import error
49 49 from IPython.parallel import util
50 50
51 51 from IPython.kernel.zmq.session import Session, Message
52 52 from IPython.kernel.zmq import serialize
53 53
54 54 from .asyncresult import AsyncResult, AsyncHubResult
55 55 from .view import DirectView, LoadBalancedView
56 56
57 57 if sys.version_info[0] >= 3:
58 58 # xrange is used in a couple 'isinstance' tests in py2
59 59 # should be just 'range' in 3k
60 60 xrange = range
61 61
62 62 #--------------------------------------------------------------------------
63 63 # Decorators for Client methods
64 64 #--------------------------------------------------------------------------
65 65
66 66 @decorator
67 67 def spin_first(f, self, *args, **kwargs):
68 68 """Call spin() to sync state prior to calling the method."""
69 69 self.spin()
70 70 return f(self, *args, **kwargs)
71 71
72 72
73 73 #--------------------------------------------------------------------------
74 74 # Classes
75 75 #--------------------------------------------------------------------------
76 76
77 77
78 78 class ExecuteReply(object):
79 79 """wrapper for finished Execute results"""
80 80 def __init__(self, msg_id, content, metadata):
81 81 self.msg_id = msg_id
82 82 self._content = content
83 83 self.execution_count = content['execution_count']
84 84 self.metadata = metadata
85 85
86 86 def __getitem__(self, key):
87 87 return self.metadata[key]
88 88
89 89 def __getattr__(self, key):
90 90 if key not in self.metadata:
91 91 raise AttributeError(key)
92 92 return self.metadata[key]
93 93
94 94 def __repr__(self):
95 95 pyout = self.metadata['pyout'] or {'data':{}}
96 96 text_out = pyout['data'].get('text/plain', '')
97 97 if len(text_out) > 32:
98 98 text_out = text_out[:29] + '...'
99 99
100 100 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
101 101
102 102 def _repr_pretty_(self, p, cycle):
103 103 pyout = self.metadata['pyout'] or {'data':{}}
104 104 text_out = pyout['data'].get('text/plain', '')
105 105
106 106 if not text_out:
107 107 return
108 108
109 109 try:
110 110 ip = get_ipython()
111 111 except NameError:
112 112 colors = "NoColor"
113 113 else:
114 114 colors = ip.colors
115 115
116 116 if colors == "NoColor":
117 117 out = normal = ""
118 118 else:
119 119 out = TermColors.Red
120 120 normal = TermColors.Normal
121 121
122 122 if '\n' in text_out and not text_out.startswith('\n'):
123 123 # add newline for multiline reprs
124 124 text_out = '\n' + text_out
125 125
126 126 p.text(
127 127 out + u'Out[%i:%i]: ' % (
128 128 self.metadata['engine_id'], self.execution_count
129 129 ) + normal + text_out
130 130 )
131 131
132 132 def _repr_html_(self):
133 133 pyout = self.metadata['pyout'] or {'data':{}}
134 134 return pyout['data'].get("text/html")
135 135
136 136 def _repr_latex_(self):
137 137 pyout = self.metadata['pyout'] or {'data':{}}
138 138 return pyout['data'].get("text/latex")
139 139
140 140 def _repr_json_(self):
141 141 pyout = self.metadata['pyout'] or {'data':{}}
142 142 return pyout['data'].get("application/json")
143 143
144 144 def _repr_javascript_(self):
145 145 pyout = self.metadata['pyout'] or {'data':{}}
146 146 return pyout['data'].get("application/javascript")
147 147
148 148 def _repr_png_(self):
149 149 pyout = self.metadata['pyout'] or {'data':{}}
150 150 return pyout['data'].get("image/png")
151 151
152 152 def _repr_jpeg_(self):
153 153 pyout = self.metadata['pyout'] or {'data':{}}
154 154 return pyout['data'].get("image/jpeg")
155 155
156 156 def _repr_svg_(self):
157 157 pyout = self.metadata['pyout'] or {'data':{}}
158 158 return pyout['data'].get("image/svg+xml")
159 159
160 160
161 161 class Metadata(dict):
162 162 """Subclass of dict for initializing metadata values.
163 163
164 164 Attribute access works on keys.
165 165
166 166 These objects have a strict set of keys - errors will raise if you try
167 167 to add new keys.
168 168 """
169 169 def __init__(self, *args, **kwargs):
170 170 dict.__init__(self)
171 171 md = {'msg_id' : None,
172 172 'submitted' : None,
173 173 'started' : None,
174 174 'completed' : None,
175 175 'received' : None,
176 176 'engine_uuid' : None,
177 177 'engine_id' : None,
178 178 'follow' : None,
179 179 'after' : None,
180 180 'status' : None,
181 181
182 182 'pyin' : None,
183 183 'pyout' : None,
184 184 'pyerr' : None,
185 185 'stdout' : '',
186 186 'stderr' : '',
187 187 'outputs' : [],
188 188 'data': {},
189 189 'outputs_ready' : False,
190 190 }
191 191 self.update(md)
192 192 self.update(dict(*args, **kwargs))
193 193
194 194 def __getattr__(self, key):
195 195 """getattr aliased to getitem"""
196 196 if key in self.iterkeys():
197 197 return self[key]
198 198 else:
199 199 raise AttributeError(key)
200 200
201 201 def __setattr__(self, key, value):
202 202 """setattr aliased to setitem, with strict"""
203 203 if key in self.iterkeys():
204 204 self[key] = value
205 205 else:
206 206 raise AttributeError(key)
207 207
208 208 def __setitem__(self, key, value):
209 209 """strict static key enforcement"""
210 210 if key in self.iterkeys():
211 211 dict.__setitem__(self, key, value)
212 212 else:
213 213 raise KeyError(key)
214 214
215 215
216 216 class Client(HasTraits):
217 217 """A semi-synchronous client to the IPython ZMQ cluster
218 218
219 219 Parameters
220 220 ----------
221 221
222 222 url_file : str/unicode; path to ipcontroller-client.json
223 223 This JSON file should contain all the information needed to connect to a cluster,
224 224 and is likely the only argument needed.
225 225 Connection information for the Hub's registration. If a json connector
226 226 file is given, then likely no further configuration is necessary.
227 227 [Default: use profile]
228 228 profile : bytes
229 229 The name of the Cluster profile to be used to find connector information.
230 230 If run from an IPython application, the default profile will be the same
231 231 as the running application, otherwise it will be 'default'.
232 232 cluster_id : str
233 233 String id to added to runtime files, to prevent name collisions when using
234 234 multiple clusters with a single profile simultaneously.
235 235 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
236 236 Since this is text inserted into filenames, typical recommendations apply:
237 237 Simple character strings are ideal, and spaces are not recommended (but
238 238 should generally work)
239 239 context : zmq.Context
240 240 Pass an existing zmq.Context instance, otherwise the client will create its own.
241 241 debug : bool
242 242 flag for lots of message printing for debug purposes
243 243 timeout : int/float
244 244 time (in seconds) to wait for connection replies from the Hub
245 245 [Default: 10]
246 246
247 247 #-------------- session related args ----------------
248 248
249 249 config : Config object
250 250 If specified, this will be relayed to the Session for configuration
251 251 username : str
252 252 set username for the session object
253 253
254 254 #-------------- ssh related args ----------------
255 255 # These are args for configuring the ssh tunnel to be used
256 256 # credentials are used to forward connections over ssh to the Controller
257 257 # Note that the ip given in `addr` needs to be relative to sshserver
258 258 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
259 259 # and set sshserver as the same machine the Controller is on. However,
260 260 # the only requirement is that sshserver is able to see the Controller
261 261 # (i.e. is within the same trusted network).
262 262
263 263 sshserver : str
264 264 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
265 265 If keyfile or password is specified, and this is not, it will default to
266 266 the ip given in addr.
267 267 sshkey : str; path to ssh private key file
268 268 This specifies a key to be used in ssh login, default None.
269 269 Regular default ssh keys will be used without specifying this argument.
270 270 password : str
271 271 Your ssh password to sshserver. Note that if this is left None,
272 272 you will be prompted for it if passwordless key based login is unavailable.
273 273 paramiko : bool
274 274 flag for whether to use paramiko instead of shell ssh for tunneling.
275 275 [default: True on win32, False else]
276 276
277 277
278 278 Attributes
279 279 ----------
280 280
281 281 ids : list of int engine IDs
282 282 requesting the ids attribute always synchronizes
283 283 the registration state. To request ids without synchronization,
284 284 use semi-private _ids attributes.
285 285
286 286 history : list of msg_ids
287 287 a list of msg_ids, keeping track of all the execution
288 288 messages you have submitted in order.
289 289
290 290 outstanding : set of msg_ids
291 291 a set of msg_ids that have been submitted, but whose
292 292 results have not yet been received.
293 293
294 294 results : dict
295 295 a dict of all our results, keyed by msg_id
296 296
297 297 block : bool
298 298 determines default behavior when block not specified
299 299 in execution methods
300 300
301 301 Methods
302 302 -------
303 303
304 304 spin
305 305 flushes incoming results and registration state changes
306 306 control methods spin, and requesting `ids` also ensures up to date
307 307
308 308 wait
309 309 wait on one or more msg_ids
310 310
311 311 execution methods
312 312 apply
313 313 legacy: execute, run
314 314
315 315 data movement
316 316 push, pull, scatter, gather
317 317
318 318 query methods
319 319 queue_status, get_result, purge, result_status
320 320
321 321 control methods
322 322 abort, shutdown
323 323
324 324 """
325 325
326 326
327 327 block = Bool(False)
328 328 outstanding = Set()
329 329 results = Instance('collections.defaultdict', (dict,))
330 330 metadata = Instance('collections.defaultdict', (Metadata,))
331 331 history = List()
332 332 debug = Bool(False)
333 333 _spin_thread = Any()
334 334 _stop_spinning = Any()
335 335
336 336 profile=Unicode()
337 337 def _profile_default(self):
338 338 if BaseIPythonApplication.initialized():
339 339 # an IPython app *might* be running, try to get its profile
340 340 try:
341 341 return BaseIPythonApplication.instance().profile
342 342 except (AttributeError, MultipleInstanceError):
343 343 # could be a *different* subclass of config.Application,
344 344 # which would raise one of these two errors.
345 345 return u'default'
346 346 else:
347 347 return u'default'
348 348
349 349
350 350 _outstanding_dict = Instance('collections.defaultdict', (set,))
351 351 _ids = List()
352 352 _connected=Bool(False)
353 353 _ssh=Bool(False)
354 354 _context = Instance('zmq.Context')
355 355 _config = Dict()
356 356 _engines=Instance(util.ReverseDict, (), {})
357 357 # _hub_socket=Instance('zmq.Socket')
358 358 _query_socket=Instance('zmq.Socket')
359 359 _control_socket=Instance('zmq.Socket')
360 360 _iopub_socket=Instance('zmq.Socket')
361 361 _notification_socket=Instance('zmq.Socket')
362 362 _mux_socket=Instance('zmq.Socket')
363 363 _task_socket=Instance('zmq.Socket')
364 364 _task_scheme=Unicode()
365 365 _closed = False
366 366 _ignored_control_replies=Integer(0)
367 367 _ignored_hub_replies=Integer(0)
368 368
369 369 def __new__(self, *args, **kw):
370 370 # don't raise on positional args
371 371 return HasTraits.__new__(self, **kw)
372 372
373 373 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
374 374 context=None, debug=False,
375 375 sshserver=None, sshkey=None, password=None, paramiko=None,
376 376 timeout=10, cluster_id=None, **extra_args
377 377 ):
378 378 if profile:
379 379 super(Client, self).__init__(debug=debug, profile=profile)
380 380 else:
381 381 super(Client, self).__init__(debug=debug)
382 382 if context is None:
383 383 context = zmq.Context.instance()
384 384 self._context = context
385 385 self._stop_spinning = Event()
386 386
387 387 if 'url_or_file' in extra_args:
388 388 url_file = extra_args['url_or_file']
389 389 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
390 390
391 391 if url_file and util.is_url(url_file):
392 392 raise ValueError("single urls cannot be specified, url-files must be used.")
393 393
394 394 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
395 395
396 396 if self._cd is not None:
397 397 if url_file is None:
398 398 if not cluster_id:
399 399 client_json = 'ipcontroller-client.json'
400 400 else:
401 401 client_json = 'ipcontroller-%s-client.json' % cluster_id
402 402 url_file = pjoin(self._cd.security_dir, client_json)
403 403 if url_file is None:
404 404 raise ValueError(
405 405 "I can't find enough information to connect to a hub!"
406 406 " Please specify at least one of url_file or profile."
407 407 )
408 408
409 409 with open(url_file) as f:
410 410 cfg = json.load(f)
411 411
412 412 self._task_scheme = cfg['task_scheme']
413 413
414 414 # sync defaults from args, json:
415 415 if sshserver:
416 416 cfg['ssh'] = sshserver
417 417
418 418 location = cfg.setdefault('location', None)
419 419
420 420 proto,addr = cfg['interface'].split('://')
421 421 addr = util.disambiguate_ip_address(addr, location)
422 422 cfg['interface'] = "%s://%s" % (proto, addr)
423 423
424 424 # turn interface,port into full urls:
425 425 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
426 426 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
427 427
428 428 url = cfg['registration']
429 429
430 430 if location is not None and addr == LOCALHOST:
431 431 # location specified, and connection is expected to be local
432 432 if location not in LOCAL_IPS and not sshserver:
433 433 # load ssh from JSON *only* if the controller is not on
434 434 # this machine
435 435 sshserver=cfg['ssh']
436 436 if location not in LOCAL_IPS and not sshserver:
437 437 # warn if no ssh specified, but SSH is probably needed
438 438 # This is only a warning, because the most likely cause
439 439 # is a local Controller on a laptop whose IP is dynamic
440 440 warnings.warn("""
441 441 Controller appears to be listening on localhost, but not on this machine.
442 442 If this is true, you should specify Client(...,sshserver='you@%s')
443 443 or instruct your controller to listen on an external IP."""%location,
444 444 RuntimeWarning)
445 445 elif not sshserver:
446 446 # otherwise sync with cfg
447 447 sshserver = cfg['ssh']
448 448
449 449 self._config = cfg
450 450
451 451 self._ssh = bool(sshserver or sshkey or password)
452 452 if self._ssh and sshserver is None:
453 453 # default to ssh via localhost
454 454 sshserver = addr
455 455 if self._ssh and password is None:
456 456 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
457 457 password=False
458 458 else:
459 459 password = getpass("SSH Password for %s: "%sshserver)
460 460 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
461 461
462 462 # configure and construct the session
463 463 extra_args['packer'] = cfg['pack']
464 464 extra_args['unpacker'] = cfg['unpack']
465 465 extra_args['key'] = cast_bytes(cfg['exec_key'])
466 466
467 467 self.session = Session(**extra_args)
468 468
469 469 self._query_socket = self._context.socket(zmq.DEALER)
470 470
471 471 if self._ssh:
472 472 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
473 473 else:
474 474 self._query_socket.connect(cfg['registration'])
475 475
476 476 self.session.debug = self.debug
477 477
478 478 self._notification_handlers = {'registration_notification' : self._register_engine,
479 479 'unregistration_notification' : self._unregister_engine,
480 480 'shutdown_notification' : lambda msg: self.close(),
481 481 }
482 482 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
483 483 'apply_reply' : self._handle_apply_reply}
484 484 self._connect(sshserver, ssh_kwargs, timeout)
485 485
486 486 # last step: setup magics, if we are in IPython:
487 487
488 488 try:
489 489 ip = get_ipython()
490 490 except NameError:
491 491 return
492 492 else:
493 493 if 'px' not in ip.magics_manager.magics:
494 494 # in IPython but we are the first Client.
495 495 # activate a default view for parallel magics.
496 496 self.activate()
497 497
498 498 def __del__(self):
499 499 """cleanup sockets, but _not_ context."""
500 500 self.close()
501 501
502 502 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
503 503 if ipython_dir is None:
504 504 ipython_dir = get_ipython_dir()
505 505 if profile_dir is not None:
506 506 try:
507 507 self._cd = ProfileDir.find_profile_dir(profile_dir)
508 508 return
509 509 except ProfileDirError:
510 510 pass
511 511 elif profile is not None:
512 512 try:
513 513 self._cd = ProfileDir.find_profile_dir_by_name(
514 514 ipython_dir, profile)
515 515 return
516 516 except ProfileDirError:
517 517 pass
518 518 self._cd = None
519 519
520 520 def _update_engines(self, engines):
521 521 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
522 522 for k,v in engines.iteritems():
523 523 eid = int(k)
524 524 if eid not in self._engines:
525 525 self._ids.append(eid)
526 526 self._engines[eid] = v
527 527 self._ids = sorted(self._ids)
528 528 if sorted(self._engines.keys()) != range(len(self._engines)) and \
529 529 self._task_scheme == 'pure' and self._task_socket:
530 530 self._stop_scheduling_tasks()
531 531
532 532 def _stop_scheduling_tasks(self):
533 533 """Stop scheduling tasks because an engine has been unregistered
534 534 from a pure ZMQ scheduler.
535 535 """
536 536 self._task_socket.close()
537 537 self._task_socket = None
538 538 msg = "An engine has been unregistered, and we are using pure " +\
539 539 "ZMQ task scheduling. Task farming will be disabled."
540 540 if self.outstanding:
541 541 msg += " If you were running tasks when this happened, " +\
542 542 "some `outstanding` msg_ids may never resolve."
543 543 warnings.warn(msg, RuntimeWarning)
544 544
545 545 def _build_targets(self, targets):
546 546 """Turn valid target IDs or 'all' into two lists:
547 547 (int_ids, uuids).
548 548 """
549 549 if not self._ids:
550 550 # flush notification socket if no engines yet, just in case
551 551 if not self.ids:
552 552 raise error.NoEnginesRegistered("Can't build targets without any engines")
553 553
554 554 if targets is None:
555 555 targets = self._ids
556 556 elif isinstance(targets, basestring):
557 557 if targets.lower() == 'all':
558 558 targets = self._ids
559 559 else:
560 560 raise TypeError("%r not valid str target, must be 'all'"%(targets))
561 561 elif isinstance(targets, int):
562 562 if targets < 0:
563 563 targets = self.ids[targets]
564 564 if targets not in self._ids:
565 565 raise IndexError("No such engine: %i"%targets)
566 566 targets = [targets]
567 567
568 568 if isinstance(targets, slice):
569 569 indices = range(len(self._ids))[targets]
570 570 ids = self.ids
571 571 targets = [ ids[i] for i in indices ]
572 572
573 573 if not isinstance(targets, (tuple, list, xrange)):
574 574 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
575 575
576 576 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
577 577
578 578 def _connect(self, sshserver, ssh_kwargs, timeout):
579 579 """setup all our socket connections to the cluster. This is called from
580 580 __init__."""
581 581
582 582 # Maybe allow reconnecting?
583 583 if self._connected:
584 584 return
585 585 self._connected=True
586 586
587 587 def connect_socket(s, url):
588 588 # url = util.disambiguate_url(url, self._config['location'])
589 589 if self._ssh:
590 590 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
591 591 else:
592 592 return s.connect(url)
593 593
594 594 self.session.send(self._query_socket, 'connection_request')
595 595 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
596 596 poller = zmq.Poller()
597 597 poller.register(self._query_socket, zmq.POLLIN)
598 598 # poll expects milliseconds, timeout is seconds
599 599 evts = poller.poll(timeout*1000)
600 600 if not evts:
601 601 raise error.TimeoutError("Hub connection request timed out")
602 602 idents,msg = self.session.recv(self._query_socket,mode=0)
603 603 if self.debug:
604 604 pprint(msg)
605 605 content = msg['content']
606 606 # self._config['registration'] = dict(content)
607 607 cfg = self._config
608 608 if content['status'] == 'ok':
609 609 self._mux_socket = self._context.socket(zmq.DEALER)
610 610 connect_socket(self._mux_socket, cfg['mux'])
611 611
612 612 self._task_socket = self._context.socket(zmq.DEALER)
613 613 connect_socket(self._task_socket, cfg['task'])
614 614
615 615 self._notification_socket = self._context.socket(zmq.SUB)
616 616 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
617 617 connect_socket(self._notification_socket, cfg['notification'])
618 618
619 619 self._control_socket = self._context.socket(zmq.DEALER)
620 620 connect_socket(self._control_socket, cfg['control'])
621 621
622 622 self._iopub_socket = self._context.socket(zmq.SUB)
623 623 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
624 624 connect_socket(self._iopub_socket, cfg['iopub'])
625 625
626 626 self._update_engines(dict(content['engines']))
627 627 else:
628 628 self._connected = False
629 629 raise Exception("Failed to connect!")
630 630
631 631 #--------------------------------------------------------------------------
632 632 # handlers and callbacks for incoming messages
633 633 #--------------------------------------------------------------------------
634 634
635 635 def _unwrap_exception(self, content):
636 636 """unwrap exception, and remap engine_id to int."""
637 637 e = error.unwrap_exception(content)
638 638 # print e.traceback
639 639 if e.engine_info:
640 640 e_uuid = e.engine_info['engine_uuid']
641 641 eid = self._engines[e_uuid]
642 642 e.engine_info['engine_id'] = eid
643 643 return e
644 644
645 645 def _extract_metadata(self, msg):
646 646 header = msg['header']
647 647 parent = msg['parent_header']
648 648 msg_meta = msg['metadata']
649 649 content = msg['content']
650 650 md = {'msg_id' : parent['msg_id'],
651 651 'received' : datetime.now(),
652 652 'engine_uuid' : msg_meta.get('engine', None),
653 653 'follow' : msg_meta.get('follow', []),
654 654 'after' : msg_meta.get('after', []),
655 655 'status' : content['status'],
656 656 }
657 657
658 658 if md['engine_uuid'] is not None:
659 659 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
660 660
661 661 if 'date' in parent:
662 662 md['submitted'] = parent['date']
663 663 if 'started' in msg_meta:
664 664 md['started'] = msg_meta['started']
665 665 if 'date' in header:
666 666 md['completed'] = header['date']
667 667 return md
668 668
669 669 def _register_engine(self, msg):
670 670 """Register a new engine, and update our connection info."""
671 671 content = msg['content']
672 672 eid = content['id']
673 673 d = {eid : content['uuid']}
674 674 self._update_engines(d)
675 675
676 676 def _unregister_engine(self, msg):
677 677 """Unregister an engine that has died."""
678 678 content = msg['content']
679 679 eid = int(content['id'])
680 680 if eid in self._ids:
681 681 self._ids.remove(eid)
682 682 uuid = self._engines.pop(eid)
683 683
684 684 self._handle_stranded_msgs(eid, uuid)
685 685
686 686 if self._task_socket and self._task_scheme == 'pure':
687 687 self._stop_scheduling_tasks()
688 688
689 689 def _handle_stranded_msgs(self, eid, uuid):
690 690 """Handle messages known to be on an engine when the engine unregisters.
691 691
692 692 It is possible that this will fire prematurely - that is, an engine will
693 693 go down after completing a result, and the client will be notified
694 694 of the unregistration and later receive the successful result.
695 695 """
696 696
697 697 outstanding = self._outstanding_dict[uuid]
698 698
699 699 for msg_id in list(outstanding):
700 700 if msg_id in self.results:
701 701 # we already
702 702 continue
703 703 try:
704 704 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
705 705 except:
706 706 content = error.wrap_exception()
707 707 # build a fake message:
708 708 msg = self.session.msg('apply_reply', content=content)
709 709 msg['parent_header']['msg_id'] = msg_id
710 710 msg['metadata']['engine'] = uuid
711 711 self._handle_apply_reply(msg)
712 712
713 713 def _handle_execute_reply(self, msg):
714 714 """Save the reply to an execute_request into our results.
715 715
716 716 execute messages are never actually used. apply is used instead.
717 717 """
718 718
719 719 parent = msg['parent_header']
720 720 msg_id = parent['msg_id']
721 721 if msg_id not in self.outstanding:
722 722 if msg_id in self.history:
723 723 print ("got stale result: %s"%msg_id)
724 724 else:
725 725 print ("got unknown result: %s"%msg_id)
726 726 else:
727 727 self.outstanding.remove(msg_id)
728 728
729 729 content = msg['content']
730 730 header = msg['header']
731 731
732 732 # construct metadata:
733 733 md = self.metadata[msg_id]
734 734 md.update(self._extract_metadata(msg))
735 735 # is this redundant?
736 736 self.metadata[msg_id] = md
737 737
738 738 e_outstanding = self._outstanding_dict[md['engine_uuid']]
739 739 if msg_id in e_outstanding:
740 740 e_outstanding.remove(msg_id)
741 741
742 742 # construct result:
743 743 if content['status'] == 'ok':
744 744 self.results[msg_id] = ExecuteReply(msg_id, content, md)
745 745 elif content['status'] == 'aborted':
746 746 self.results[msg_id] = error.TaskAborted(msg_id)
747 747 elif content['status'] == 'resubmitted':
748 748 # TODO: handle resubmission
749 749 pass
750 750 else:
751 751 self.results[msg_id] = self._unwrap_exception(content)
752 752
753 753 def _handle_apply_reply(self, msg):
754 754 """Save the reply to an apply_request into our results."""
755 755 parent = msg['parent_header']
756 756 msg_id = parent['msg_id']
757 757 if msg_id not in self.outstanding:
758 758 if msg_id in self.history:
759 759 print ("got stale result: %s"%msg_id)
760 760 print self.results[msg_id]
761 761 print msg
762 762 else:
763 763 print ("got unknown result: %s"%msg_id)
764 764 else:
765 765 self.outstanding.remove(msg_id)
766 766 content = msg['content']
767 767 header = msg['header']
768 768
769 769 # construct metadata:
770 770 md = self.metadata[msg_id]
771 771 md.update(self._extract_metadata(msg))
772 772 # is this redundant?
773 773 self.metadata[msg_id] = md
774 774
775 775 e_outstanding = self._outstanding_dict[md['engine_uuid']]
776 776 if msg_id in e_outstanding:
777 777 e_outstanding.remove(msg_id)
778 778
779 779 # construct result:
780 780 if content['status'] == 'ok':
781 781 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
782 782 elif content['status'] == 'aborted':
783 783 self.results[msg_id] = error.TaskAborted(msg_id)
784 784 elif content['status'] == 'resubmitted':
785 785 # TODO: handle resubmission
786 786 pass
787 787 else:
788 788 self.results[msg_id] = self._unwrap_exception(content)
789 789
790 790 def _flush_notifications(self):
791 791 """Flush notifications of engine registrations waiting
792 792 in ZMQ queue."""
793 793 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
794 794 while msg is not None:
795 795 if self.debug:
796 796 pprint(msg)
797 797 msg_type = msg['header']['msg_type']
798 798 handler = self._notification_handlers.get(msg_type, None)
799 799 if handler is None:
800 800 raise Exception("Unhandled message type: %s" % msg_type)
801 801 else:
802 802 handler(msg)
803 803 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
804 804
805 805 def _flush_results(self, sock):
806 806 """Flush task or queue results waiting in ZMQ queue."""
807 807 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
808 808 while msg is not None:
809 809 if self.debug:
810 810 pprint(msg)
811 811 msg_type = msg['header']['msg_type']
812 812 handler = self._queue_handlers.get(msg_type, None)
813 813 if handler is None:
814 814 raise Exception("Unhandled message type: %s" % msg_type)
815 815 else:
816 816 handler(msg)
817 817 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
818 818
819 819 def _flush_control(self, sock):
820 820 """Flush replies from the control channel waiting
821 821 in the ZMQ queue.
822 822
823 823 Currently: ignore them."""
824 824 if self._ignored_control_replies <= 0:
825 825 return
826 826 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
827 827 while msg is not None:
828 828 self._ignored_control_replies -= 1
829 829 if self.debug:
830 830 pprint(msg)
831 831 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
832 832
833 833 def _flush_ignored_control(self):
834 834 """flush ignored control replies"""
835 835 while self._ignored_control_replies > 0:
836 836 self.session.recv(self._control_socket)
837 837 self._ignored_control_replies -= 1
838 838
839 839 def _flush_ignored_hub_replies(self):
840 840 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
841 841 while msg is not None:
842 842 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
843 843
844 844 def _flush_iopub(self, sock):
845 845 """Flush replies from the iopub channel waiting
846 846 in the ZMQ queue.
847 847 """
848 848 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
849 849 while msg is not None:
850 850 if self.debug:
851 851 pprint(msg)
852 852 parent = msg['parent_header']
853 853 # ignore IOPub messages with no parent.
854 854 # Caused by print statements or warnings from before the first execution.
855 855 if not parent:
856 856 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
857 857 continue
858 858 msg_id = parent['msg_id']
859 859 content = msg['content']
860 860 header = msg['header']
861 861 msg_type = msg['header']['msg_type']
862 862
863 863 # init metadata:
864 864 md = self.metadata[msg_id]
865 865
866 866 if msg_type == 'stream':
867 867 name = content['name']
868 868 s = md[name] or ''
869 869 md[name] = s + content['data']
870 870 elif msg_type == 'pyerr':
871 871 md.update({'pyerr' : self._unwrap_exception(content)})
872 872 elif msg_type == 'pyin':
873 873 md.update({'pyin' : content['code']})
874 874 elif msg_type == 'display_data':
875 875 md['outputs'].append(content)
876 876 elif msg_type == 'pyout':
877 877 md['pyout'] = content
878 878 elif msg_type == 'data_message':
879 879 data, remainder = serialize.unserialize_object(msg['buffers'])
880 880 md['data'].update(data)
881 881 elif msg_type == 'status':
882 882 # idle message comes after all outputs
883 883 if content['execution_state'] == 'idle':
884 884 md['outputs_ready'] = True
885 885 else:
886 886 # unhandled msg_type (status, etc.)
887 887 pass
888 888
889 889 # reduntant?
890 890 self.metadata[msg_id] = md
891 891
892 892 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
893 893
894 894 #--------------------------------------------------------------------------
895 895 # len, getitem
896 896 #--------------------------------------------------------------------------
897 897
898 898 def __len__(self):
899 899 """len(client) returns # of engines."""
900 900 return len(self.ids)
901 901
902 902 def __getitem__(self, key):
903 903 """index access returns DirectView multiplexer objects
904 904
905 905 Must be int, slice, or list/tuple/xrange of ints"""
906 906 if not isinstance(key, (int, slice, tuple, list, xrange)):
907 907 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
908 908 else:
909 909 return self.direct_view(key)
910 910
911 911 #--------------------------------------------------------------------------
912 912 # Begin public methods
913 913 #--------------------------------------------------------------------------
914 914
915 915 @property
916 916 def ids(self):
917 917 """Always up-to-date ids property."""
918 918 self._flush_notifications()
919 919 # always copy:
920 920 return list(self._ids)
921 921
922 922 def activate(self, targets='all', suffix=''):
923 923 """Create a DirectView and register it with IPython magics
924 924
925 925 Defines the magics `%px, %autopx, %pxresult, %%px`
926 926
927 927 Parameters
928 928 ----------
929 929
930 930 targets: int, list of ints, or 'all'
931 931 The engines on which the view's magics will run
932 932 suffix: str [default: '']
933 933 The suffix, if any, for the magics. This allows you to have
934 934 multiple views associated with parallel magics at the same time.
935 935
936 936 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
937 937 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
938 938 on engine 0.
939 939 """
940 940 view = self.direct_view(targets)
941 941 view.block = True
942 942 view.activate(suffix)
943 943 return view
944 944
945 945 def close(self):
946 946 if self._closed:
947 947 return
948 948 self.stop_spin_thread()
949 949 snames = filter(lambda n: n.endswith('socket'), dir(self))
950 950 for socket in map(lambda name: getattr(self, name), snames):
951 951 if isinstance(socket, zmq.Socket) and not socket.closed:
952 952 socket.close()
953 953 self._closed = True
954 954
955 955 def _spin_every(self, interval=1):
956 956 """target func for use in spin_thread"""
957 957 while True:
958 958 if self._stop_spinning.is_set():
959 959 return
960 960 time.sleep(interval)
961 961 self.spin()
962 962
963 963 def spin_thread(self, interval=1):
964 964 """call Client.spin() in a background thread on some regular interval
965 965
966 966 This helps ensure that messages don't pile up too much in the zmq queue
967 967 while you are working on other things, or just leaving an idle terminal.
968 968
969 969 It also helps limit potential padding of the `received` timestamp
970 970 on AsyncResult objects, used for timings.
971 971
972 972 Parameters
973 973 ----------
974 974
975 975 interval : float, optional
976 976 The interval on which to spin the client in the background thread
977 977 (simply passed to time.sleep).
978 978
979 979 Notes
980 980 -----
981 981
982 982 For precision timing, you may want to use this method to put a bound
983 983 on the jitter (in seconds) in `received` timestamps used
984 984 in AsyncResult.wall_time.
985 985
986 986 """
987 987 if self._spin_thread is not None:
988 988 self.stop_spin_thread()
989 989 self._stop_spinning.clear()
990 990 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
991 991 self._spin_thread.daemon = True
992 992 self._spin_thread.start()
993 993
994 994 def stop_spin_thread(self):
995 995 """stop background spin_thread, if any"""
996 996 if self._spin_thread is not None:
997 997 self._stop_spinning.set()
998 998 self._spin_thread.join()
999 999 self._spin_thread = None
1000 1000
1001 1001 def spin(self):
1002 1002 """Flush any registration notifications and execution results
1003 1003 waiting in the ZMQ queue.
1004 1004 """
1005 1005 if self._notification_socket:
1006 1006 self._flush_notifications()
1007 1007 if self._iopub_socket:
1008 1008 self._flush_iopub(self._iopub_socket)
1009 1009 if self._mux_socket:
1010 1010 self._flush_results(self._mux_socket)
1011 1011 if self._task_socket:
1012 1012 self._flush_results(self._task_socket)
1013 1013 if self._control_socket:
1014 1014 self._flush_control(self._control_socket)
1015 1015 if self._query_socket:
1016 1016 self._flush_ignored_hub_replies()
1017 1017
1018 1018 def wait(self, jobs=None, timeout=-1):
1019 1019 """waits on one or more `jobs`, for up to `timeout` seconds.
1020 1020
1021 1021 Parameters
1022 1022 ----------
1023 1023
1024 1024 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1025 1025 ints are indices to self.history
1026 1026 strs are msg_ids
1027 1027 default: wait on all outstanding messages
1028 1028 timeout : float
1029 1029 a time in seconds, after which to give up.
1030 1030 default is -1, which means no timeout
1031 1031
1032 1032 Returns
1033 1033 -------
1034 1034
1035 1035 True : when all msg_ids are done
1036 1036 False : timeout reached, some msg_ids still outstanding
1037 1037 """
1038 1038 tic = time.time()
1039 1039 if jobs is None:
1040 1040 theids = self.outstanding
1041 1041 else:
1042 1042 if isinstance(jobs, (int, basestring, AsyncResult)):
1043 1043 jobs = [jobs]
1044 1044 theids = set()
1045 1045 for job in jobs:
1046 1046 if isinstance(job, int):
1047 1047 # index access
1048 1048 job = self.history[job]
1049 1049 elif isinstance(job, AsyncResult):
1050 1050 map(theids.add, job.msg_ids)
1051 1051 continue
1052 1052 theids.add(job)
1053 1053 if not theids.intersection(self.outstanding):
1054 1054 return True
1055 1055 self.spin()
1056 1056 while theids.intersection(self.outstanding):
1057 1057 if timeout >= 0 and ( time.time()-tic ) > timeout:
1058 1058 break
1059 1059 time.sleep(1e-3)
1060 1060 self.spin()
1061 1061 return len(theids.intersection(self.outstanding)) == 0
1062 1062
1063 1063 #--------------------------------------------------------------------------
1064 1064 # Control methods
1065 1065 #--------------------------------------------------------------------------
1066 1066
1067 1067 @spin_first
1068 1068 def clear(self, targets=None, block=None):
1069 1069 """Clear the namespace in target(s)."""
1070 1070 block = self.block if block is None else block
1071 1071 targets = self._build_targets(targets)[0]
1072 1072 for t in targets:
1073 1073 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1074 1074 error = False
1075 1075 if block:
1076 1076 self._flush_ignored_control()
1077 1077 for i in range(len(targets)):
1078 1078 idents,msg = self.session.recv(self._control_socket,0)
1079 1079 if self.debug:
1080 1080 pprint(msg)
1081 1081 if msg['content']['status'] != 'ok':
1082 1082 error = self._unwrap_exception(msg['content'])
1083 1083 else:
1084 1084 self._ignored_control_replies += len(targets)
1085 1085 if error:
1086 1086 raise error
1087 1087
1088 1088
1089 1089 @spin_first
1090 1090 def abort(self, jobs=None, targets=None, block=None):
1091 1091 """Abort specific jobs from the execution queues of target(s).
1092 1092
1093 1093 This is a mechanism to prevent jobs that have already been submitted
1094 1094 from executing.
1095 1095
1096 1096 Parameters
1097 1097 ----------
1098 1098
1099 1099 jobs : msg_id, list of msg_ids, or AsyncResult
1100 1100 The jobs to be aborted
1101 1101
1102 1102 If unspecified/None: abort all outstanding jobs.
1103 1103
1104 1104 """
1105 1105 block = self.block if block is None else block
1106 1106 jobs = jobs if jobs is not None else list(self.outstanding)
1107 1107 targets = self._build_targets(targets)[0]
1108 1108
1109 1109 msg_ids = []
1110 1110 if isinstance(jobs, (basestring,AsyncResult)):
1111 1111 jobs = [jobs]
1112 1112 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1113 1113 if bad_ids:
1114 1114 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1115 1115 for j in jobs:
1116 1116 if isinstance(j, AsyncResult):
1117 1117 msg_ids.extend(j.msg_ids)
1118 1118 else:
1119 1119 msg_ids.append(j)
1120 1120 content = dict(msg_ids=msg_ids)
1121 1121 for t in targets:
1122 1122 self.session.send(self._control_socket, 'abort_request',
1123 1123 content=content, ident=t)
1124 1124 error = False
1125 1125 if block:
1126 1126 self._flush_ignored_control()
1127 1127 for i in range(len(targets)):
1128 1128 idents,msg = self.session.recv(self._control_socket,0)
1129 1129 if self.debug:
1130 1130 pprint(msg)
1131 1131 if msg['content']['status'] != 'ok':
1132 1132 error = self._unwrap_exception(msg['content'])
1133 1133 else:
1134 1134 self._ignored_control_replies += len(targets)
1135 1135 if error:
1136 1136 raise error
1137 1137
1138 1138 @spin_first
1139 1139 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1140 1140 """Terminates one or more engine processes, optionally including the hub.
1141 1141
1142 1142 Parameters
1143 1143 ----------
1144 1144
1145 1145 targets: list of ints or 'all' [default: all]
1146 1146 Which engines to shutdown.
1147 1147 hub: bool [default: False]
1148 1148 Whether to include the Hub. hub=True implies targets='all'.
1149 1149 block: bool [default: self.block]
1150 1150 Whether to wait for clean shutdown replies or not.
1151 1151 restart: bool [default: False]
1152 1152 NOT IMPLEMENTED
1153 1153 whether to restart engines after shutting them down.
1154 1154 """
1155 1155 from IPython.parallel.error import NoEnginesRegistered
1156 1156 if restart:
1157 1157 raise NotImplementedError("Engine restart is not yet implemented")
1158 1158
1159 1159 block = self.block if block is None else block
1160 1160 if hub:
1161 1161 targets = 'all'
1162 1162 try:
1163 1163 targets = self._build_targets(targets)[0]
1164 1164 except NoEnginesRegistered:
1165 1165 targets = []
1166 1166 for t in targets:
1167 1167 self.session.send(self._control_socket, 'shutdown_request',
1168 1168 content={'restart':restart},ident=t)
1169 1169 error = False
1170 1170 if block or hub:
1171 1171 self._flush_ignored_control()
1172 1172 for i in range(len(targets)):
1173 1173 idents,msg = self.session.recv(self._control_socket, 0)
1174 1174 if self.debug:
1175 1175 pprint(msg)
1176 1176 if msg['content']['status'] != 'ok':
1177 1177 error = self._unwrap_exception(msg['content'])
1178 1178 else:
1179 1179 self._ignored_control_replies += len(targets)
1180 1180
1181 1181 if hub:
1182 1182 time.sleep(0.25)
1183 1183 self.session.send(self._query_socket, 'shutdown_request')
1184 1184 idents,msg = self.session.recv(self._query_socket, 0)
1185 1185 if self.debug:
1186 1186 pprint(msg)
1187 1187 if msg['content']['status'] != 'ok':
1188 1188 error = self._unwrap_exception(msg['content'])
1189 1189
1190 1190 if error:
1191 1191 raise error
1192 1192
1193 1193 #--------------------------------------------------------------------------
1194 1194 # Execution related methods
1195 1195 #--------------------------------------------------------------------------
1196 1196
1197 1197 def _maybe_raise(self, result):
1198 1198 """wrapper for maybe raising an exception if apply failed."""
1199 1199 if isinstance(result, error.RemoteError):
1200 1200 raise result
1201 1201
1202 1202 return result
1203 1203
1204 1204 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1205 1205 ident=None):
1206 1206 """construct and send an apply message via a socket.
1207 1207
1208 1208 This is the principal method with which all engine execution is performed by views.
1209 1209 """
1210 1210
1211 1211 if self._closed:
1212 1212 raise RuntimeError("Client cannot be used after its sockets have been closed")
1213 1213
1214 1214 # defaults:
1215 1215 args = args if args is not None else []
1216 1216 kwargs = kwargs if kwargs is not None else {}
1217 1217 metadata = metadata if metadata is not None else {}
1218 1218
1219 1219 # validate arguments
1220 1220 if not callable(f) and not isinstance(f, Reference):
1221 1221 raise TypeError("f must be callable, not %s"%type(f))
1222 1222 if not isinstance(args, (tuple, list)):
1223 1223 raise TypeError("args must be tuple or list, not %s"%type(args))
1224 1224 if not isinstance(kwargs, dict):
1225 1225 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1226 1226 if not isinstance(metadata, dict):
1227 1227 raise TypeError("metadata must be dict, not %s"%type(metadata))
1228 1228
1229 1229 bufs = serialize.pack_apply_message(f, args, kwargs,
1230 1230 buffer_threshold=self.session.buffer_threshold,
1231 1231 item_threshold=self.session.item_threshold,
1232 1232 )
1233 1233
1234 1234 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1235 1235 metadata=metadata, track=track)
1236 1236
1237 1237 msg_id = msg['header']['msg_id']
1238 1238 self.outstanding.add(msg_id)
1239 1239 if ident:
1240 1240 # possibly routed to a specific engine
1241 1241 if isinstance(ident, list):
1242 1242 ident = ident[-1]
1243 1243 if ident in self._engines.values():
1244 1244 # save for later, in case of engine death
1245 1245 self._outstanding_dict[ident].add(msg_id)
1246 1246 self.history.append(msg_id)
1247 1247 self.metadata[msg_id]['submitted'] = datetime.now()
1248 1248
1249 1249 return msg
1250 1250
1251 1251 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1252 1252 """construct and send an execute request via a socket.
1253 1253
1254 1254 """
1255 1255
1256 1256 if self._closed:
1257 1257 raise RuntimeError("Client cannot be used after its sockets have been closed")
1258 1258
1259 1259 # defaults:
1260 1260 metadata = metadata if metadata is not None else {}
1261 1261
1262 1262 # validate arguments
1263 1263 if not isinstance(code, basestring):
1264 1264 raise TypeError("code must be text, not %s" % type(code))
1265 1265 if not isinstance(metadata, dict):
1266 1266 raise TypeError("metadata must be dict, not %s" % type(metadata))
1267 1267
1268 1268 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1269 1269
1270 1270
1271 1271 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1272 1272 metadata=metadata)
1273 1273
1274 1274 msg_id = msg['header']['msg_id']
1275 1275 self.outstanding.add(msg_id)
1276 1276 if ident:
1277 1277 # possibly routed to a specific engine
1278 1278 if isinstance(ident, list):
1279 1279 ident = ident[-1]
1280 1280 if ident in self._engines.values():
1281 1281 # save for later, in case of engine death
1282 1282 self._outstanding_dict[ident].add(msg_id)
1283 1283 self.history.append(msg_id)
1284 1284 self.metadata[msg_id]['submitted'] = datetime.now()
1285 1285
1286 1286 return msg
1287 1287
1288 1288 #--------------------------------------------------------------------------
1289 1289 # construct a View object
1290 1290 #--------------------------------------------------------------------------
1291 1291
1292 1292 def load_balanced_view(self, targets=None):
1293 1293 """construct a DirectView object.
1294 1294
1295 1295 If no arguments are specified, create a LoadBalancedView
1296 1296 using all engines.
1297 1297
1298 1298 Parameters
1299 1299 ----------
1300 1300
1301 1301 targets: list,slice,int,etc. [default: use all engines]
1302 1302 The subset of engines across which to load-balance
1303 1303 """
1304 1304 if targets == 'all':
1305 1305 targets = None
1306 1306 if targets is not None:
1307 1307 targets = self._build_targets(targets)[1]
1308 1308 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1309 1309
1310 1310 def direct_view(self, targets='all'):
1311 1311 """construct a DirectView object.
1312 1312
1313 1313 If no targets are specified, create a DirectView using all engines.
1314 1314
1315 1315 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1316 1316 evaluate the target engines at each execution, whereas rc[:] will connect to
1317 1317 all *current* engines, and that list will not change.
1318 1318
1319 1319 That is, 'all' will always use all engines, whereas rc[:] will not use
1320 1320 engines added after the DirectView is constructed.
1321 1321
1322 1322 Parameters
1323 1323 ----------
1324 1324
1325 1325 targets: list,slice,int,etc. [default: use all engines]
1326 1326 The engines to use for the View
1327 1327 """
1328 1328 single = isinstance(targets, int)
1329 1329 # allow 'all' to be lazily evaluated at each execution
1330 1330 if targets != 'all':
1331 1331 targets = self._build_targets(targets)[1]
1332 1332 if single:
1333 1333 targets = targets[0]
1334 1334 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1335 1335
1336 1336 #--------------------------------------------------------------------------
1337 1337 # Query methods
1338 1338 #--------------------------------------------------------------------------
1339 1339
1340 1340 @spin_first
1341 1341 def get_result(self, indices_or_msg_ids=None, block=None):
1342 1342 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1343 1343
1344 1344 If the client already has the results, no request to the Hub will be made.
1345 1345
1346 1346 This is a convenient way to construct AsyncResult objects, which are wrappers
1347 1347 that include metadata about execution, and allow for awaiting results that
1348 1348 were not submitted by this Client.
1349 1349
1350 1350 It can also be a convenient way to retrieve the metadata associated with
1351 1351 blocking execution, since it always retrieves
1352 1352
1353 1353 Examples
1354 1354 --------
1355 1355 ::
1356 1356
1357 1357 In [10]: r = client.apply()
1358 1358
1359 1359 Parameters
1360 1360 ----------
1361 1361
1362 1362 indices_or_msg_ids : integer history index, str msg_id, or list of either
1363 1363 The indices or msg_ids of indices to be retrieved
1364 1364
1365 1365 block : bool
1366 1366 Whether to wait for the result to be done
1367 1367
1368 1368 Returns
1369 1369 -------
1370 1370
1371 1371 AsyncResult
1372 1372 A single AsyncResult object will always be returned.
1373 1373
1374 1374 AsyncHubResult
1375 1375 A subclass of AsyncResult that retrieves results from the Hub
1376 1376
1377 1377 """
1378 1378 block = self.block if block is None else block
1379 1379 if indices_or_msg_ids is None:
1380 1380 indices_or_msg_ids = -1
1381
1381
1382 single_result = False
1382 1383 if not isinstance(indices_or_msg_ids, (list,tuple)):
1383 1384 indices_or_msg_ids = [indices_or_msg_ids]
1385 single_result = True
1384 1386
1385 1387 theids = []
1386 1388 for id in indices_or_msg_ids:
1387 1389 if isinstance(id, int):
1388 1390 id = self.history[id]
1389 1391 if not isinstance(id, basestring):
1390 1392 raise TypeError("indices must be str or int, not %r"%id)
1391 1393 theids.append(id)
1392 1394
1393 1395 local_ids = filter(lambda msg_id: msg_id in self.outstanding or msg_id in self.results, theids)
1394 1396 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1397
1398 # given single msg_id initially, get_result shot get the result itself,
1399 # not a length-one list
1400 if single_result:
1401 theids = theids[0]
1395 1402
1396 1403 if remote_ids:
1397 1404 ar = AsyncHubResult(self, msg_ids=theids)
1398 1405 else:
1399 1406 ar = AsyncResult(self, msg_ids=theids)
1400 1407
1401 1408 if block:
1402 1409 ar.wait()
1403 1410
1404 1411 return ar
1405 1412
1406 1413 @spin_first
1407 1414 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1408 1415 """Resubmit one or more tasks.
1409 1416
1410 1417 in-flight tasks may not be resubmitted.
1411 1418
1412 1419 Parameters
1413 1420 ----------
1414 1421
1415 1422 indices_or_msg_ids : integer history index, str msg_id, or list of either
1416 1423 The indices or msg_ids of indices to be retrieved
1417 1424
1418 1425 block : bool
1419 1426 Whether to wait for the result to be done
1420 1427
1421 1428 Returns
1422 1429 -------
1423 1430
1424 1431 AsyncHubResult
1425 1432 A subclass of AsyncResult that retrieves results from the Hub
1426 1433
1427 1434 """
1428 1435 block = self.block if block is None else block
1429 1436 if indices_or_msg_ids is None:
1430 1437 indices_or_msg_ids = -1
1431 1438
1432 1439 if not isinstance(indices_or_msg_ids, (list,tuple)):
1433 1440 indices_or_msg_ids = [indices_or_msg_ids]
1434 1441
1435 1442 theids = []
1436 1443 for id in indices_or_msg_ids:
1437 1444 if isinstance(id, int):
1438 1445 id = self.history[id]
1439 1446 if not isinstance(id, basestring):
1440 1447 raise TypeError("indices must be str or int, not %r"%id)
1441 1448 theids.append(id)
1442 1449
1443 1450 content = dict(msg_ids = theids)
1444 1451
1445 1452 self.session.send(self._query_socket, 'resubmit_request', content)
1446 1453
1447 1454 zmq.select([self._query_socket], [], [])
1448 1455 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1449 1456 if self.debug:
1450 1457 pprint(msg)
1451 1458 content = msg['content']
1452 1459 if content['status'] != 'ok':
1453 1460 raise self._unwrap_exception(content)
1454 1461 mapping = content['resubmitted']
1455 1462 new_ids = [ mapping[msg_id] for msg_id in theids ]
1456 1463
1457 1464 ar = AsyncHubResult(self, msg_ids=new_ids)
1458 1465
1459 1466 if block:
1460 1467 ar.wait()
1461 1468
1462 1469 return ar
1463 1470
1464 1471 @spin_first
1465 1472 def result_status(self, msg_ids, status_only=True):
1466 1473 """Check on the status of the result(s) of the apply request with `msg_ids`.
1467 1474
1468 1475 If status_only is False, then the actual results will be retrieved, else
1469 1476 only the status of the results will be checked.
1470 1477
1471 1478 Parameters
1472 1479 ----------
1473 1480
1474 1481 msg_ids : list of msg_ids
1475 1482 if int:
1476 1483 Passed as index to self.history for convenience.
1477 1484 status_only : bool (default: True)
1478 1485 if False:
1479 1486 Retrieve the actual results of completed tasks.
1480 1487
1481 1488 Returns
1482 1489 -------
1483 1490
1484 1491 results : dict
1485 1492 There will always be the keys 'pending' and 'completed', which will
1486 1493 be lists of msg_ids that are incomplete or complete. If `status_only`
1487 1494 is False, then completed results will be keyed by their `msg_id`.
1488 1495 """
1489 1496 if not isinstance(msg_ids, (list,tuple)):
1490 1497 msg_ids = [msg_ids]
1491 1498
1492 1499 theids = []
1493 1500 for msg_id in msg_ids:
1494 1501 if isinstance(msg_id, int):
1495 1502 msg_id = self.history[msg_id]
1496 1503 if not isinstance(msg_id, basestring):
1497 1504 raise TypeError("msg_ids must be str, not %r"%msg_id)
1498 1505 theids.append(msg_id)
1499 1506
1500 1507 completed = []
1501 1508 local_results = {}
1502 1509
1503 1510 # comment this block out to temporarily disable local shortcut:
1504 1511 for msg_id in theids:
1505 1512 if msg_id in self.results:
1506 1513 completed.append(msg_id)
1507 1514 local_results[msg_id] = self.results[msg_id]
1508 1515 theids.remove(msg_id)
1509 1516
1510 1517 if theids: # some not locally cached
1511 1518 content = dict(msg_ids=theids, status_only=status_only)
1512 1519 msg = self.session.send(self._query_socket, "result_request", content=content)
1513 1520 zmq.select([self._query_socket], [], [])
1514 1521 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1515 1522 if self.debug:
1516 1523 pprint(msg)
1517 1524 content = msg['content']
1518 1525 if content['status'] != 'ok':
1519 1526 raise self._unwrap_exception(content)
1520 1527 buffers = msg['buffers']
1521 1528 else:
1522 1529 content = dict(completed=[],pending=[])
1523 1530
1524 1531 content['completed'].extend(completed)
1525 1532
1526 1533 if status_only:
1527 1534 return content
1528 1535
1529 1536 failures = []
1530 1537 # load cached results into result:
1531 1538 content.update(local_results)
1532 1539
1533 1540 # update cache with results:
1534 1541 for msg_id in sorted(theids):
1535 1542 if msg_id in content['completed']:
1536 1543 rec = content[msg_id]
1537 1544 parent = rec['header']
1538 1545 header = rec['result_header']
1539 1546 rcontent = rec['result_content']
1540 1547 iodict = rec['io']
1541 1548 if isinstance(rcontent, str):
1542 1549 rcontent = self.session.unpack(rcontent)
1543 1550
1544 1551 md = self.metadata[msg_id]
1545 1552 md_msg = dict(
1546 1553 content=rcontent,
1547 1554 parent_header=parent,
1548 1555 header=header,
1549 1556 metadata=rec['result_metadata'],
1550 1557 )
1551 1558 md.update(self._extract_metadata(md_msg))
1552 1559 if rec.get('received'):
1553 1560 md['received'] = rec['received']
1554 1561 md.update(iodict)
1555 1562
1556 1563 if rcontent['status'] == 'ok':
1557 1564 if header['msg_type'] == 'apply_reply':
1558 1565 res,buffers = serialize.unserialize_object(buffers)
1559 1566 elif header['msg_type'] == 'execute_reply':
1560 1567 res = ExecuteReply(msg_id, rcontent, md)
1561 1568 else:
1562 1569 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1563 1570 else:
1564 1571 res = self._unwrap_exception(rcontent)
1565 1572 failures.append(res)
1566 1573
1567 1574 self.results[msg_id] = res
1568 1575 content[msg_id] = res
1569 1576
1570 1577 if len(theids) == 1 and failures:
1571 1578 raise failures[0]
1572 1579
1573 1580 error.collect_exceptions(failures, "result_status")
1574 1581 return content
1575 1582
1576 1583 @spin_first
1577 1584 def queue_status(self, targets='all', verbose=False):
1578 1585 """Fetch the status of engine queues.
1579 1586
1580 1587 Parameters
1581 1588 ----------
1582 1589
1583 1590 targets : int/str/list of ints/strs
1584 1591 the engines whose states are to be queried.
1585 1592 default : all
1586 1593 verbose : bool
1587 1594 Whether to return lengths only, or lists of ids for each element
1588 1595 """
1589 1596 if targets == 'all':
1590 1597 # allow 'all' to be evaluated on the engine
1591 1598 engine_ids = None
1592 1599 else:
1593 1600 engine_ids = self._build_targets(targets)[1]
1594 1601 content = dict(targets=engine_ids, verbose=verbose)
1595 1602 self.session.send(self._query_socket, "queue_request", content=content)
1596 1603 idents,msg = self.session.recv(self._query_socket, 0)
1597 1604 if self.debug:
1598 1605 pprint(msg)
1599 1606 content = msg['content']
1600 1607 status = content.pop('status')
1601 1608 if status != 'ok':
1602 1609 raise self._unwrap_exception(content)
1603 1610 content = rekey(content)
1604 1611 if isinstance(targets, int):
1605 1612 return content[targets]
1606 1613 else:
1607 1614 return content
1608 1615
1609 1616 def _build_msgids_from_target(self, targets=None):
1610 1617 """Build a list of msg_ids from the list of engine targets"""
1611 1618 if not targets: # needed as _build_targets otherwise uses all engines
1612 1619 return []
1613 1620 target_ids = self._build_targets(targets)[0]
1614 1621 return filter(lambda md_id: self.metadata[md_id]["engine_uuid"] in target_ids, self.metadata)
1615 1622
1616 1623 def _build_msgids_from_jobs(self, jobs=None):
1617 1624 """Build a list of msg_ids from "jobs" """
1618 1625 if not jobs:
1619 1626 return []
1620 1627 msg_ids = []
1621 1628 if isinstance(jobs, (basestring,AsyncResult)):
1622 1629 jobs = [jobs]
1623 1630 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1624 1631 if bad_ids:
1625 1632 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1626 1633 for j in jobs:
1627 1634 if isinstance(j, AsyncResult):
1628 1635 msg_ids.extend(j.msg_ids)
1629 1636 else:
1630 1637 msg_ids.append(j)
1631 1638 return msg_ids
1632 1639
1633 1640 def purge_local_results(self, jobs=[], targets=[]):
1634 1641 """Clears the client caches of results and frees such memory.
1635 1642
1636 1643 Individual results can be purged by msg_id, or the entire
1637 1644 history of specific targets can be purged.
1638 1645
1639 1646 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1640 1647
1641 1648 The client must have no outstanding tasks before purging the caches.
1642 1649 Raises `AssertionError` if there are still outstanding tasks.
1643 1650
1644 1651 After this call all `AsyncResults` are invalid and should be discarded.
1645 1652
1646 1653 If you must "reget" the results, you can still do so by using
1647 1654 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1648 1655 redownload the results from the hub if they are still available
1649 1656 (i.e `client.purge_hub_results(...)` has not been called.
1650 1657
1651 1658 Parameters
1652 1659 ----------
1653 1660
1654 1661 jobs : str or list of str or AsyncResult objects
1655 1662 the msg_ids whose results should be purged.
1656 1663 targets : int/str/list of ints/strs
1657 1664 The targets, by int_id, whose entire results are to be purged.
1658 1665
1659 1666 default : None
1660 1667 """
1661 1668 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1662 1669
1663 1670 if not targets and not jobs:
1664 1671 raise ValueError("Must specify at least one of `targets` and `jobs`")
1665 1672
1666 1673 if jobs == 'all':
1667 1674 self.results.clear()
1668 1675 self.metadata.clear()
1669 1676 return
1670 1677 else:
1671 1678 msg_ids = []
1672 1679 msg_ids.extend(self._build_msgids_from_target(targets))
1673 1680 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1674 1681 map(self.results.pop, msg_ids)
1675 1682 map(self.metadata.pop, msg_ids)
1676 1683
1677 1684
1678 1685 @spin_first
1679 1686 def purge_hub_results(self, jobs=[], targets=[]):
1680 1687 """Tell the Hub to forget results.
1681 1688
1682 1689 Individual results can be purged by msg_id, or the entire
1683 1690 history of specific targets can be purged.
1684 1691
1685 1692 Use `purge_results('all')` to scrub everything from the Hub's db.
1686 1693
1687 1694 Parameters
1688 1695 ----------
1689 1696
1690 1697 jobs : str or list of str or AsyncResult objects
1691 1698 the msg_ids whose results should be forgotten.
1692 1699 targets : int/str/list of ints/strs
1693 1700 The targets, by int_id, whose entire history is to be purged.
1694 1701
1695 1702 default : None
1696 1703 """
1697 1704 if not targets and not jobs:
1698 1705 raise ValueError("Must specify at least one of `targets` and `jobs`")
1699 1706 if targets:
1700 1707 targets = self._build_targets(targets)[1]
1701 1708
1702 1709 # construct msg_ids from jobs
1703 1710 if jobs == 'all':
1704 1711 msg_ids = jobs
1705 1712 else:
1706 1713 msg_ids = self._build_msgids_from_jobs(jobs)
1707 1714
1708 1715 content = dict(engine_ids=targets, msg_ids=msg_ids)
1709 1716 self.session.send(self._query_socket, "purge_request", content=content)
1710 1717 idents, msg = self.session.recv(self._query_socket, 0)
1711 1718 if self.debug:
1712 1719 pprint(msg)
1713 1720 content = msg['content']
1714 1721 if content['status'] != 'ok':
1715 1722 raise self._unwrap_exception(content)
1716 1723
1717 1724 def purge_results(self, jobs=[], targets=[]):
1718 1725 """Clears the cached results from both the hub and the local client
1719 1726
1720 1727 Individual results can be purged by msg_id, or the entire
1721 1728 history of specific targets can be purged.
1722 1729
1723 1730 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1724 1731 the Client's db.
1725 1732
1726 1733 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1727 1734 the same arguments.
1728 1735
1729 1736 Parameters
1730 1737 ----------
1731 1738
1732 1739 jobs : str or list of str or AsyncResult objects
1733 1740 the msg_ids whose results should be forgotten.
1734 1741 targets : int/str/list of ints/strs
1735 1742 The targets, by int_id, whose entire history is to be purged.
1736 1743
1737 1744 default : None
1738 1745 """
1739 1746 self.purge_local_results(jobs=jobs, targets=targets)
1740 1747 self.purge_hub_results(jobs=jobs, targets=targets)
1741 1748
1742 1749 def purge_everything(self):
1743 1750 """Clears all content from previous Tasks from both the hub and the local client
1744 1751
1745 1752 In addition to calling `purge_results("all")` it also deletes the history and
1746 1753 other bookkeeping lists.
1747 1754 """
1748 1755 self.purge_results("all")
1749 1756 self.history = []
1750 1757 self.session.digest_history.clear()
1751 1758
1752 1759 @spin_first
1753 1760 def hub_history(self):
1754 1761 """Get the Hub's history
1755 1762
1756 1763 Just like the Client, the Hub has a history, which is a list of msg_ids.
1757 1764 This will contain the history of all clients, and, depending on configuration,
1758 1765 may contain history across multiple cluster sessions.
1759 1766
1760 1767 Any msg_id returned here is a valid argument to `get_result`.
1761 1768
1762 1769 Returns
1763 1770 -------
1764 1771
1765 1772 msg_ids : list of strs
1766 1773 list of all msg_ids, ordered by task submission time.
1767 1774 """
1768 1775
1769 1776 self.session.send(self._query_socket, "history_request", content={})
1770 1777 idents, msg = self.session.recv(self._query_socket, 0)
1771 1778
1772 1779 if self.debug:
1773 1780 pprint(msg)
1774 1781 content = msg['content']
1775 1782 if content['status'] != 'ok':
1776 1783 raise self._unwrap_exception(content)
1777 1784 else:
1778 1785 return content['history']
1779 1786
1780 1787 @spin_first
1781 1788 def db_query(self, query, keys=None):
1782 1789 """Query the Hub's TaskRecord database
1783 1790
1784 1791 This will return a list of task record dicts that match `query`
1785 1792
1786 1793 Parameters
1787 1794 ----------
1788 1795
1789 1796 query : mongodb query dict
1790 1797 The search dict. See mongodb query docs for details.
1791 1798 keys : list of strs [optional]
1792 1799 The subset of keys to be returned. The default is to fetch everything but buffers.
1793 1800 'msg_id' will *always* be included.
1794 1801 """
1795 1802 if isinstance(keys, basestring):
1796 1803 keys = [keys]
1797 1804 content = dict(query=query, keys=keys)
1798 1805 self.session.send(self._query_socket, "db_request", content=content)
1799 1806 idents, msg = self.session.recv(self._query_socket, 0)
1800 1807 if self.debug:
1801 1808 pprint(msg)
1802 1809 content = msg['content']
1803 1810 if content['status'] != 'ok':
1804 1811 raise self._unwrap_exception(content)
1805 1812
1806 1813 records = content['records']
1807 1814
1808 1815 buffer_lens = content['buffer_lens']
1809 1816 result_buffer_lens = content['result_buffer_lens']
1810 1817 buffers = msg['buffers']
1811 1818 has_bufs = buffer_lens is not None
1812 1819 has_rbufs = result_buffer_lens is not None
1813 1820 for i,rec in enumerate(records):
1814 1821 # relink buffers
1815 1822 if has_bufs:
1816 1823 blen = buffer_lens[i]
1817 1824 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1818 1825 if has_rbufs:
1819 1826 blen = result_buffer_lens[i]
1820 1827 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1821 1828
1822 1829 return records
1823 1830
1824 1831 __all__ = [ 'Client' ]
@@ -1,517 +1,518 b''
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import time
22 22 from datetime import datetime
23 23 from tempfile import mktemp
24 24
25 25 import zmq
26 26
27 27 from IPython import parallel
28 28 from IPython.parallel.client import client as clientmod
29 29 from IPython.parallel import error
30 30 from IPython.parallel import AsyncResult, AsyncHubResult
31 31 from IPython.parallel import LoadBalancedView, DirectView
32 32
33 33 from clienttest import ClusterTestCase, segfault, wait, add_engines
34 34
35 35 def setup():
36 36 add_engines(4, total=True)
37 37
38 38 class TestClient(ClusterTestCase):
39 39
40 40 def test_ids(self):
41 41 n = len(self.client.ids)
42 42 self.add_engines(2)
43 43 self.assertEqual(len(self.client.ids), n+2)
44 44
45 45 def test_view_indexing(self):
46 46 """test index access for views"""
47 47 self.minimum_engines(4)
48 48 targets = self.client._build_targets('all')[-1]
49 49 v = self.client[:]
50 50 self.assertEqual(v.targets, targets)
51 51 t = self.client.ids[2]
52 52 v = self.client[t]
53 53 self.assertTrue(isinstance(v, DirectView))
54 54 self.assertEqual(v.targets, t)
55 55 t = self.client.ids[2:4]
56 56 v = self.client[t]
57 57 self.assertTrue(isinstance(v, DirectView))
58 58 self.assertEqual(v.targets, t)
59 59 v = self.client[::2]
60 60 self.assertTrue(isinstance(v, DirectView))
61 61 self.assertEqual(v.targets, targets[::2])
62 62 v = self.client[1::3]
63 63 self.assertTrue(isinstance(v, DirectView))
64 64 self.assertEqual(v.targets, targets[1::3])
65 65 v = self.client[:-3]
66 66 self.assertTrue(isinstance(v, DirectView))
67 67 self.assertEqual(v.targets, targets[:-3])
68 68 v = self.client[-1]
69 69 self.assertTrue(isinstance(v, DirectView))
70 70 self.assertEqual(v.targets, targets[-1])
71 71 self.assertRaises(TypeError, lambda : self.client[None])
72 72
73 73 def test_lbview_targets(self):
74 74 """test load_balanced_view targets"""
75 75 v = self.client.load_balanced_view()
76 76 self.assertEqual(v.targets, None)
77 77 v = self.client.load_balanced_view(-1)
78 78 self.assertEqual(v.targets, [self.client.ids[-1]])
79 79 v = self.client.load_balanced_view('all')
80 80 self.assertEqual(v.targets, None)
81 81
82 82 def test_dview_targets(self):
83 83 """test direct_view targets"""
84 84 v = self.client.direct_view()
85 85 self.assertEqual(v.targets, 'all')
86 86 v = self.client.direct_view('all')
87 87 self.assertEqual(v.targets, 'all')
88 88 v = self.client.direct_view(-1)
89 89 self.assertEqual(v.targets, self.client.ids[-1])
90 90
91 91 def test_lazy_all_targets(self):
92 92 """test lazy evaluation of rc.direct_view('all')"""
93 93 v = self.client.direct_view()
94 94 self.assertEqual(v.targets, 'all')
95 95
96 96 def double(x):
97 97 return x*2
98 98 seq = range(100)
99 99 ref = [ double(x) for x in seq ]
100 100
101 101 # add some engines, which should be used
102 102 self.add_engines(1)
103 103 n1 = len(self.client.ids)
104 104
105 105 # simple apply
106 106 r = v.apply_sync(lambda : 1)
107 107 self.assertEqual(r, [1] * n1)
108 108
109 109 # map goes through remotefunction
110 110 r = v.map_sync(double, seq)
111 111 self.assertEqual(r, ref)
112 112
113 113 # add a couple more engines, and try again
114 114 self.add_engines(2)
115 115 n2 = len(self.client.ids)
116 116 self.assertNotEqual(n2, n1)
117 117
118 118 # apply
119 119 r = v.apply_sync(lambda : 1)
120 120 self.assertEqual(r, [1] * n2)
121 121
122 122 # map
123 123 r = v.map_sync(double, seq)
124 124 self.assertEqual(r, ref)
125 125
126 126 def test_targets(self):
127 127 """test various valid targets arguments"""
128 128 build = self.client._build_targets
129 129 ids = self.client.ids
130 130 idents,targets = build(None)
131 131 self.assertEqual(ids, targets)
132 132
133 133 def test_clear(self):
134 134 """test clear behavior"""
135 135 self.minimum_engines(2)
136 136 v = self.client[:]
137 137 v.block=True
138 138 v.push(dict(a=5))
139 139 v.pull('a')
140 140 id0 = self.client.ids[-1]
141 141 self.client.clear(targets=id0, block=True)
142 142 a = self.client[:-1].get('a')
143 143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
144 144 self.client.clear(block=True)
145 145 for i in self.client.ids:
146 146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
147 147
148 148 def test_get_result(self):
149 149 """test getting results from the Hub."""
150 150 c = clientmod.Client(profile='iptest')
151 151 t = c.ids[-1]
152 152 ar = c[t].apply_async(wait, 1)
153 153 # give the monitor time to notice the message
154 154 time.sleep(.25)
155 ahr = self.client.get_result(ar.msg_ids)
155 ahr = self.client.get_result(ar.msg_ids[0])
156 156 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 157 self.assertEqual(ahr.get(), ar.get())
158 ar2 = self.client.get_result(ar.msg_ids)
158 ar2 = self.client.get_result(ar.msg_ids[0])
159 159 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 160 c.close()
161 161
162 162 def test_get_execute_result(self):
163 163 """test getting execute results from the Hub."""
164 164 c = clientmod.Client(profile='iptest')
165 165 t = c.ids[-1]
166 166 cell = '\n'.join([
167 167 'import time',
168 168 'time.sleep(0.25)',
169 169 '5'
170 170 ])
171 171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
172 172 # give the monitor time to notice the message
173 173 time.sleep(.25)
174 ahr = self.client.get_result(ar.msg_ids)
174 ahr = self.client.get_result(ar.msg_ids[0])
175 print ar.get(), ahr.get(), ar._single_result, ahr._single_result
175 176 self.assertTrue(isinstance(ahr, AsyncHubResult))
176 177 self.assertEqual(ahr.get().pyout, ar.get().pyout)
177 ar2 = self.client.get_result(ar.msg_ids)
178 ar2 = self.client.get_result(ar.msg_ids[0])
178 179 self.assertFalse(isinstance(ar2, AsyncHubResult))
179 180 c.close()
180 181
181 182 def test_ids_list(self):
182 183 """test client.ids"""
183 184 ids = self.client.ids
184 185 self.assertEqual(ids, self.client._ids)
185 186 self.assertFalse(ids is self.client._ids)
186 187 ids.remove(ids[-1])
187 188 self.assertNotEqual(ids, self.client._ids)
188 189
189 190 def test_queue_status(self):
190 191 ids = self.client.ids
191 192 id0 = ids[0]
192 193 qs = self.client.queue_status(targets=id0)
193 194 self.assertTrue(isinstance(qs, dict))
194 195 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
195 196 allqs = self.client.queue_status()
196 197 self.assertTrue(isinstance(allqs, dict))
197 198 intkeys = list(allqs.keys())
198 199 intkeys.remove('unassigned')
199 200 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
200 201 unassigned = allqs.pop('unassigned')
201 202 for eid,qs in allqs.items():
202 203 self.assertTrue(isinstance(qs, dict))
203 204 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
204 205
205 206 def test_shutdown(self):
206 207 ids = self.client.ids
207 208 id0 = ids[0]
208 209 self.client.shutdown(id0, block=True)
209 210 while id0 in self.client.ids:
210 211 time.sleep(0.1)
211 212 self.client.spin()
212 213
213 214 self.assertRaises(IndexError, lambda : self.client[id0])
214 215
215 216 def test_result_status(self):
216 217 pass
217 218 # to be written
218 219
219 220 def test_db_query_dt(self):
220 221 """test db query by date"""
221 222 hist = self.client.hub_history()
222 223 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
223 224 tic = middle['submitted']
224 225 before = self.client.db_query({'submitted' : {'$lt' : tic}})
225 226 after = self.client.db_query({'submitted' : {'$gte' : tic}})
226 227 self.assertEqual(len(before)+len(after),len(hist))
227 228 for b in before:
228 229 self.assertTrue(b['submitted'] < tic)
229 230 for a in after:
230 231 self.assertTrue(a['submitted'] >= tic)
231 232 same = self.client.db_query({'submitted' : tic})
232 233 for s in same:
233 234 self.assertTrue(s['submitted'] == tic)
234 235
235 236 def test_db_query_keys(self):
236 237 """test extracting subset of record keys"""
237 238 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
238 239 for rec in found:
239 240 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
240 241
241 242 def test_db_query_default_keys(self):
242 243 """default db_query excludes buffers"""
243 244 found = self.client.db_query({'msg_id': {'$ne' : ''}})
244 245 for rec in found:
245 246 keys = set(rec.keys())
246 247 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
247 248 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
248 249
249 250 def test_db_query_msg_id(self):
250 251 """ensure msg_id is always in db queries"""
251 252 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
252 253 for rec in found:
253 254 self.assertTrue('msg_id' in rec.keys())
254 255 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
255 256 for rec in found:
256 257 self.assertTrue('msg_id' in rec.keys())
257 258 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
258 259 for rec in found:
259 260 self.assertTrue('msg_id' in rec.keys())
260 261
261 262 def test_db_query_get_result(self):
262 263 """pop in db_query shouldn't pop from result itself"""
263 264 self.client[:].apply_sync(lambda : 1)
264 265 found = self.client.db_query({'msg_id': {'$ne' : ''}})
265 266 rc2 = clientmod.Client(profile='iptest')
266 267 # If this bug is not fixed, this call will hang:
267 268 ar = rc2.get_result(self.client.history[-1])
268 269 ar.wait(2)
269 270 self.assertTrue(ar.ready())
270 271 ar.get()
271 272 rc2.close()
272 273
273 274 def test_db_query_in(self):
274 275 """test db query with '$in','$nin' operators"""
275 276 hist = self.client.hub_history()
276 277 even = hist[::2]
277 278 odd = hist[1::2]
278 279 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
279 280 found = [ r['msg_id'] for r in recs ]
280 281 self.assertEqual(set(even), set(found))
281 282 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
282 283 found = [ r['msg_id'] for r in recs ]
283 284 self.assertEqual(set(odd), set(found))
284 285
285 286 def test_hub_history(self):
286 287 hist = self.client.hub_history()
287 288 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
288 289 recdict = {}
289 290 for rec in recs:
290 291 recdict[rec['msg_id']] = rec
291 292
292 293 latest = datetime(1984,1,1)
293 294 for msg_id in hist:
294 295 rec = recdict[msg_id]
295 296 newt = rec['submitted']
296 297 self.assertTrue(newt >= latest)
297 298 latest = newt
298 299 ar = self.client[-1].apply_async(lambda : 1)
299 300 ar.get()
300 301 time.sleep(0.25)
301 302 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
302 303
303 304 def _wait_for_idle(self):
304 305 """wait for an engine to become idle, according to the Hub"""
305 306 rc = self.client
306 307
307 308 # step 1. wait for all requests to be noticed
308 309 # timeout 5s, polling every 100ms
309 310 msg_ids = set(rc.history)
310 311 hub_hist = rc.hub_history()
311 312 for i in range(50):
312 313 if msg_ids.difference(hub_hist):
313 314 time.sleep(0.1)
314 315 hub_hist = rc.hub_history()
315 316 else:
316 317 break
317 318
318 319 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
319 320
320 321 # step 2. wait for all requests to be done
321 322 # timeout 5s, polling every 100ms
322 323 qs = rc.queue_status()
323 324 for i in range(50):
324 325 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
325 326 time.sleep(0.1)
326 327 qs = rc.queue_status()
327 328 else:
328 329 break
329 330
330 331 # ensure Hub up to date:
331 332 self.assertEqual(qs['unassigned'], 0)
332 333 for eid in rc.ids:
333 334 self.assertEqual(qs[eid]['tasks'], 0)
334 335
335 336
336 337 def test_resubmit(self):
337 338 def f():
338 339 import random
339 340 return random.random()
340 341 v = self.client.load_balanced_view()
341 342 ar = v.apply_async(f)
342 343 r1 = ar.get(1)
343 344 # give the Hub a chance to notice:
344 345 self._wait_for_idle()
345 346 ahr = self.client.resubmit(ar.msg_ids)
346 347 r2 = ahr.get(1)
347 348 self.assertFalse(r1 == r2)
348 349
349 350 def test_resubmit_chain(self):
350 351 """resubmit resubmitted tasks"""
351 352 v = self.client.load_balanced_view()
352 353 ar = v.apply_async(lambda x: x, 'x'*1024)
353 354 ar.get()
354 355 self._wait_for_idle()
355 356 ars = [ar]
356 357
357 358 for i in range(10):
358 359 ar = ars[-1]
359 360 ar2 = self.client.resubmit(ar.msg_ids)
360 361
361 362 [ ar.get() for ar in ars ]
362 363
363 364 def test_resubmit_header(self):
364 365 """resubmit shouldn't clobber the whole header"""
365 366 def f():
366 367 import random
367 368 return random.random()
368 369 v = self.client.load_balanced_view()
369 370 v.retries = 1
370 371 ar = v.apply_async(f)
371 372 r1 = ar.get(1)
372 373 # give the Hub a chance to notice:
373 374 self._wait_for_idle()
374 375 ahr = self.client.resubmit(ar.msg_ids)
375 376 ahr.get(1)
376 377 time.sleep(0.5)
377 378 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
378 379 h1,h2 = [ r['header'] for r in records ]
379 380 for key in set(h1.keys()).union(set(h2.keys())):
380 381 if key in ('msg_id', 'date'):
381 382 self.assertNotEqual(h1[key], h2[key])
382 383 else:
383 384 self.assertEqual(h1[key], h2[key])
384 385
385 386 def test_resubmit_aborted(self):
386 387 def f():
387 388 import random
388 389 return random.random()
389 390 v = self.client.load_balanced_view()
390 391 # restrict to one engine, so we can put a sleep
391 392 # ahead of the task, so it will get aborted
392 393 eid = self.client.ids[-1]
393 394 v.targets = [eid]
394 395 sleep = v.apply_async(time.sleep, 0.5)
395 396 ar = v.apply_async(f)
396 397 ar.abort()
397 398 self.assertRaises(error.TaskAborted, ar.get)
398 399 # Give the Hub a chance to get up to date:
399 400 self._wait_for_idle()
400 401 ahr = self.client.resubmit(ar.msg_ids)
401 402 r2 = ahr.get(1)
402 403
403 404 def test_resubmit_inflight(self):
404 405 """resubmit of inflight task"""
405 406 v = self.client.load_balanced_view()
406 407 ar = v.apply_async(time.sleep,1)
407 408 # give the message a chance to arrive
408 409 time.sleep(0.2)
409 410 ahr = self.client.resubmit(ar.msg_ids)
410 411 ar.get(2)
411 412 ahr.get(2)
412 413
413 414 def test_resubmit_badkey(self):
414 415 """ensure KeyError on resubmit of nonexistant task"""
415 416 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
416 417
417 418 def test_purge_hub_results(self):
418 419 # ensure there are some tasks
419 420 for i in range(5):
420 421 self.client[:].apply_sync(lambda : 1)
421 422 # Wait for the Hub to realise the result is done:
422 423 # This prevents a race condition, where we
423 424 # might purge a result the Hub still thinks is pending.
424 425 self._wait_for_idle()
425 426 rc2 = clientmod.Client(profile='iptest')
426 427 hist = self.client.hub_history()
427 428 ahr = rc2.get_result([hist[-1]])
428 429 ahr.wait(10)
429 430 self.client.purge_hub_results(hist[-1])
430 431 newhist = self.client.hub_history()
431 432 self.assertEqual(len(newhist)+1,len(hist))
432 433 rc2.spin()
433 434 rc2.close()
434 435
435 436 def test_purge_local_results(self):
436 437 # ensure there are some tasks
437 438 res = []
438 439 for i in range(5):
439 440 res.append(self.client[:].apply_async(lambda : 1))
440 441 self._wait_for_idle()
441 442 self.client.wait(10) # wait for the results to come back
442 443 before = len(self.client.results)
443 444 self.assertEqual(len(self.client.metadata),before)
444 445 self.client.purge_local_results(res[-1])
445 446 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
446 447 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
447 448
448 449 def test_purge_all_hub_results(self):
449 450 self.client.purge_hub_results('all')
450 451 hist = self.client.hub_history()
451 452 self.assertEqual(len(hist), 0)
452 453
453 454 def test_purge_all_local_results(self):
454 455 self.client.purge_local_results('all')
455 456 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
456 457 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
457 458
458 459 def test_purge_all_results(self):
459 460 # ensure there are some tasks
460 461 for i in range(5):
461 462 self.client[:].apply_sync(lambda : 1)
462 463 self.client.wait(10)
463 464 self._wait_for_idle()
464 465 self.client.purge_results('all')
465 466 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
466 467 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
467 468 hist = self.client.hub_history()
468 469 self.assertEqual(len(hist), 0, msg="hub history not empty")
469 470
470 471 def test_purge_everything(self):
471 472 # ensure there are some tasks
472 473 for i in range(5):
473 474 self.client[:].apply_sync(lambda : 1)
474 475 self.client.wait(10)
475 476 self._wait_for_idle()
476 477 self.client.purge_everything()
477 478 # The client results
478 479 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
479 480 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
480 481 # The client "bookkeeping"
481 482 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
482 483 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
483 484 # the hub results
484 485 hist = self.client.hub_history()
485 486 self.assertEqual(len(hist), 0, msg="hub history not empty")
486 487
487 488
488 489 def test_spin_thread(self):
489 490 self.client.spin_thread(0.01)
490 491 ar = self.client[-1].apply_async(lambda : 1)
491 492 time.sleep(0.1)
492 493 self.assertTrue(ar.wall_time < 0.1,
493 494 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
494 495 )
495 496
496 497 def test_stop_spin_thread(self):
497 498 self.client.spin_thread(0.01)
498 499 self.client.stop_spin_thread()
499 500 ar = self.client[-1].apply_async(lambda : 1)
500 501 time.sleep(0.15)
501 502 self.assertTrue(ar.wall_time > 0.1,
502 503 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
503 504 )
504 505
505 506 def test_activate(self):
506 507 ip = get_ipython()
507 508 magics = ip.magics_manager.magics
508 509 self.assertTrue('px' in magics['line'])
509 510 self.assertTrue('px' in magics['cell'])
510 511 v0 = self.client.activate(-1, '0')
511 512 self.assertTrue('px0' in magics['line'])
512 513 self.assertTrue('px0' in magics['cell'])
513 514 self.assertEqual(v0.targets, self.client.ids[-1])
514 515 v0 = self.client.activate('all', 'all')
515 516 self.assertTrue('pxall' in magics['line'])
516 517 self.assertTrue('pxall' in magics['cell'])
517 518 self.assertEqual(v0.targets, 'all')
@@ -1,789 +1,789 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import platform
21 21 import time
22 22 from collections import namedtuple
23 23 from tempfile import mktemp
24 24 from StringIO import StringIO
25 25
26 26 import zmq
27 27 from nose import SkipTest
28 28 from nose.plugins.attrib import attr
29 29
30 30 from IPython.testing import decorators as dec
31 31 from IPython.testing.ipunittest import ParametricTestCase
32 32 from IPython.utils.io import capture_output
33 33
34 34 from IPython import parallel as pmod
35 35 from IPython.parallel import error
36 36 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
37 37 from IPython.parallel import DirectView
38 38 from IPython.parallel.util import interactive
39 39
40 40 from IPython.parallel.tests import add_engines
41 41
42 42 from .clienttest import ClusterTestCase, crash, wait, skip_without
43 43
44 44 def setup():
45 45 add_engines(3, total=True)
46 46
47 47 point = namedtuple("point", "x y")
48 48
49 49 class TestView(ClusterTestCase, ParametricTestCase):
50 50
51 51 def setUp(self):
52 52 # On Win XP, wait for resource cleanup, else parallel test group fails
53 53 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
54 54 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
55 55 time.sleep(2)
56 56 super(TestView, self).setUp()
57 57
58 58 @attr('crash')
59 59 def test_z_crash_mux(self):
60 60 """test graceful handling of engine death (direct)"""
61 61 # self.add_engines(1)
62 62 eid = self.client.ids[-1]
63 63 ar = self.client[eid].apply_async(crash)
64 64 self.assertRaisesRemote(error.EngineError, ar.get, 10)
65 65 eid = ar.engine_id
66 66 tic = time.time()
67 67 while eid in self.client.ids and time.time()-tic < 5:
68 68 time.sleep(.01)
69 69 self.client.spin()
70 70 self.assertFalse(eid in self.client.ids, "Engine should have died")
71 71
72 72 def test_push_pull(self):
73 73 """test pushing and pulling"""
74 74 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
75 75 t = self.client.ids[-1]
76 76 v = self.client[t]
77 77 push = v.push
78 78 pull = v.pull
79 79 v.block=True
80 80 nengines = len(self.client)
81 81 push({'data':data})
82 82 d = pull('data')
83 83 self.assertEqual(d, data)
84 84 self.client[:].push({'data':data})
85 85 d = self.client[:].pull('data', block=True)
86 86 self.assertEqual(d, nengines*[data])
87 87 ar = push({'data':data}, block=False)
88 88 self.assertTrue(isinstance(ar, AsyncResult))
89 89 r = ar.get()
90 90 ar = self.client[:].pull('data', block=False)
91 91 self.assertTrue(isinstance(ar, AsyncResult))
92 92 r = ar.get()
93 93 self.assertEqual(r, nengines*[data])
94 94 self.client[:].push(dict(a=10,b=20))
95 95 r = self.client[:].pull(('a','b'), block=True)
96 96 self.assertEqual(r, nengines*[[10,20]])
97 97
98 98 def test_push_pull_function(self):
99 99 "test pushing and pulling functions"
100 100 def testf(x):
101 101 return 2.0*x
102 102
103 103 t = self.client.ids[-1]
104 104 v = self.client[t]
105 105 v.block=True
106 106 push = v.push
107 107 pull = v.pull
108 108 execute = v.execute
109 109 push({'testf':testf})
110 110 r = pull('testf')
111 111 self.assertEqual(r(1.0), testf(1.0))
112 112 execute('r = testf(10)')
113 113 r = pull('r')
114 114 self.assertEqual(r, testf(10))
115 115 ar = self.client[:].push({'testf':testf}, block=False)
116 116 ar.get()
117 117 ar = self.client[:].pull('testf', block=False)
118 118 rlist = ar.get()
119 119 for r in rlist:
120 120 self.assertEqual(r(1.0), testf(1.0))
121 121 execute("def g(x): return x*x")
122 122 r = pull(('testf','g'))
123 123 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
124 124
125 125 def test_push_function_globals(self):
126 126 """test that pushed functions have access to globals"""
127 127 @interactive
128 128 def geta():
129 129 return a
130 130 # self.add_engines(1)
131 131 v = self.client[-1]
132 132 v.block=True
133 133 v['f'] = geta
134 134 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
135 135 v.execute('a=5')
136 136 v.execute('b=f()')
137 137 self.assertEqual(v['b'], 5)
138 138
139 139 def test_push_function_defaults(self):
140 140 """test that pushed functions preserve default args"""
141 141 def echo(a=10):
142 142 return a
143 143 v = self.client[-1]
144 144 v.block=True
145 145 v['f'] = echo
146 146 v.execute('b=f()')
147 147 self.assertEqual(v['b'], 10)
148 148
149 149 def test_get_result(self):
150 150 """test getting results from the Hub."""
151 151 c = pmod.Client(profile='iptest')
152 152 # self.add_engines(1)
153 153 t = c.ids[-1]
154 154 v = c[t]
155 155 v2 = self.client[t]
156 156 ar = v.apply_async(wait, 1)
157 157 # give the monitor time to notice the message
158 158 time.sleep(.25)
159 ahr = v2.get_result(ar.msg_ids)
159 ahr = v2.get_result(ar.msg_ids[0])
160 160 self.assertTrue(isinstance(ahr, AsyncHubResult))
161 161 self.assertEqual(ahr.get(), ar.get())
162 ar2 = v2.get_result(ar.msg_ids)
162 ar2 = v2.get_result(ar.msg_ids[0])
163 163 self.assertFalse(isinstance(ar2, AsyncHubResult))
164 164 c.spin()
165 165 c.close()
166 166
167 167 def test_run_newline(self):
168 168 """test that run appends newline to files"""
169 169 tmpfile = mktemp()
170 170 with open(tmpfile, 'w') as f:
171 171 f.write("""def g():
172 172 return 5
173 173 """)
174 174 v = self.client[-1]
175 175 v.run(tmpfile, block=True)
176 176 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
177 177
178 178 def test_apply_tracked(self):
179 179 """test tracking for apply"""
180 180 # self.add_engines(1)
181 181 t = self.client.ids[-1]
182 182 v = self.client[t]
183 183 v.block=False
184 184 def echo(n=1024*1024, **kwargs):
185 185 with v.temp_flags(**kwargs):
186 186 return v.apply(lambda x: x, 'x'*n)
187 187 ar = echo(1, track=False)
188 188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 189 self.assertTrue(ar.sent)
190 190 ar = echo(track=True)
191 191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 192 self.assertEqual(ar.sent, ar._tracker.done)
193 193 ar._tracker.wait()
194 194 self.assertTrue(ar.sent)
195 195
196 196 def test_push_tracked(self):
197 197 t = self.client.ids[-1]
198 198 ns = dict(x='x'*1024*1024)
199 199 v = self.client[t]
200 200 ar = v.push(ns, block=False, track=False)
201 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 202 self.assertTrue(ar.sent)
203 203
204 204 ar = v.push(ns, block=False, track=True)
205 205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 206 ar._tracker.wait()
207 207 self.assertEqual(ar.sent, ar._tracker.done)
208 208 self.assertTrue(ar.sent)
209 209 ar.get()
210 210
211 211 def test_scatter_tracked(self):
212 212 t = self.client.ids
213 213 x='x'*1024*1024
214 214 ar = self.client[t].scatter('x', x, block=False, track=False)
215 215 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 216 self.assertTrue(ar.sent)
217 217
218 218 ar = self.client[t].scatter('x', x, block=False, track=True)
219 219 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
220 220 self.assertEqual(ar.sent, ar._tracker.done)
221 221 ar._tracker.wait()
222 222 self.assertTrue(ar.sent)
223 223 ar.get()
224 224
225 225 def test_remote_reference(self):
226 226 v = self.client[-1]
227 227 v['a'] = 123
228 228 ra = pmod.Reference('a')
229 229 b = v.apply_sync(lambda x: x, ra)
230 230 self.assertEqual(b, 123)
231 231
232 232
233 233 def test_scatter_gather(self):
234 234 view = self.client[:]
235 235 seq1 = range(16)
236 236 view.scatter('a', seq1)
237 237 seq2 = view.gather('a', block=True)
238 238 self.assertEqual(seq2, seq1)
239 239 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
240 240
241 241 @skip_without('numpy')
242 242 def test_scatter_gather_numpy(self):
243 243 import numpy
244 244 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
245 245 view = self.client[:]
246 246 a = numpy.arange(64)
247 247 view.scatter('a', a, block=True)
248 248 b = view.gather('a', block=True)
249 249 assert_array_equal(b, a)
250 250
251 251 def test_scatter_gather_lazy(self):
252 252 """scatter/gather with targets='all'"""
253 253 view = self.client.direct_view(targets='all')
254 254 x = range(64)
255 255 view.scatter('x', x)
256 256 gathered = view.gather('x', block=True)
257 257 self.assertEqual(gathered, x)
258 258
259 259
260 260 @dec.known_failure_py3
261 261 @skip_without('numpy')
262 262 def test_push_numpy_nocopy(self):
263 263 import numpy
264 264 view = self.client[:]
265 265 a = numpy.arange(64)
266 266 view['A'] = a
267 267 @interactive
268 268 def check_writeable(x):
269 269 return x.flags.writeable
270 270
271 271 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
272 272 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273 273
274 274 view.push(dict(B=a))
275 275 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
276 276 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
277 277
278 278 @skip_without('numpy')
279 279 def test_apply_numpy(self):
280 280 """view.apply(f, ndarray)"""
281 281 import numpy
282 282 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
283 283
284 284 A = numpy.random.random((100,100))
285 285 view = self.client[-1]
286 286 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
287 287 B = A.astype(dt)
288 288 C = view.apply_sync(lambda x:x, B)
289 289 assert_array_equal(B,C)
290 290
291 291 @skip_without('numpy')
292 292 def test_push_pull_recarray(self):
293 293 """push/pull recarrays"""
294 294 import numpy
295 295 from numpy.testing.utils import assert_array_equal
296 296
297 297 view = self.client[-1]
298 298
299 299 R = numpy.array([
300 300 (1, 'hi', 0.),
301 301 (2**30, 'there', 2.5),
302 302 (-99999, 'world', -12345.6789),
303 303 ], [('n', int), ('s', '|S10'), ('f', float)])
304 304
305 305 view['RR'] = R
306 306 R2 = view['RR']
307 307
308 308 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 309 self.assertEqual(r_dtype, R.dtype)
310 310 self.assertEqual(r_shape, R.shape)
311 311 self.assertEqual(R2.dtype, R.dtype)
312 312 self.assertEqual(R2.shape, R.shape)
313 313 assert_array_equal(R2, R)
314 314
315 315 @skip_without('pandas')
316 316 def test_push_pull_timeseries(self):
317 317 """push/pull pandas.TimeSeries"""
318 318 import pandas
319 319
320 320 ts = pandas.TimeSeries(range(10))
321 321
322 322 view = self.client[-1]
323 323
324 324 view.push(dict(ts=ts), block=True)
325 325 rts = view['ts']
326 326
327 327 self.assertEqual(type(rts), type(ts))
328 328 self.assertTrue((ts == rts).all())
329 329
330 330 def test_map(self):
331 331 view = self.client[:]
332 332 def f(x):
333 333 return x**2
334 334 data = range(16)
335 335 r = view.map_sync(f, data)
336 336 self.assertEqual(r, map(f, data))
337 337
338 338 def test_map_iterable(self):
339 339 """test map on iterables (direct)"""
340 340 view = self.client[:]
341 341 # 101 is prime, so it won't be evenly distributed
342 342 arr = range(101)
343 343 # ensure it will be an iterator, even in Python 3
344 344 it = iter(arr)
345 345 r = view.map_sync(lambda x:x, arr)
346 346 self.assertEqual(r, list(arr))
347 347
348 348 def test_scatter_gather_nonblocking(self):
349 349 data = range(16)
350 350 view = self.client[:]
351 351 view.scatter('a', data, block=False)
352 352 ar = view.gather('a', block=False)
353 353 self.assertEqual(ar.get(), data)
354 354
355 355 @skip_without('numpy')
356 356 def test_scatter_gather_numpy_nonblocking(self):
357 357 import numpy
358 358 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
359 359 a = numpy.arange(64)
360 360 view = self.client[:]
361 361 ar = view.scatter('a', a, block=False)
362 362 self.assertTrue(isinstance(ar, AsyncResult))
363 363 amr = view.gather('a', block=False)
364 364 self.assertTrue(isinstance(amr, AsyncMapResult))
365 365 assert_array_equal(amr.get(), a)
366 366
367 367 def test_execute(self):
368 368 view = self.client[:]
369 369 # self.client.debug=True
370 370 execute = view.execute
371 371 ar = execute('c=30', block=False)
372 372 self.assertTrue(isinstance(ar, AsyncResult))
373 373 ar = execute('d=[0,1,2]', block=False)
374 374 self.client.wait(ar, 1)
375 375 self.assertEqual(len(ar.get()), len(self.client))
376 376 for c in view['c']:
377 377 self.assertEqual(c, 30)
378 378
379 379 def test_abort(self):
380 380 view = self.client[-1]
381 381 ar = view.execute('import time; time.sleep(1)', block=False)
382 382 ar2 = view.apply_async(lambda : 2)
383 383 ar3 = view.apply_async(lambda : 3)
384 384 view.abort(ar2)
385 385 view.abort(ar3.msg_ids)
386 386 self.assertRaises(error.TaskAborted, ar2.get)
387 387 self.assertRaises(error.TaskAborted, ar3.get)
388 388
389 389 def test_abort_all(self):
390 390 """view.abort() aborts all outstanding tasks"""
391 391 view = self.client[-1]
392 392 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
393 393 view.abort()
394 394 view.wait(timeout=5)
395 395 for ar in ars[5:]:
396 396 self.assertRaises(error.TaskAborted, ar.get)
397 397
398 398 def test_temp_flags(self):
399 399 view = self.client[-1]
400 400 view.block=True
401 401 with view.temp_flags(block=False):
402 402 self.assertFalse(view.block)
403 403 self.assertTrue(view.block)
404 404
405 405 @dec.known_failure_py3
406 406 def test_importer(self):
407 407 view = self.client[-1]
408 408 view.clear(block=True)
409 409 with view.importer:
410 410 import re
411 411
412 412 @interactive
413 413 def findall(pat, s):
414 414 # this globals() step isn't necessary in real code
415 415 # only to prevent a closure in the test
416 416 re = globals()['re']
417 417 return re.findall(pat, s)
418 418
419 419 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
420 420
421 421 def test_unicode_execute(self):
422 422 """test executing unicode strings"""
423 423 v = self.client[-1]
424 424 v.block=True
425 425 if sys.version_info[0] >= 3:
426 426 code="a='é'"
427 427 else:
428 428 code=u"a=u'é'"
429 429 v.execute(code)
430 430 self.assertEqual(v['a'], u'é')
431 431
432 432 def test_unicode_apply_result(self):
433 433 """test unicode apply results"""
434 434 v = self.client[-1]
435 435 r = v.apply_sync(lambda : u'é')
436 436 self.assertEqual(r, u'é')
437 437
438 438 def test_unicode_apply_arg(self):
439 439 """test passing unicode arguments to apply"""
440 440 v = self.client[-1]
441 441
442 442 @interactive
443 443 def check_unicode(a, check):
444 444 assert isinstance(a, unicode), "%r is not unicode"%a
445 445 assert isinstance(check, bytes), "%r is not bytes"%check
446 446 assert a.encode('utf8') == check, "%s != %s"%(a,check)
447 447
448 448 for s in [ u'é', u'ßø®∫',u'asdf' ]:
449 449 try:
450 450 v.apply_sync(check_unicode, s, s.encode('utf8'))
451 451 except error.RemoteError as e:
452 452 if e.ename == 'AssertionError':
453 453 self.fail(e.evalue)
454 454 else:
455 455 raise e
456 456
457 457 def test_map_reference(self):
458 458 """view.map(<Reference>, *seqs) should work"""
459 459 v = self.client[:]
460 460 v.scatter('n', self.client.ids, flatten=True)
461 461 v.execute("f = lambda x,y: x*y")
462 462 rf = pmod.Reference('f')
463 463 nlist = list(range(10))
464 464 mlist = nlist[::-1]
465 465 expected = [ m*n for m,n in zip(mlist, nlist) ]
466 466 result = v.map_sync(rf, mlist, nlist)
467 467 self.assertEqual(result, expected)
468 468
469 469 def test_apply_reference(self):
470 470 """view.apply(<Reference>, *args) should work"""
471 471 v = self.client[:]
472 472 v.scatter('n', self.client.ids, flatten=True)
473 473 v.execute("f = lambda x: n*x")
474 474 rf = pmod.Reference('f')
475 475 result = v.apply_sync(rf, 5)
476 476 expected = [ 5*id for id in self.client.ids ]
477 477 self.assertEqual(result, expected)
478 478
479 479 def test_eval_reference(self):
480 480 v = self.client[self.client.ids[0]]
481 481 v['g'] = range(5)
482 482 rg = pmod.Reference('g[0]')
483 483 echo = lambda x:x
484 484 self.assertEqual(v.apply_sync(echo, rg), 0)
485 485
486 486 def test_reference_nameerror(self):
487 487 v = self.client[self.client.ids[0]]
488 488 r = pmod.Reference('elvis_has_left')
489 489 echo = lambda x:x
490 490 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
491 491
492 492 def test_single_engine_map(self):
493 493 e0 = self.client[self.client.ids[0]]
494 494 r = range(5)
495 495 check = [ -1*i for i in r ]
496 496 result = e0.map_sync(lambda x: -1*x, r)
497 497 self.assertEqual(result, check)
498 498
499 499 def test_len(self):
500 500 """len(view) makes sense"""
501 501 e0 = self.client[self.client.ids[0]]
502 502 yield self.assertEqual(len(e0), 1)
503 503 v = self.client[:]
504 504 yield self.assertEqual(len(v), len(self.client.ids))
505 505 v = self.client.direct_view('all')
506 506 yield self.assertEqual(len(v), len(self.client.ids))
507 507 v = self.client[:2]
508 508 yield self.assertEqual(len(v), 2)
509 509 v = self.client[:1]
510 510 yield self.assertEqual(len(v), 1)
511 511 v = self.client.load_balanced_view()
512 512 yield self.assertEqual(len(v), len(self.client.ids))
513 513 # parametric tests seem to require manual closing?
514 514 self.client.close()
515 515
516 516
517 517 # begin execute tests
518 518
519 519 def test_execute_reply(self):
520 520 e0 = self.client[self.client.ids[0]]
521 521 e0.block = True
522 522 ar = e0.execute("5", silent=False)
523 523 er = ar.get()
524 524 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
525 525 self.assertEqual(er.pyout['data']['text/plain'], '5')
526 526
527 527 def test_execute_reply_stdout(self):
528 528 e0 = self.client[self.client.ids[0]]
529 529 e0.block = True
530 530 ar = e0.execute("print (5)", silent=False)
531 531 er = ar.get()
532 532 self.assertEqual(er.stdout.strip(), '5')
533 533
534 534 def test_execute_pyout(self):
535 535 """execute triggers pyout with silent=False"""
536 536 view = self.client[:]
537 537 ar = view.execute("5", silent=False, block=True)
538 538
539 539 expected = [{'text/plain' : '5'}] * len(view)
540 540 mimes = [ out['data'] for out in ar.pyout ]
541 541 self.assertEqual(mimes, expected)
542 542
543 543 def test_execute_silent(self):
544 544 """execute does not trigger pyout with silent=True"""
545 545 view = self.client[:]
546 546 ar = view.execute("5", block=True)
547 547 expected = [None] * len(view)
548 548 self.assertEqual(ar.pyout, expected)
549 549
550 550 def test_execute_magic(self):
551 551 """execute accepts IPython commands"""
552 552 view = self.client[:]
553 553 view.execute("a = 5")
554 554 ar = view.execute("%whos", block=True)
555 555 # this will raise, if that failed
556 556 ar.get(5)
557 557 for stdout in ar.stdout:
558 558 lines = stdout.splitlines()
559 559 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
560 560 found = False
561 561 for line in lines[2:]:
562 562 split = line.split()
563 563 if split == ['a', 'int', '5']:
564 564 found = True
565 565 break
566 566 self.assertTrue(found, "whos output wrong: %s" % stdout)
567 567
568 568 def test_execute_displaypub(self):
569 569 """execute tracks display_pub output"""
570 570 view = self.client[:]
571 571 view.execute("from IPython.core.display import *")
572 572 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
573 573
574 574 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
575 575 for outputs in ar.outputs:
576 576 mimes = [ out['data'] for out in outputs ]
577 577 self.assertEqual(mimes, expected)
578 578
579 579 def test_apply_displaypub(self):
580 580 """apply tracks display_pub output"""
581 581 view = self.client[:]
582 582 view.execute("from IPython.core.display import *")
583 583
584 584 @interactive
585 585 def publish():
586 586 [ display(i) for i in range(5) ]
587 587
588 588 ar = view.apply_async(publish)
589 589 ar.get(5)
590 590 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
591 591 for outputs in ar.outputs:
592 592 mimes = [ out['data'] for out in outputs ]
593 593 self.assertEqual(mimes, expected)
594 594
595 595 def test_execute_raises(self):
596 596 """exceptions in execute requests raise appropriately"""
597 597 view = self.client[-1]
598 598 ar = view.execute("1/0")
599 599 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
600 600
601 601 def test_remoteerror_render_exception(self):
602 602 """RemoteErrors get nice tracebacks"""
603 603 view = self.client[-1]
604 604 ar = view.execute("1/0")
605 605 ip = get_ipython()
606 606 ip.user_ns['ar'] = ar
607 607 with capture_output() as io:
608 608 ip.run_cell("ar.get(2)")
609 609
610 610 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
611 611
612 612 def test_compositeerror_render_exception(self):
613 613 """CompositeErrors get nice tracebacks"""
614 614 view = self.client[:]
615 615 ar = view.execute("1/0")
616 616 ip = get_ipython()
617 617 ip.user_ns['ar'] = ar
618 618
619 619 with capture_output() as io:
620 620 ip.run_cell("ar.get(2)")
621 621
622 622 count = min(error.CompositeError.tb_limit, len(view))
623 623
624 624 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
625 625 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
626 626 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
627 627
628 628 def test_compositeerror_truncate(self):
629 629 """Truncate CompositeErrors with many exceptions"""
630 630 view = self.client[:]
631 631 msg_ids = []
632 632 for i in range(10):
633 633 ar = view.execute("1/0")
634 634 msg_ids.extend(ar.msg_ids)
635 635
636 636 ar = self.client.get_result(msg_ids)
637 637 try:
638 638 ar.get()
639 639 except error.CompositeError as _e:
640 640 e = _e
641 641 else:
642 642 self.fail("Should have raised CompositeError")
643 643
644 644 lines = e.render_traceback()
645 645 with capture_output() as io:
646 646 e.print_traceback()
647 647
648 648 self.assertTrue("more exceptions" in lines[-1])
649 649 count = e.tb_limit
650 650
651 651 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
652 652 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
653 653 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
654 654
655 655 @dec.skipif_not_matplotlib
656 656 def test_magic_pylab(self):
657 657 """%pylab works on engines"""
658 658 view = self.client[-1]
659 659 ar = view.execute("%pylab inline")
660 660 # at least check if this raised:
661 661 reply = ar.get(5)
662 662 # include imports, in case user config
663 663 ar = view.execute("plot(rand(100))", silent=False)
664 664 reply = ar.get(5)
665 665 self.assertEqual(len(reply.outputs), 1)
666 666 output = reply.outputs[0]
667 667 self.assertTrue("data" in output)
668 668 data = output['data']
669 669 self.assertTrue("image/png" in data)
670 670
671 671 def test_func_default_func(self):
672 672 """interactively defined function as apply func default"""
673 673 def foo():
674 674 return 'foo'
675 675
676 676 def bar(f=foo):
677 677 return f()
678 678
679 679 view = self.client[-1]
680 680 ar = view.apply_async(bar)
681 681 r = ar.get(10)
682 682 self.assertEqual(r, 'foo')
683 683 def test_data_pub_single(self):
684 684 view = self.client[-1]
685 685 ar = view.execute('\n'.join([
686 686 'from IPython.kernel.zmq.datapub import publish_data',
687 687 'for i in range(5):',
688 688 ' publish_data(dict(i=i))'
689 689 ]), block=False)
690 690 self.assertTrue(isinstance(ar.data, dict))
691 691 ar.get(5)
692 692 self.assertEqual(ar.data, dict(i=4))
693 693
694 694 def test_data_pub(self):
695 695 view = self.client[:]
696 696 ar = view.execute('\n'.join([
697 697 'from IPython.kernel.zmq.datapub import publish_data',
698 698 'for i in range(5):',
699 699 ' publish_data(dict(i=i))'
700 700 ]), block=False)
701 701 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
702 702 ar.get(5)
703 703 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
704 704
705 705 def test_can_list_arg(self):
706 706 """args in lists are canned"""
707 707 view = self.client[-1]
708 708 view['a'] = 128
709 709 rA = pmod.Reference('a')
710 710 ar = view.apply_async(lambda x: x, [rA])
711 711 r = ar.get(5)
712 712 self.assertEqual(r, [128])
713 713
714 714 def test_can_dict_arg(self):
715 715 """args in dicts are canned"""
716 716 view = self.client[-1]
717 717 view['a'] = 128
718 718 rA = pmod.Reference('a')
719 719 ar = view.apply_async(lambda x: x, dict(foo=rA))
720 720 r = ar.get(5)
721 721 self.assertEqual(r, dict(foo=128))
722 722
723 723 def test_can_list_kwarg(self):
724 724 """kwargs in lists are canned"""
725 725 view = self.client[-1]
726 726 view['a'] = 128
727 727 rA = pmod.Reference('a')
728 728 ar = view.apply_async(lambda x=5: x, x=[rA])
729 729 r = ar.get(5)
730 730 self.assertEqual(r, [128])
731 731
732 732 def test_can_dict_kwarg(self):
733 733 """kwargs in dicts are canned"""
734 734 view = self.client[-1]
735 735 view['a'] = 128
736 736 rA = pmod.Reference('a')
737 737 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
738 738 r = ar.get(5)
739 739 self.assertEqual(r, dict(foo=128))
740 740
741 741 def test_map_ref(self):
742 742 """view.map works with references"""
743 743 view = self.client[:]
744 744 ranks = sorted(self.client.ids)
745 745 view.scatter('rank', ranks, flatten=True)
746 746 rrank = pmod.Reference('rank')
747 747
748 748 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
749 749 drank = amr.get(5)
750 750 self.assertEqual(drank, [ r*2 for r in ranks ])
751 751
752 752 def test_nested_getitem_setitem(self):
753 753 """get and set with view['a.b']"""
754 754 view = self.client[-1]
755 755 view.execute('\n'.join([
756 756 'class A(object): pass',
757 757 'a = A()',
758 758 'a.b = 128',
759 759 ]), block=True)
760 760 ra = pmod.Reference('a')
761 761
762 762 r = view.apply_sync(lambda x: x.b, ra)
763 763 self.assertEqual(r, 128)
764 764 self.assertEqual(view['a.b'], 128)
765 765
766 766 view['a.b'] = 0
767 767
768 768 r = view.apply_sync(lambda x: x.b, ra)
769 769 self.assertEqual(r, 0)
770 770 self.assertEqual(view['a.b'], 0)
771 771
772 772 def test_return_namedtuple(self):
773 773 def namedtuplify(x, y):
774 774 from IPython.parallel.tests.test_view import point
775 775 return point(x, y)
776 776
777 777 view = self.client[-1]
778 778 p = view.apply_sync(namedtuplify, 1, 2)
779 779 self.assertEqual(p.x, 1)
780 780 self.assertEqual(p.y, 2)
781 781
782 782 def test_apply_namedtuple(self):
783 783 def echoxy(p):
784 784 return p.y, p.x
785 785
786 786 view = self.client[-1]
787 787 tup = view.apply_sync(echoxy, point(1, 2))
788 788 self.assertEqual(tup, (2,1))
789 789
General Comments 0
You need to be logged in to leave comments. Login now