##// END OF EJS Templates
subclass RichOutput in ExecuteReply...
MinRK -
Show More
@@ -1,1839 +1,1845 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 from threading import Thread, Event
22 22 import time
23 23 import warnings
24 24 from datetime import datetime
25 25 from getpass import getpass
26 26 from pprint import pprint
27 27
28 28 pjoin = os.path.join
29 29
30 30 import zmq
31 31 # from zmq.eventloop import ioloop, zmqstream
32 32
33 33 from IPython.config.configurable import MultipleInstanceError
34 34 from IPython.core.application import BaseIPythonApplication
35 35 from IPython.core.profiledir import ProfileDir, ProfileDirError
36 36
37 from IPython.utils.capture import RichOutput
37 38 from IPython.utils.coloransi import TermColors
38 39 from IPython.utils.jsonutil import rekey
39 40 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
40 41 from IPython.utils.path import get_ipython_dir
41 42 from IPython.utils.py3compat import cast_bytes
42 43 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
43 44 Dict, List, Bool, Set, Any)
44 45 from IPython.external.decorator import decorator
45 46 from IPython.external.ssh import tunnel
46 47
47 48 from IPython.parallel import Reference
48 49 from IPython.parallel import error
49 50 from IPython.parallel import util
50 51
51 52 from IPython.kernel.zmq.session import Session, Message
52 53 from IPython.kernel.zmq import serialize
53 54
54 55 from .asyncresult import AsyncResult, AsyncHubResult
55 56 from .view import DirectView, LoadBalancedView
56 57
57 58 if sys.version_info[0] >= 3:
58 59 # xrange is used in a couple 'isinstance' tests in py2
59 60 # should be just 'range' in 3k
60 61 xrange = range
61 62
62 63 #--------------------------------------------------------------------------
63 64 # Decorators for Client methods
64 65 #--------------------------------------------------------------------------
65 66
66 67 @decorator
67 68 def spin_first(f, self, *args, **kwargs):
68 69 """Call spin() to sync state prior to calling the method."""
69 70 self.spin()
70 71 return f(self, *args, **kwargs)
71 72
72 73
73 74 #--------------------------------------------------------------------------
74 75 # Classes
75 76 #--------------------------------------------------------------------------
76 77
77 78
78 class ExecuteReply(object):
79 class ExecuteReply(RichOutput):
79 80 """wrapper for finished Execute results"""
80 81 def __init__(self, msg_id, content, metadata):
81 82 self.msg_id = msg_id
82 83 self._content = content
83 84 self.execution_count = content['execution_count']
84 85 self.metadata = metadata
85 86
87 # RichOutput overrides
88
89 @property
90 def source(self):
91 pyout = self.metadata['pyout']
92 if pyout:
93 return pyout.get('source', '')
94
95 @property
96 def data(self):
97 pyout = self.metadata['pyout']
98 if pyout:
99 return pyout.get('data', {})
100
101 @property
102 def _metadata(self):
103 pyout = self.metadata['pyout']
104 if pyout:
105 return pyout.get('metadata', {})
106
107 def display(self):
108 from IPython.display import publish_display_data
109 publish_display_data(self.source, self.data, self.metadata)
110
111 def _repr_mime_(self, mime):
112 if mime not in self.data:
113 return
114 data = self.data[mime]
115 if mime in self._metadata:
116 return data, self._metadata[mime]
117 else:
118 return data
119
86 120 def __getitem__(self, key):
87 121 return self.metadata[key]
88 122
89 123 def __getattr__(self, key):
90 124 if key not in self.metadata:
91 125 raise AttributeError(key)
92 126 return self.metadata[key]
93 127
94 128 def __repr__(self):
95 129 pyout = self.metadata['pyout'] or {'data':{}}
96 130 text_out = pyout['data'].get('text/plain', '')
97 131 if len(text_out) > 32:
98 132 text_out = text_out[:29] + '...'
99 133
100 134 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
101 135
102 136 def _repr_pretty_(self, p, cycle):
103 137 pyout = self.metadata['pyout'] or {'data':{}}
104 138 text_out = pyout['data'].get('text/plain', '')
105 139
106 140 if not text_out:
107 141 return
108 142
109 143 try:
110 144 ip = get_ipython()
111 145 except NameError:
112 146 colors = "NoColor"
113 147 else:
114 148 colors = ip.colors
115 149
116 150 if colors == "NoColor":
117 151 out = normal = ""
118 152 else:
119 153 out = TermColors.Red
120 154 normal = TermColors.Normal
121 155
122 156 if '\n' in text_out and not text_out.startswith('\n'):
123 157 # add newline for multiline reprs
124 158 text_out = '\n' + text_out
125 159
126 160 p.text(
127 161 out + u'Out[%i:%i]: ' % (
128 162 self.metadata['engine_id'], self.execution_count
129 163 ) + normal + text_out
130 164 )
131
132 def _repr_html_(self):
133 pyout = self.metadata['pyout'] or {'data':{}}
134 return pyout['data'].get("text/html")
135
136 def _repr_latex_(self):
137 pyout = self.metadata['pyout'] or {'data':{}}
138 return pyout['data'].get("text/latex")
139
140 def _repr_json_(self):
141 pyout = self.metadata['pyout'] or {'data':{}}
142 return pyout['data'].get("application/json")
143
144 def _repr_javascript_(self):
145 pyout = self.metadata['pyout'] or {'data':{}}
146 return pyout['data'].get("application/javascript")
147
148 def _repr_png_(self):
149 pyout = self.metadata['pyout'] or {'data':{}}
150 return pyout['data'].get("image/png")
151
152 def _repr_jpeg_(self):
153 pyout = self.metadata['pyout'] or {'data':{}}
154 return pyout['data'].get("image/jpeg")
155
156 def _repr_svg_(self):
157 pyout = self.metadata['pyout'] or {'data':{}}
158 return pyout['data'].get("image/svg+xml")
159 165
160 166
161 167 class Metadata(dict):
162 168 """Subclass of dict for initializing metadata values.
163 169
164 170 Attribute access works on keys.
165 171
166 172 These objects have a strict set of keys - errors will raise if you try
167 173 to add new keys.
168 174 """
169 175 def __init__(self, *args, **kwargs):
170 176 dict.__init__(self)
171 177 md = {'msg_id' : None,
172 178 'submitted' : None,
173 179 'started' : None,
174 180 'completed' : None,
175 181 'received' : None,
176 182 'engine_uuid' : None,
177 183 'engine_id' : None,
178 184 'follow' : None,
179 185 'after' : None,
180 186 'status' : None,
181 187
182 188 'pyin' : None,
183 189 'pyout' : None,
184 190 'pyerr' : None,
185 191 'stdout' : '',
186 192 'stderr' : '',
187 193 'outputs' : [],
188 194 'data': {},
189 195 'outputs_ready' : False,
190 196 }
191 197 self.update(md)
192 198 self.update(dict(*args, **kwargs))
193 199
194 200 def __getattr__(self, key):
195 201 """getattr aliased to getitem"""
196 202 if key in self.iterkeys():
197 203 return self[key]
198 204 else:
199 205 raise AttributeError(key)
200 206
201 207 def __setattr__(self, key, value):
202 208 """setattr aliased to setitem, with strict"""
203 209 if key in self.iterkeys():
204 210 self[key] = value
205 211 else:
206 212 raise AttributeError(key)
207 213
208 214 def __setitem__(self, key, value):
209 215 """strict static key enforcement"""
210 216 if key in self.iterkeys():
211 217 dict.__setitem__(self, key, value)
212 218 else:
213 219 raise KeyError(key)
214 220
215 221
216 222 class Client(HasTraits):
217 223 """A semi-synchronous client to the IPython ZMQ cluster
218 224
219 225 Parameters
220 226 ----------
221 227
222 228 url_file : str/unicode; path to ipcontroller-client.json
223 229 This JSON file should contain all the information needed to connect to a cluster,
224 230 and is likely the only argument needed.
225 231 Connection information for the Hub's registration. If a json connector
226 232 file is given, then likely no further configuration is necessary.
227 233 [Default: use profile]
228 234 profile : bytes
229 235 The name of the Cluster profile to be used to find connector information.
230 236 If run from an IPython application, the default profile will be the same
231 237 as the running application, otherwise it will be 'default'.
232 238 cluster_id : str
233 239 String id to added to runtime files, to prevent name collisions when using
234 240 multiple clusters with a single profile simultaneously.
235 241 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
236 242 Since this is text inserted into filenames, typical recommendations apply:
237 243 Simple character strings are ideal, and spaces are not recommended (but
238 244 should generally work)
239 245 context : zmq.Context
240 246 Pass an existing zmq.Context instance, otherwise the client will create its own.
241 247 debug : bool
242 248 flag for lots of message printing for debug purposes
243 249 timeout : int/float
244 250 time (in seconds) to wait for connection replies from the Hub
245 251 [Default: 10]
246 252
247 253 #-------------- session related args ----------------
248 254
249 255 config : Config object
250 256 If specified, this will be relayed to the Session for configuration
251 257 username : str
252 258 set username for the session object
253 259
254 260 #-------------- ssh related args ----------------
255 261 # These are args for configuring the ssh tunnel to be used
256 262 # credentials are used to forward connections over ssh to the Controller
257 263 # Note that the ip given in `addr` needs to be relative to sshserver
258 264 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
259 265 # and set sshserver as the same machine the Controller is on. However,
260 266 # the only requirement is that sshserver is able to see the Controller
261 267 # (i.e. is within the same trusted network).
262 268
263 269 sshserver : str
264 270 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
265 271 If keyfile or password is specified, and this is not, it will default to
266 272 the ip given in addr.
267 273 sshkey : str; path to ssh private key file
268 274 This specifies a key to be used in ssh login, default None.
269 275 Regular default ssh keys will be used without specifying this argument.
270 276 password : str
271 277 Your ssh password to sshserver. Note that if this is left None,
272 278 you will be prompted for it if passwordless key based login is unavailable.
273 279 paramiko : bool
274 280 flag for whether to use paramiko instead of shell ssh for tunneling.
275 281 [default: True on win32, False else]
276 282
277 283
278 284 Attributes
279 285 ----------
280 286
281 287 ids : list of int engine IDs
282 288 requesting the ids attribute always synchronizes
283 289 the registration state. To request ids without synchronization,
284 290 use semi-private _ids attributes.
285 291
286 292 history : list of msg_ids
287 293 a list of msg_ids, keeping track of all the execution
288 294 messages you have submitted in order.
289 295
290 296 outstanding : set of msg_ids
291 297 a set of msg_ids that have been submitted, but whose
292 298 results have not yet been received.
293 299
294 300 results : dict
295 301 a dict of all our results, keyed by msg_id
296 302
297 303 block : bool
298 304 determines default behavior when block not specified
299 305 in execution methods
300 306
301 307 Methods
302 308 -------
303 309
304 310 spin
305 311 flushes incoming results and registration state changes
306 312 control methods spin, and requesting `ids` also ensures up to date
307 313
308 314 wait
309 315 wait on one or more msg_ids
310 316
311 317 execution methods
312 318 apply
313 319 legacy: execute, run
314 320
315 321 data movement
316 322 push, pull, scatter, gather
317 323
318 324 query methods
319 325 queue_status, get_result, purge, result_status
320 326
321 327 control methods
322 328 abort, shutdown
323 329
324 330 """
325 331
326 332
327 333 block = Bool(False)
328 334 outstanding = Set()
329 335 results = Instance('collections.defaultdict', (dict,))
330 336 metadata = Instance('collections.defaultdict', (Metadata,))
331 337 history = List()
332 338 debug = Bool(False)
333 339 _spin_thread = Any()
334 340 _stop_spinning = Any()
335 341
336 342 profile=Unicode()
337 343 def _profile_default(self):
338 344 if BaseIPythonApplication.initialized():
339 345 # an IPython app *might* be running, try to get its profile
340 346 try:
341 347 return BaseIPythonApplication.instance().profile
342 348 except (AttributeError, MultipleInstanceError):
343 349 # could be a *different* subclass of config.Application,
344 350 # which would raise one of these two errors.
345 351 return u'default'
346 352 else:
347 353 return u'default'
348 354
349 355
350 356 _outstanding_dict = Instance('collections.defaultdict', (set,))
351 357 _ids = List()
352 358 _connected=Bool(False)
353 359 _ssh=Bool(False)
354 360 _context = Instance('zmq.Context')
355 361 _config = Dict()
356 362 _engines=Instance(util.ReverseDict, (), {})
357 363 # _hub_socket=Instance('zmq.Socket')
358 364 _query_socket=Instance('zmq.Socket')
359 365 _control_socket=Instance('zmq.Socket')
360 366 _iopub_socket=Instance('zmq.Socket')
361 367 _notification_socket=Instance('zmq.Socket')
362 368 _mux_socket=Instance('zmq.Socket')
363 369 _task_socket=Instance('zmq.Socket')
364 370 _task_scheme=Unicode()
365 371 _closed = False
366 372 _ignored_control_replies=Integer(0)
367 373 _ignored_hub_replies=Integer(0)
368 374
369 375 def __new__(self, *args, **kw):
370 376 # don't raise on positional args
371 377 return HasTraits.__new__(self, **kw)
372 378
373 379 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
374 380 context=None, debug=False,
375 381 sshserver=None, sshkey=None, password=None, paramiko=None,
376 382 timeout=10, cluster_id=None, **extra_args
377 383 ):
378 384 if profile:
379 385 super(Client, self).__init__(debug=debug, profile=profile)
380 386 else:
381 387 super(Client, self).__init__(debug=debug)
382 388 if context is None:
383 389 context = zmq.Context.instance()
384 390 self._context = context
385 391 self._stop_spinning = Event()
386 392
387 393 if 'url_or_file' in extra_args:
388 394 url_file = extra_args['url_or_file']
389 395 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
390 396
391 397 if url_file and util.is_url(url_file):
392 398 raise ValueError("single urls cannot be specified, url-files must be used.")
393 399
394 400 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
395 401
396 402 if self._cd is not None:
397 403 if url_file is None:
398 404 if not cluster_id:
399 405 client_json = 'ipcontroller-client.json'
400 406 else:
401 407 client_json = 'ipcontroller-%s-client.json' % cluster_id
402 408 url_file = pjoin(self._cd.security_dir, client_json)
403 409 if url_file is None:
404 410 raise ValueError(
405 411 "I can't find enough information to connect to a hub!"
406 412 " Please specify at least one of url_file or profile."
407 413 )
408 414
409 415 with open(url_file) as f:
410 416 cfg = json.load(f)
411 417
412 418 self._task_scheme = cfg['task_scheme']
413 419
414 420 # sync defaults from args, json:
415 421 if sshserver:
416 422 cfg['ssh'] = sshserver
417 423
418 424 location = cfg.setdefault('location', None)
419 425
420 426 proto,addr = cfg['interface'].split('://')
421 427 addr = util.disambiguate_ip_address(addr, location)
422 428 cfg['interface'] = "%s://%s" % (proto, addr)
423 429
424 430 # turn interface,port into full urls:
425 431 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
426 432 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
427 433
428 434 url = cfg['registration']
429 435
430 436 if location is not None and addr == LOCALHOST:
431 437 # location specified, and connection is expected to be local
432 438 if location not in LOCAL_IPS and not sshserver:
433 439 # load ssh from JSON *only* if the controller is not on
434 440 # this machine
435 441 sshserver=cfg['ssh']
436 442 if location not in LOCAL_IPS and not sshserver:
437 443 # warn if no ssh specified, but SSH is probably needed
438 444 # This is only a warning, because the most likely cause
439 445 # is a local Controller on a laptop whose IP is dynamic
440 446 warnings.warn("""
441 447 Controller appears to be listening on localhost, but not on this machine.
442 448 If this is true, you should specify Client(...,sshserver='you@%s')
443 449 or instruct your controller to listen on an external IP."""%location,
444 450 RuntimeWarning)
445 451 elif not sshserver:
446 452 # otherwise sync with cfg
447 453 sshserver = cfg['ssh']
448 454
449 455 self._config = cfg
450 456
451 457 self._ssh = bool(sshserver or sshkey or password)
452 458 if self._ssh and sshserver is None:
453 459 # default to ssh via localhost
454 460 sshserver = addr
455 461 if self._ssh and password is None:
456 462 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
457 463 password=False
458 464 else:
459 465 password = getpass("SSH Password for %s: "%sshserver)
460 466 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
461 467
462 468 # configure and construct the session
463 469 try:
464 470 extra_args['packer'] = cfg['pack']
465 471 extra_args['unpacker'] = cfg['unpack']
466 472 extra_args['key'] = cast_bytes(cfg['key'])
467 473 extra_args['signature_scheme'] = cfg['signature_scheme']
468 474 except KeyError as exc:
469 475 msg = '\n'.join([
470 476 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
471 477 "If you are reusing connection files, remove them and start ipcontroller again."
472 478 ])
473 479 raise ValueError(msg.format(exc.message))
474 480
475 481 self.session = Session(**extra_args)
476 482
477 483 self._query_socket = self._context.socket(zmq.DEALER)
478 484
479 485 if self._ssh:
480 486 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
481 487 else:
482 488 self._query_socket.connect(cfg['registration'])
483 489
484 490 self.session.debug = self.debug
485 491
486 492 self._notification_handlers = {'registration_notification' : self._register_engine,
487 493 'unregistration_notification' : self._unregister_engine,
488 494 'shutdown_notification' : lambda msg: self.close(),
489 495 }
490 496 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
491 497 'apply_reply' : self._handle_apply_reply}
492 498 self._connect(sshserver, ssh_kwargs, timeout)
493 499
494 500 # last step: setup magics, if we are in IPython:
495 501
496 502 try:
497 503 ip = get_ipython()
498 504 except NameError:
499 505 return
500 506 else:
501 507 if 'px' not in ip.magics_manager.magics:
502 508 # in IPython but we are the first Client.
503 509 # activate a default view for parallel magics.
504 510 self.activate()
505 511
506 512 def __del__(self):
507 513 """cleanup sockets, but _not_ context."""
508 514 self.close()
509 515
510 516 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
511 517 if ipython_dir is None:
512 518 ipython_dir = get_ipython_dir()
513 519 if profile_dir is not None:
514 520 try:
515 521 self._cd = ProfileDir.find_profile_dir(profile_dir)
516 522 return
517 523 except ProfileDirError:
518 524 pass
519 525 elif profile is not None:
520 526 try:
521 527 self._cd = ProfileDir.find_profile_dir_by_name(
522 528 ipython_dir, profile)
523 529 return
524 530 except ProfileDirError:
525 531 pass
526 532 self._cd = None
527 533
528 534 def _update_engines(self, engines):
529 535 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
530 536 for k,v in engines.iteritems():
531 537 eid = int(k)
532 538 if eid not in self._engines:
533 539 self._ids.append(eid)
534 540 self._engines[eid] = v
535 541 self._ids = sorted(self._ids)
536 542 if sorted(self._engines.keys()) != range(len(self._engines)) and \
537 543 self._task_scheme == 'pure' and self._task_socket:
538 544 self._stop_scheduling_tasks()
539 545
540 546 def _stop_scheduling_tasks(self):
541 547 """Stop scheduling tasks because an engine has been unregistered
542 548 from a pure ZMQ scheduler.
543 549 """
544 550 self._task_socket.close()
545 551 self._task_socket = None
546 552 msg = "An engine has been unregistered, and we are using pure " +\
547 553 "ZMQ task scheduling. Task farming will be disabled."
548 554 if self.outstanding:
549 555 msg += " If you were running tasks when this happened, " +\
550 556 "some `outstanding` msg_ids may never resolve."
551 557 warnings.warn(msg, RuntimeWarning)
552 558
553 559 def _build_targets(self, targets):
554 560 """Turn valid target IDs or 'all' into two lists:
555 561 (int_ids, uuids).
556 562 """
557 563 if not self._ids:
558 564 # flush notification socket if no engines yet, just in case
559 565 if not self.ids:
560 566 raise error.NoEnginesRegistered("Can't build targets without any engines")
561 567
562 568 if targets is None:
563 569 targets = self._ids
564 570 elif isinstance(targets, basestring):
565 571 if targets.lower() == 'all':
566 572 targets = self._ids
567 573 else:
568 574 raise TypeError("%r not valid str target, must be 'all'"%(targets))
569 575 elif isinstance(targets, int):
570 576 if targets < 0:
571 577 targets = self.ids[targets]
572 578 if targets not in self._ids:
573 579 raise IndexError("No such engine: %i"%targets)
574 580 targets = [targets]
575 581
576 582 if isinstance(targets, slice):
577 583 indices = range(len(self._ids))[targets]
578 584 ids = self.ids
579 585 targets = [ ids[i] for i in indices ]
580 586
581 587 if not isinstance(targets, (tuple, list, xrange)):
582 588 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
583 589
584 590 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
585 591
586 592 def _connect(self, sshserver, ssh_kwargs, timeout):
587 593 """setup all our socket connections to the cluster. This is called from
588 594 __init__."""
589 595
590 596 # Maybe allow reconnecting?
591 597 if self._connected:
592 598 return
593 599 self._connected=True
594 600
595 601 def connect_socket(s, url):
596 602 # url = util.disambiguate_url(url, self._config['location'])
597 603 if self._ssh:
598 604 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
599 605 else:
600 606 return s.connect(url)
601 607
602 608 self.session.send(self._query_socket, 'connection_request')
603 609 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
604 610 poller = zmq.Poller()
605 611 poller.register(self._query_socket, zmq.POLLIN)
606 612 # poll expects milliseconds, timeout is seconds
607 613 evts = poller.poll(timeout*1000)
608 614 if not evts:
609 615 raise error.TimeoutError("Hub connection request timed out")
610 616 idents,msg = self.session.recv(self._query_socket,mode=0)
611 617 if self.debug:
612 618 pprint(msg)
613 619 content = msg['content']
614 620 # self._config['registration'] = dict(content)
615 621 cfg = self._config
616 622 if content['status'] == 'ok':
617 623 self._mux_socket = self._context.socket(zmq.DEALER)
618 624 connect_socket(self._mux_socket, cfg['mux'])
619 625
620 626 self._task_socket = self._context.socket(zmq.DEALER)
621 627 connect_socket(self._task_socket, cfg['task'])
622 628
623 629 self._notification_socket = self._context.socket(zmq.SUB)
624 630 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
625 631 connect_socket(self._notification_socket, cfg['notification'])
626 632
627 633 self._control_socket = self._context.socket(zmq.DEALER)
628 634 connect_socket(self._control_socket, cfg['control'])
629 635
630 636 self._iopub_socket = self._context.socket(zmq.SUB)
631 637 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
632 638 connect_socket(self._iopub_socket, cfg['iopub'])
633 639
634 640 self._update_engines(dict(content['engines']))
635 641 else:
636 642 self._connected = False
637 643 raise Exception("Failed to connect!")
638 644
639 645 #--------------------------------------------------------------------------
640 646 # handlers and callbacks for incoming messages
641 647 #--------------------------------------------------------------------------
642 648
643 649 def _unwrap_exception(self, content):
644 650 """unwrap exception, and remap engine_id to int."""
645 651 e = error.unwrap_exception(content)
646 652 # print e.traceback
647 653 if e.engine_info:
648 654 e_uuid = e.engine_info['engine_uuid']
649 655 eid = self._engines[e_uuid]
650 656 e.engine_info['engine_id'] = eid
651 657 return e
652 658
653 659 def _extract_metadata(self, msg):
654 660 header = msg['header']
655 661 parent = msg['parent_header']
656 662 msg_meta = msg['metadata']
657 663 content = msg['content']
658 664 md = {'msg_id' : parent['msg_id'],
659 665 'received' : datetime.now(),
660 666 'engine_uuid' : msg_meta.get('engine', None),
661 667 'follow' : msg_meta.get('follow', []),
662 668 'after' : msg_meta.get('after', []),
663 669 'status' : content['status'],
664 670 }
665 671
666 672 if md['engine_uuid'] is not None:
667 673 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
668 674
669 675 if 'date' in parent:
670 676 md['submitted'] = parent['date']
671 677 if 'started' in msg_meta:
672 678 md['started'] = msg_meta['started']
673 679 if 'date' in header:
674 680 md['completed'] = header['date']
675 681 return md
676 682
677 683 def _register_engine(self, msg):
678 684 """Register a new engine, and update our connection info."""
679 685 content = msg['content']
680 686 eid = content['id']
681 687 d = {eid : content['uuid']}
682 688 self._update_engines(d)
683 689
684 690 def _unregister_engine(self, msg):
685 691 """Unregister an engine that has died."""
686 692 content = msg['content']
687 693 eid = int(content['id'])
688 694 if eid in self._ids:
689 695 self._ids.remove(eid)
690 696 uuid = self._engines.pop(eid)
691 697
692 698 self._handle_stranded_msgs(eid, uuid)
693 699
694 700 if self._task_socket and self._task_scheme == 'pure':
695 701 self._stop_scheduling_tasks()
696 702
697 703 def _handle_stranded_msgs(self, eid, uuid):
698 704 """Handle messages known to be on an engine when the engine unregisters.
699 705
700 706 It is possible that this will fire prematurely - that is, an engine will
701 707 go down after completing a result, and the client will be notified
702 708 of the unregistration and later receive the successful result.
703 709 """
704 710
705 711 outstanding = self._outstanding_dict[uuid]
706 712
707 713 for msg_id in list(outstanding):
708 714 if msg_id in self.results:
709 715 # we already
710 716 continue
711 717 try:
712 718 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
713 719 except:
714 720 content = error.wrap_exception()
715 721 # build a fake message:
716 722 msg = self.session.msg('apply_reply', content=content)
717 723 msg['parent_header']['msg_id'] = msg_id
718 724 msg['metadata']['engine'] = uuid
719 725 self._handle_apply_reply(msg)
720 726
721 727 def _handle_execute_reply(self, msg):
722 728 """Save the reply to an execute_request into our results.
723 729
724 730 execute messages are never actually used. apply is used instead.
725 731 """
726 732
727 733 parent = msg['parent_header']
728 734 msg_id = parent['msg_id']
729 735 if msg_id not in self.outstanding:
730 736 if msg_id in self.history:
731 737 print ("got stale result: %s"%msg_id)
732 738 else:
733 739 print ("got unknown result: %s"%msg_id)
734 740 else:
735 741 self.outstanding.remove(msg_id)
736 742
737 743 content = msg['content']
738 744 header = msg['header']
739 745
740 746 # construct metadata:
741 747 md = self.metadata[msg_id]
742 748 md.update(self._extract_metadata(msg))
743 749 # is this redundant?
744 750 self.metadata[msg_id] = md
745 751
746 752 e_outstanding = self._outstanding_dict[md['engine_uuid']]
747 753 if msg_id in e_outstanding:
748 754 e_outstanding.remove(msg_id)
749 755
750 756 # construct result:
751 757 if content['status'] == 'ok':
752 758 self.results[msg_id] = ExecuteReply(msg_id, content, md)
753 759 elif content['status'] == 'aborted':
754 760 self.results[msg_id] = error.TaskAborted(msg_id)
755 761 elif content['status'] == 'resubmitted':
756 762 # TODO: handle resubmission
757 763 pass
758 764 else:
759 765 self.results[msg_id] = self._unwrap_exception(content)
760 766
761 767 def _handle_apply_reply(self, msg):
762 768 """Save the reply to an apply_request into our results."""
763 769 parent = msg['parent_header']
764 770 msg_id = parent['msg_id']
765 771 if msg_id not in self.outstanding:
766 772 if msg_id in self.history:
767 773 print ("got stale result: %s"%msg_id)
768 774 print self.results[msg_id]
769 775 print msg
770 776 else:
771 777 print ("got unknown result: %s"%msg_id)
772 778 else:
773 779 self.outstanding.remove(msg_id)
774 780 content = msg['content']
775 781 header = msg['header']
776 782
777 783 # construct metadata:
778 784 md = self.metadata[msg_id]
779 785 md.update(self._extract_metadata(msg))
780 786 # is this redundant?
781 787 self.metadata[msg_id] = md
782 788
783 789 e_outstanding = self._outstanding_dict[md['engine_uuid']]
784 790 if msg_id in e_outstanding:
785 791 e_outstanding.remove(msg_id)
786 792
787 793 # construct result:
788 794 if content['status'] == 'ok':
789 795 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
790 796 elif content['status'] == 'aborted':
791 797 self.results[msg_id] = error.TaskAborted(msg_id)
792 798 elif content['status'] == 'resubmitted':
793 799 # TODO: handle resubmission
794 800 pass
795 801 else:
796 802 self.results[msg_id] = self._unwrap_exception(content)
797 803
798 804 def _flush_notifications(self):
799 805 """Flush notifications of engine registrations waiting
800 806 in ZMQ queue."""
801 807 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
802 808 while msg is not None:
803 809 if self.debug:
804 810 pprint(msg)
805 811 msg_type = msg['header']['msg_type']
806 812 handler = self._notification_handlers.get(msg_type, None)
807 813 if handler is None:
808 814 raise Exception("Unhandled message type: %s" % msg_type)
809 815 else:
810 816 handler(msg)
811 817 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
812 818
813 819 def _flush_results(self, sock):
814 820 """Flush task or queue results waiting in ZMQ queue."""
815 821 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
816 822 while msg is not None:
817 823 if self.debug:
818 824 pprint(msg)
819 825 msg_type = msg['header']['msg_type']
820 826 handler = self._queue_handlers.get(msg_type, None)
821 827 if handler is None:
822 828 raise Exception("Unhandled message type: %s" % msg_type)
823 829 else:
824 830 handler(msg)
825 831 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
826 832
827 833 def _flush_control(self, sock):
828 834 """Flush replies from the control channel waiting
829 835 in the ZMQ queue.
830 836
831 837 Currently: ignore them."""
832 838 if self._ignored_control_replies <= 0:
833 839 return
834 840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
835 841 while msg is not None:
836 842 self._ignored_control_replies -= 1
837 843 if self.debug:
838 844 pprint(msg)
839 845 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
840 846
841 847 def _flush_ignored_control(self):
842 848 """flush ignored control replies"""
843 849 while self._ignored_control_replies > 0:
844 850 self.session.recv(self._control_socket)
845 851 self._ignored_control_replies -= 1
846 852
847 853 def _flush_ignored_hub_replies(self):
848 854 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
849 855 while msg is not None:
850 856 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
851 857
852 858 def _flush_iopub(self, sock):
853 859 """Flush replies from the iopub channel waiting
854 860 in the ZMQ queue.
855 861 """
856 862 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
857 863 while msg is not None:
858 864 if self.debug:
859 865 pprint(msg)
860 866 parent = msg['parent_header']
861 867 # ignore IOPub messages with no parent.
862 868 # Caused by print statements or warnings from before the first execution.
863 869 if not parent:
864 870 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
865 871 continue
866 872 msg_id = parent['msg_id']
867 873 content = msg['content']
868 874 header = msg['header']
869 875 msg_type = msg['header']['msg_type']
870 876
871 877 # init metadata:
872 878 md = self.metadata[msg_id]
873 879
874 880 if msg_type == 'stream':
875 881 name = content['name']
876 882 s = md[name] or ''
877 883 md[name] = s + content['data']
878 884 elif msg_type == 'pyerr':
879 885 md.update({'pyerr' : self._unwrap_exception(content)})
880 886 elif msg_type == 'pyin':
881 887 md.update({'pyin' : content['code']})
882 888 elif msg_type == 'display_data':
883 889 md['outputs'].append(content)
884 890 elif msg_type == 'pyout':
885 891 md['pyout'] = content
886 892 elif msg_type == 'data_message':
887 893 data, remainder = serialize.unserialize_object(msg['buffers'])
888 894 md['data'].update(data)
889 895 elif msg_type == 'status':
890 896 # idle message comes after all outputs
891 897 if content['execution_state'] == 'idle':
892 898 md['outputs_ready'] = True
893 899 else:
894 900 # unhandled msg_type (status, etc.)
895 901 pass
896 902
897 903 # reduntant?
898 904 self.metadata[msg_id] = md
899 905
900 906 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
901 907
902 908 #--------------------------------------------------------------------------
903 909 # len, getitem
904 910 #--------------------------------------------------------------------------
905 911
906 912 def __len__(self):
907 913 """len(client) returns # of engines."""
908 914 return len(self.ids)
909 915
910 916 def __getitem__(self, key):
911 917 """index access returns DirectView multiplexer objects
912 918
913 919 Must be int, slice, or list/tuple/xrange of ints"""
914 920 if not isinstance(key, (int, slice, tuple, list, xrange)):
915 921 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
916 922 else:
917 923 return self.direct_view(key)
918 924
919 925 #--------------------------------------------------------------------------
920 926 # Begin public methods
921 927 #--------------------------------------------------------------------------
922 928
923 929 @property
924 930 def ids(self):
925 931 """Always up-to-date ids property."""
926 932 self._flush_notifications()
927 933 # always copy:
928 934 return list(self._ids)
929 935
930 936 def activate(self, targets='all', suffix=''):
931 937 """Create a DirectView and register it with IPython magics
932 938
933 939 Defines the magics `%px, %autopx, %pxresult, %%px`
934 940
935 941 Parameters
936 942 ----------
937 943
938 944 targets: int, list of ints, or 'all'
939 945 The engines on which the view's magics will run
940 946 suffix: str [default: '']
941 947 The suffix, if any, for the magics. This allows you to have
942 948 multiple views associated with parallel magics at the same time.
943 949
944 950 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
945 951 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
946 952 on engine 0.
947 953 """
948 954 view = self.direct_view(targets)
949 955 view.block = True
950 956 view.activate(suffix)
951 957 return view
952 958
953 959 def close(self):
954 960 if self._closed:
955 961 return
956 962 self.stop_spin_thread()
957 963 snames = filter(lambda n: n.endswith('socket'), dir(self))
958 964 for socket in map(lambda name: getattr(self, name), snames):
959 965 if isinstance(socket, zmq.Socket) and not socket.closed:
960 966 socket.close()
961 967 self._closed = True
962 968
963 969 def _spin_every(self, interval=1):
964 970 """target func for use in spin_thread"""
965 971 while True:
966 972 if self._stop_spinning.is_set():
967 973 return
968 974 time.sleep(interval)
969 975 self.spin()
970 976
971 977 def spin_thread(self, interval=1):
972 978 """call Client.spin() in a background thread on some regular interval
973 979
974 980 This helps ensure that messages don't pile up too much in the zmq queue
975 981 while you are working on other things, or just leaving an idle terminal.
976 982
977 983 It also helps limit potential padding of the `received` timestamp
978 984 on AsyncResult objects, used for timings.
979 985
980 986 Parameters
981 987 ----------
982 988
983 989 interval : float, optional
984 990 The interval on which to spin the client in the background thread
985 991 (simply passed to time.sleep).
986 992
987 993 Notes
988 994 -----
989 995
990 996 For precision timing, you may want to use this method to put a bound
991 997 on the jitter (in seconds) in `received` timestamps used
992 998 in AsyncResult.wall_time.
993 999
994 1000 """
995 1001 if self._spin_thread is not None:
996 1002 self.stop_spin_thread()
997 1003 self._stop_spinning.clear()
998 1004 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
999 1005 self._spin_thread.daemon = True
1000 1006 self._spin_thread.start()
1001 1007
1002 1008 def stop_spin_thread(self):
1003 1009 """stop background spin_thread, if any"""
1004 1010 if self._spin_thread is not None:
1005 1011 self._stop_spinning.set()
1006 1012 self._spin_thread.join()
1007 1013 self._spin_thread = None
1008 1014
1009 1015 def spin(self):
1010 1016 """Flush any registration notifications and execution results
1011 1017 waiting in the ZMQ queue.
1012 1018 """
1013 1019 if self._notification_socket:
1014 1020 self._flush_notifications()
1015 1021 if self._iopub_socket:
1016 1022 self._flush_iopub(self._iopub_socket)
1017 1023 if self._mux_socket:
1018 1024 self._flush_results(self._mux_socket)
1019 1025 if self._task_socket:
1020 1026 self._flush_results(self._task_socket)
1021 1027 if self._control_socket:
1022 1028 self._flush_control(self._control_socket)
1023 1029 if self._query_socket:
1024 1030 self._flush_ignored_hub_replies()
1025 1031
1026 1032 def wait(self, jobs=None, timeout=-1):
1027 1033 """waits on one or more `jobs`, for up to `timeout` seconds.
1028 1034
1029 1035 Parameters
1030 1036 ----------
1031 1037
1032 1038 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1033 1039 ints are indices to self.history
1034 1040 strs are msg_ids
1035 1041 default: wait on all outstanding messages
1036 1042 timeout : float
1037 1043 a time in seconds, after which to give up.
1038 1044 default is -1, which means no timeout
1039 1045
1040 1046 Returns
1041 1047 -------
1042 1048
1043 1049 True : when all msg_ids are done
1044 1050 False : timeout reached, some msg_ids still outstanding
1045 1051 """
1046 1052 tic = time.time()
1047 1053 if jobs is None:
1048 1054 theids = self.outstanding
1049 1055 else:
1050 1056 if isinstance(jobs, (int, basestring, AsyncResult)):
1051 1057 jobs = [jobs]
1052 1058 theids = set()
1053 1059 for job in jobs:
1054 1060 if isinstance(job, int):
1055 1061 # index access
1056 1062 job = self.history[job]
1057 1063 elif isinstance(job, AsyncResult):
1058 1064 map(theids.add, job.msg_ids)
1059 1065 continue
1060 1066 theids.add(job)
1061 1067 if not theids.intersection(self.outstanding):
1062 1068 return True
1063 1069 self.spin()
1064 1070 while theids.intersection(self.outstanding):
1065 1071 if timeout >= 0 and ( time.time()-tic ) > timeout:
1066 1072 break
1067 1073 time.sleep(1e-3)
1068 1074 self.spin()
1069 1075 return len(theids.intersection(self.outstanding)) == 0
1070 1076
1071 1077 #--------------------------------------------------------------------------
1072 1078 # Control methods
1073 1079 #--------------------------------------------------------------------------
1074 1080
1075 1081 @spin_first
1076 1082 def clear(self, targets=None, block=None):
1077 1083 """Clear the namespace in target(s)."""
1078 1084 block = self.block if block is None else block
1079 1085 targets = self._build_targets(targets)[0]
1080 1086 for t in targets:
1081 1087 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1082 1088 error = False
1083 1089 if block:
1084 1090 self._flush_ignored_control()
1085 1091 for i in range(len(targets)):
1086 1092 idents,msg = self.session.recv(self._control_socket,0)
1087 1093 if self.debug:
1088 1094 pprint(msg)
1089 1095 if msg['content']['status'] != 'ok':
1090 1096 error = self._unwrap_exception(msg['content'])
1091 1097 else:
1092 1098 self._ignored_control_replies += len(targets)
1093 1099 if error:
1094 1100 raise error
1095 1101
1096 1102
1097 1103 @spin_first
1098 1104 def abort(self, jobs=None, targets=None, block=None):
1099 1105 """Abort specific jobs from the execution queues of target(s).
1100 1106
1101 1107 This is a mechanism to prevent jobs that have already been submitted
1102 1108 from executing.
1103 1109
1104 1110 Parameters
1105 1111 ----------
1106 1112
1107 1113 jobs : msg_id, list of msg_ids, or AsyncResult
1108 1114 The jobs to be aborted
1109 1115
1110 1116 If unspecified/None: abort all outstanding jobs.
1111 1117
1112 1118 """
1113 1119 block = self.block if block is None else block
1114 1120 jobs = jobs if jobs is not None else list(self.outstanding)
1115 1121 targets = self._build_targets(targets)[0]
1116 1122
1117 1123 msg_ids = []
1118 1124 if isinstance(jobs, (basestring,AsyncResult)):
1119 1125 jobs = [jobs]
1120 1126 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1121 1127 if bad_ids:
1122 1128 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1123 1129 for j in jobs:
1124 1130 if isinstance(j, AsyncResult):
1125 1131 msg_ids.extend(j.msg_ids)
1126 1132 else:
1127 1133 msg_ids.append(j)
1128 1134 content = dict(msg_ids=msg_ids)
1129 1135 for t in targets:
1130 1136 self.session.send(self._control_socket, 'abort_request',
1131 1137 content=content, ident=t)
1132 1138 error = False
1133 1139 if block:
1134 1140 self._flush_ignored_control()
1135 1141 for i in range(len(targets)):
1136 1142 idents,msg = self.session.recv(self._control_socket,0)
1137 1143 if self.debug:
1138 1144 pprint(msg)
1139 1145 if msg['content']['status'] != 'ok':
1140 1146 error = self._unwrap_exception(msg['content'])
1141 1147 else:
1142 1148 self._ignored_control_replies += len(targets)
1143 1149 if error:
1144 1150 raise error
1145 1151
1146 1152 @spin_first
1147 1153 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1148 1154 """Terminates one or more engine processes, optionally including the hub.
1149 1155
1150 1156 Parameters
1151 1157 ----------
1152 1158
1153 1159 targets: list of ints or 'all' [default: all]
1154 1160 Which engines to shutdown.
1155 1161 hub: bool [default: False]
1156 1162 Whether to include the Hub. hub=True implies targets='all'.
1157 1163 block: bool [default: self.block]
1158 1164 Whether to wait for clean shutdown replies or not.
1159 1165 restart: bool [default: False]
1160 1166 NOT IMPLEMENTED
1161 1167 whether to restart engines after shutting them down.
1162 1168 """
1163 1169 from IPython.parallel.error import NoEnginesRegistered
1164 1170 if restart:
1165 1171 raise NotImplementedError("Engine restart is not yet implemented")
1166 1172
1167 1173 block = self.block if block is None else block
1168 1174 if hub:
1169 1175 targets = 'all'
1170 1176 try:
1171 1177 targets = self._build_targets(targets)[0]
1172 1178 except NoEnginesRegistered:
1173 1179 targets = []
1174 1180 for t in targets:
1175 1181 self.session.send(self._control_socket, 'shutdown_request',
1176 1182 content={'restart':restart},ident=t)
1177 1183 error = False
1178 1184 if block or hub:
1179 1185 self._flush_ignored_control()
1180 1186 for i in range(len(targets)):
1181 1187 idents,msg = self.session.recv(self._control_socket, 0)
1182 1188 if self.debug:
1183 1189 pprint(msg)
1184 1190 if msg['content']['status'] != 'ok':
1185 1191 error = self._unwrap_exception(msg['content'])
1186 1192 else:
1187 1193 self._ignored_control_replies += len(targets)
1188 1194
1189 1195 if hub:
1190 1196 time.sleep(0.25)
1191 1197 self.session.send(self._query_socket, 'shutdown_request')
1192 1198 idents,msg = self.session.recv(self._query_socket, 0)
1193 1199 if self.debug:
1194 1200 pprint(msg)
1195 1201 if msg['content']['status'] != 'ok':
1196 1202 error = self._unwrap_exception(msg['content'])
1197 1203
1198 1204 if error:
1199 1205 raise error
1200 1206
1201 1207 #--------------------------------------------------------------------------
1202 1208 # Execution related methods
1203 1209 #--------------------------------------------------------------------------
1204 1210
1205 1211 def _maybe_raise(self, result):
1206 1212 """wrapper for maybe raising an exception if apply failed."""
1207 1213 if isinstance(result, error.RemoteError):
1208 1214 raise result
1209 1215
1210 1216 return result
1211 1217
1212 1218 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1213 1219 ident=None):
1214 1220 """construct and send an apply message via a socket.
1215 1221
1216 1222 This is the principal method with which all engine execution is performed by views.
1217 1223 """
1218 1224
1219 1225 if self._closed:
1220 1226 raise RuntimeError("Client cannot be used after its sockets have been closed")
1221 1227
1222 1228 # defaults:
1223 1229 args = args if args is not None else []
1224 1230 kwargs = kwargs if kwargs is not None else {}
1225 1231 metadata = metadata if metadata is not None else {}
1226 1232
1227 1233 # validate arguments
1228 1234 if not callable(f) and not isinstance(f, Reference):
1229 1235 raise TypeError("f must be callable, not %s"%type(f))
1230 1236 if not isinstance(args, (tuple, list)):
1231 1237 raise TypeError("args must be tuple or list, not %s"%type(args))
1232 1238 if not isinstance(kwargs, dict):
1233 1239 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1234 1240 if not isinstance(metadata, dict):
1235 1241 raise TypeError("metadata must be dict, not %s"%type(metadata))
1236 1242
1237 1243 bufs = serialize.pack_apply_message(f, args, kwargs,
1238 1244 buffer_threshold=self.session.buffer_threshold,
1239 1245 item_threshold=self.session.item_threshold,
1240 1246 )
1241 1247
1242 1248 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1243 1249 metadata=metadata, track=track)
1244 1250
1245 1251 msg_id = msg['header']['msg_id']
1246 1252 self.outstanding.add(msg_id)
1247 1253 if ident:
1248 1254 # possibly routed to a specific engine
1249 1255 if isinstance(ident, list):
1250 1256 ident = ident[-1]
1251 1257 if ident in self._engines.values():
1252 1258 # save for later, in case of engine death
1253 1259 self._outstanding_dict[ident].add(msg_id)
1254 1260 self.history.append(msg_id)
1255 1261 self.metadata[msg_id]['submitted'] = datetime.now()
1256 1262
1257 1263 return msg
1258 1264
1259 1265 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1260 1266 """construct and send an execute request via a socket.
1261 1267
1262 1268 """
1263 1269
1264 1270 if self._closed:
1265 1271 raise RuntimeError("Client cannot be used after its sockets have been closed")
1266 1272
1267 1273 # defaults:
1268 1274 metadata = metadata if metadata is not None else {}
1269 1275
1270 1276 # validate arguments
1271 1277 if not isinstance(code, basestring):
1272 1278 raise TypeError("code must be text, not %s" % type(code))
1273 1279 if not isinstance(metadata, dict):
1274 1280 raise TypeError("metadata must be dict, not %s" % type(metadata))
1275 1281
1276 1282 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1277 1283
1278 1284
1279 1285 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1280 1286 metadata=metadata)
1281 1287
1282 1288 msg_id = msg['header']['msg_id']
1283 1289 self.outstanding.add(msg_id)
1284 1290 if ident:
1285 1291 # possibly routed to a specific engine
1286 1292 if isinstance(ident, list):
1287 1293 ident = ident[-1]
1288 1294 if ident in self._engines.values():
1289 1295 # save for later, in case of engine death
1290 1296 self._outstanding_dict[ident].add(msg_id)
1291 1297 self.history.append(msg_id)
1292 1298 self.metadata[msg_id]['submitted'] = datetime.now()
1293 1299
1294 1300 return msg
1295 1301
1296 1302 #--------------------------------------------------------------------------
1297 1303 # construct a View object
1298 1304 #--------------------------------------------------------------------------
1299 1305
1300 1306 def load_balanced_view(self, targets=None):
1301 1307 """construct a DirectView object.
1302 1308
1303 1309 If no arguments are specified, create a LoadBalancedView
1304 1310 using all engines.
1305 1311
1306 1312 Parameters
1307 1313 ----------
1308 1314
1309 1315 targets: list,slice,int,etc. [default: use all engines]
1310 1316 The subset of engines across which to load-balance
1311 1317 """
1312 1318 if targets == 'all':
1313 1319 targets = None
1314 1320 if targets is not None:
1315 1321 targets = self._build_targets(targets)[1]
1316 1322 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1317 1323
1318 1324 def direct_view(self, targets='all'):
1319 1325 """construct a DirectView object.
1320 1326
1321 1327 If no targets are specified, create a DirectView using all engines.
1322 1328
1323 1329 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1324 1330 evaluate the target engines at each execution, whereas rc[:] will connect to
1325 1331 all *current* engines, and that list will not change.
1326 1332
1327 1333 That is, 'all' will always use all engines, whereas rc[:] will not use
1328 1334 engines added after the DirectView is constructed.
1329 1335
1330 1336 Parameters
1331 1337 ----------
1332 1338
1333 1339 targets: list,slice,int,etc. [default: use all engines]
1334 1340 The engines to use for the View
1335 1341 """
1336 1342 single = isinstance(targets, int)
1337 1343 # allow 'all' to be lazily evaluated at each execution
1338 1344 if targets != 'all':
1339 1345 targets = self._build_targets(targets)[1]
1340 1346 if single:
1341 1347 targets = targets[0]
1342 1348 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1343 1349
1344 1350 #--------------------------------------------------------------------------
1345 1351 # Query methods
1346 1352 #--------------------------------------------------------------------------
1347 1353
1348 1354 @spin_first
1349 1355 def get_result(self, indices_or_msg_ids=None, block=None):
1350 1356 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1351 1357
1352 1358 If the client already has the results, no request to the Hub will be made.
1353 1359
1354 1360 This is a convenient way to construct AsyncResult objects, which are wrappers
1355 1361 that include metadata about execution, and allow for awaiting results that
1356 1362 were not submitted by this Client.
1357 1363
1358 1364 It can also be a convenient way to retrieve the metadata associated with
1359 1365 blocking execution, since it always retrieves
1360 1366
1361 1367 Examples
1362 1368 --------
1363 1369 ::
1364 1370
1365 1371 In [10]: r = client.apply()
1366 1372
1367 1373 Parameters
1368 1374 ----------
1369 1375
1370 1376 indices_or_msg_ids : integer history index, str msg_id, or list of either
1371 1377 The indices or msg_ids of indices to be retrieved
1372 1378
1373 1379 block : bool
1374 1380 Whether to wait for the result to be done
1375 1381
1376 1382 Returns
1377 1383 -------
1378 1384
1379 1385 AsyncResult
1380 1386 A single AsyncResult object will always be returned.
1381 1387
1382 1388 AsyncHubResult
1383 1389 A subclass of AsyncResult that retrieves results from the Hub
1384 1390
1385 1391 """
1386 1392 block = self.block if block is None else block
1387 1393 if indices_or_msg_ids is None:
1388 1394 indices_or_msg_ids = -1
1389 1395
1390 1396 single_result = False
1391 1397 if not isinstance(indices_or_msg_ids, (list,tuple)):
1392 1398 indices_or_msg_ids = [indices_or_msg_ids]
1393 1399 single_result = True
1394 1400
1395 1401 theids = []
1396 1402 for id in indices_or_msg_ids:
1397 1403 if isinstance(id, int):
1398 1404 id = self.history[id]
1399 1405 if not isinstance(id, basestring):
1400 1406 raise TypeError("indices must be str or int, not %r"%id)
1401 1407 theids.append(id)
1402 1408
1403 1409 local_ids = filter(lambda msg_id: msg_id in self.outstanding or msg_id in self.results, theids)
1404 1410 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1405 1411
1406 1412 # given single msg_id initially, get_result shot get the result itself,
1407 1413 # not a length-one list
1408 1414 if single_result:
1409 1415 theids = theids[0]
1410 1416
1411 1417 if remote_ids:
1412 1418 ar = AsyncHubResult(self, msg_ids=theids)
1413 1419 else:
1414 1420 ar = AsyncResult(self, msg_ids=theids)
1415 1421
1416 1422 if block:
1417 1423 ar.wait()
1418 1424
1419 1425 return ar
1420 1426
1421 1427 @spin_first
1422 1428 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1423 1429 """Resubmit one or more tasks.
1424 1430
1425 1431 in-flight tasks may not be resubmitted.
1426 1432
1427 1433 Parameters
1428 1434 ----------
1429 1435
1430 1436 indices_or_msg_ids : integer history index, str msg_id, or list of either
1431 1437 The indices or msg_ids of indices to be retrieved
1432 1438
1433 1439 block : bool
1434 1440 Whether to wait for the result to be done
1435 1441
1436 1442 Returns
1437 1443 -------
1438 1444
1439 1445 AsyncHubResult
1440 1446 A subclass of AsyncResult that retrieves results from the Hub
1441 1447
1442 1448 """
1443 1449 block = self.block if block is None else block
1444 1450 if indices_or_msg_ids is None:
1445 1451 indices_or_msg_ids = -1
1446 1452
1447 1453 if not isinstance(indices_or_msg_ids, (list,tuple)):
1448 1454 indices_or_msg_ids = [indices_or_msg_ids]
1449 1455
1450 1456 theids = []
1451 1457 for id in indices_or_msg_ids:
1452 1458 if isinstance(id, int):
1453 1459 id = self.history[id]
1454 1460 if not isinstance(id, basestring):
1455 1461 raise TypeError("indices must be str or int, not %r"%id)
1456 1462 theids.append(id)
1457 1463
1458 1464 content = dict(msg_ids = theids)
1459 1465
1460 1466 self.session.send(self._query_socket, 'resubmit_request', content)
1461 1467
1462 1468 zmq.select([self._query_socket], [], [])
1463 1469 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1464 1470 if self.debug:
1465 1471 pprint(msg)
1466 1472 content = msg['content']
1467 1473 if content['status'] != 'ok':
1468 1474 raise self._unwrap_exception(content)
1469 1475 mapping = content['resubmitted']
1470 1476 new_ids = [ mapping[msg_id] for msg_id in theids ]
1471 1477
1472 1478 ar = AsyncHubResult(self, msg_ids=new_ids)
1473 1479
1474 1480 if block:
1475 1481 ar.wait()
1476 1482
1477 1483 return ar
1478 1484
1479 1485 @spin_first
1480 1486 def result_status(self, msg_ids, status_only=True):
1481 1487 """Check on the status of the result(s) of the apply request with `msg_ids`.
1482 1488
1483 1489 If status_only is False, then the actual results will be retrieved, else
1484 1490 only the status of the results will be checked.
1485 1491
1486 1492 Parameters
1487 1493 ----------
1488 1494
1489 1495 msg_ids : list of msg_ids
1490 1496 if int:
1491 1497 Passed as index to self.history for convenience.
1492 1498 status_only : bool (default: True)
1493 1499 if False:
1494 1500 Retrieve the actual results of completed tasks.
1495 1501
1496 1502 Returns
1497 1503 -------
1498 1504
1499 1505 results : dict
1500 1506 There will always be the keys 'pending' and 'completed', which will
1501 1507 be lists of msg_ids that are incomplete or complete. If `status_only`
1502 1508 is False, then completed results will be keyed by their `msg_id`.
1503 1509 """
1504 1510 if not isinstance(msg_ids, (list,tuple)):
1505 1511 msg_ids = [msg_ids]
1506 1512
1507 1513 theids = []
1508 1514 for msg_id in msg_ids:
1509 1515 if isinstance(msg_id, int):
1510 1516 msg_id = self.history[msg_id]
1511 1517 if not isinstance(msg_id, basestring):
1512 1518 raise TypeError("msg_ids must be str, not %r"%msg_id)
1513 1519 theids.append(msg_id)
1514 1520
1515 1521 completed = []
1516 1522 local_results = {}
1517 1523
1518 1524 # comment this block out to temporarily disable local shortcut:
1519 1525 for msg_id in theids:
1520 1526 if msg_id in self.results:
1521 1527 completed.append(msg_id)
1522 1528 local_results[msg_id] = self.results[msg_id]
1523 1529 theids.remove(msg_id)
1524 1530
1525 1531 if theids: # some not locally cached
1526 1532 content = dict(msg_ids=theids, status_only=status_only)
1527 1533 msg = self.session.send(self._query_socket, "result_request", content=content)
1528 1534 zmq.select([self._query_socket], [], [])
1529 1535 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1530 1536 if self.debug:
1531 1537 pprint(msg)
1532 1538 content = msg['content']
1533 1539 if content['status'] != 'ok':
1534 1540 raise self._unwrap_exception(content)
1535 1541 buffers = msg['buffers']
1536 1542 else:
1537 1543 content = dict(completed=[],pending=[])
1538 1544
1539 1545 content['completed'].extend(completed)
1540 1546
1541 1547 if status_only:
1542 1548 return content
1543 1549
1544 1550 failures = []
1545 1551 # load cached results into result:
1546 1552 content.update(local_results)
1547 1553
1548 1554 # update cache with results:
1549 1555 for msg_id in sorted(theids):
1550 1556 if msg_id in content['completed']:
1551 1557 rec = content[msg_id]
1552 1558 parent = rec['header']
1553 1559 header = rec['result_header']
1554 1560 rcontent = rec['result_content']
1555 1561 iodict = rec['io']
1556 1562 if isinstance(rcontent, str):
1557 1563 rcontent = self.session.unpack(rcontent)
1558 1564
1559 1565 md = self.metadata[msg_id]
1560 1566 md_msg = dict(
1561 1567 content=rcontent,
1562 1568 parent_header=parent,
1563 1569 header=header,
1564 1570 metadata=rec['result_metadata'],
1565 1571 )
1566 1572 md.update(self._extract_metadata(md_msg))
1567 1573 if rec.get('received'):
1568 1574 md['received'] = rec['received']
1569 1575 md.update(iodict)
1570 1576
1571 1577 if rcontent['status'] == 'ok':
1572 1578 if header['msg_type'] == 'apply_reply':
1573 1579 res,buffers = serialize.unserialize_object(buffers)
1574 1580 elif header['msg_type'] == 'execute_reply':
1575 1581 res = ExecuteReply(msg_id, rcontent, md)
1576 1582 else:
1577 1583 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1578 1584 else:
1579 1585 res = self._unwrap_exception(rcontent)
1580 1586 failures.append(res)
1581 1587
1582 1588 self.results[msg_id] = res
1583 1589 content[msg_id] = res
1584 1590
1585 1591 if len(theids) == 1 and failures:
1586 1592 raise failures[0]
1587 1593
1588 1594 error.collect_exceptions(failures, "result_status")
1589 1595 return content
1590 1596
1591 1597 @spin_first
1592 1598 def queue_status(self, targets='all', verbose=False):
1593 1599 """Fetch the status of engine queues.
1594 1600
1595 1601 Parameters
1596 1602 ----------
1597 1603
1598 1604 targets : int/str/list of ints/strs
1599 1605 the engines whose states are to be queried.
1600 1606 default : all
1601 1607 verbose : bool
1602 1608 Whether to return lengths only, or lists of ids for each element
1603 1609 """
1604 1610 if targets == 'all':
1605 1611 # allow 'all' to be evaluated on the engine
1606 1612 engine_ids = None
1607 1613 else:
1608 1614 engine_ids = self._build_targets(targets)[1]
1609 1615 content = dict(targets=engine_ids, verbose=verbose)
1610 1616 self.session.send(self._query_socket, "queue_request", content=content)
1611 1617 idents,msg = self.session.recv(self._query_socket, 0)
1612 1618 if self.debug:
1613 1619 pprint(msg)
1614 1620 content = msg['content']
1615 1621 status = content.pop('status')
1616 1622 if status != 'ok':
1617 1623 raise self._unwrap_exception(content)
1618 1624 content = rekey(content)
1619 1625 if isinstance(targets, int):
1620 1626 return content[targets]
1621 1627 else:
1622 1628 return content
1623 1629
1624 1630 def _build_msgids_from_target(self, targets=None):
1625 1631 """Build a list of msg_ids from the list of engine targets"""
1626 1632 if not targets: # needed as _build_targets otherwise uses all engines
1627 1633 return []
1628 1634 target_ids = self._build_targets(targets)[0]
1629 1635 return filter(lambda md_id: self.metadata[md_id]["engine_uuid"] in target_ids, self.metadata)
1630 1636
1631 1637 def _build_msgids_from_jobs(self, jobs=None):
1632 1638 """Build a list of msg_ids from "jobs" """
1633 1639 if not jobs:
1634 1640 return []
1635 1641 msg_ids = []
1636 1642 if isinstance(jobs, (basestring,AsyncResult)):
1637 1643 jobs = [jobs]
1638 1644 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1639 1645 if bad_ids:
1640 1646 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1641 1647 for j in jobs:
1642 1648 if isinstance(j, AsyncResult):
1643 1649 msg_ids.extend(j.msg_ids)
1644 1650 else:
1645 1651 msg_ids.append(j)
1646 1652 return msg_ids
1647 1653
1648 1654 def purge_local_results(self, jobs=[], targets=[]):
1649 1655 """Clears the client caches of results and frees such memory.
1650 1656
1651 1657 Individual results can be purged by msg_id, or the entire
1652 1658 history of specific targets can be purged.
1653 1659
1654 1660 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1655 1661
1656 1662 The client must have no outstanding tasks before purging the caches.
1657 1663 Raises `AssertionError` if there are still outstanding tasks.
1658 1664
1659 1665 After this call all `AsyncResults` are invalid and should be discarded.
1660 1666
1661 1667 If you must "reget" the results, you can still do so by using
1662 1668 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1663 1669 redownload the results from the hub if they are still available
1664 1670 (i.e `client.purge_hub_results(...)` has not been called.
1665 1671
1666 1672 Parameters
1667 1673 ----------
1668 1674
1669 1675 jobs : str or list of str or AsyncResult objects
1670 1676 the msg_ids whose results should be purged.
1671 1677 targets : int/str/list of ints/strs
1672 1678 The targets, by int_id, whose entire results are to be purged.
1673 1679
1674 1680 default : None
1675 1681 """
1676 1682 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1677 1683
1678 1684 if not targets and not jobs:
1679 1685 raise ValueError("Must specify at least one of `targets` and `jobs`")
1680 1686
1681 1687 if jobs == 'all':
1682 1688 self.results.clear()
1683 1689 self.metadata.clear()
1684 1690 return
1685 1691 else:
1686 1692 msg_ids = []
1687 1693 msg_ids.extend(self._build_msgids_from_target(targets))
1688 1694 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1689 1695 map(self.results.pop, msg_ids)
1690 1696 map(self.metadata.pop, msg_ids)
1691 1697
1692 1698
1693 1699 @spin_first
1694 1700 def purge_hub_results(self, jobs=[], targets=[]):
1695 1701 """Tell the Hub to forget results.
1696 1702
1697 1703 Individual results can be purged by msg_id, or the entire
1698 1704 history of specific targets can be purged.
1699 1705
1700 1706 Use `purge_results('all')` to scrub everything from the Hub's db.
1701 1707
1702 1708 Parameters
1703 1709 ----------
1704 1710
1705 1711 jobs : str or list of str or AsyncResult objects
1706 1712 the msg_ids whose results should be forgotten.
1707 1713 targets : int/str/list of ints/strs
1708 1714 The targets, by int_id, whose entire history is to be purged.
1709 1715
1710 1716 default : None
1711 1717 """
1712 1718 if not targets and not jobs:
1713 1719 raise ValueError("Must specify at least one of `targets` and `jobs`")
1714 1720 if targets:
1715 1721 targets = self._build_targets(targets)[1]
1716 1722
1717 1723 # construct msg_ids from jobs
1718 1724 if jobs == 'all':
1719 1725 msg_ids = jobs
1720 1726 else:
1721 1727 msg_ids = self._build_msgids_from_jobs(jobs)
1722 1728
1723 1729 content = dict(engine_ids=targets, msg_ids=msg_ids)
1724 1730 self.session.send(self._query_socket, "purge_request", content=content)
1725 1731 idents, msg = self.session.recv(self._query_socket, 0)
1726 1732 if self.debug:
1727 1733 pprint(msg)
1728 1734 content = msg['content']
1729 1735 if content['status'] != 'ok':
1730 1736 raise self._unwrap_exception(content)
1731 1737
1732 1738 def purge_results(self, jobs=[], targets=[]):
1733 1739 """Clears the cached results from both the hub and the local client
1734 1740
1735 1741 Individual results can be purged by msg_id, or the entire
1736 1742 history of specific targets can be purged.
1737 1743
1738 1744 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1739 1745 the Client's db.
1740 1746
1741 1747 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1742 1748 the same arguments.
1743 1749
1744 1750 Parameters
1745 1751 ----------
1746 1752
1747 1753 jobs : str or list of str or AsyncResult objects
1748 1754 the msg_ids whose results should be forgotten.
1749 1755 targets : int/str/list of ints/strs
1750 1756 The targets, by int_id, whose entire history is to be purged.
1751 1757
1752 1758 default : None
1753 1759 """
1754 1760 self.purge_local_results(jobs=jobs, targets=targets)
1755 1761 self.purge_hub_results(jobs=jobs, targets=targets)
1756 1762
1757 1763 def purge_everything(self):
1758 1764 """Clears all content from previous Tasks from both the hub and the local client
1759 1765
1760 1766 In addition to calling `purge_results("all")` it also deletes the history and
1761 1767 other bookkeeping lists.
1762 1768 """
1763 1769 self.purge_results("all")
1764 1770 self.history = []
1765 1771 self.session.digest_history.clear()
1766 1772
1767 1773 @spin_first
1768 1774 def hub_history(self):
1769 1775 """Get the Hub's history
1770 1776
1771 1777 Just like the Client, the Hub has a history, which is a list of msg_ids.
1772 1778 This will contain the history of all clients, and, depending on configuration,
1773 1779 may contain history across multiple cluster sessions.
1774 1780
1775 1781 Any msg_id returned here is a valid argument to `get_result`.
1776 1782
1777 1783 Returns
1778 1784 -------
1779 1785
1780 1786 msg_ids : list of strs
1781 1787 list of all msg_ids, ordered by task submission time.
1782 1788 """
1783 1789
1784 1790 self.session.send(self._query_socket, "history_request", content={})
1785 1791 idents, msg = self.session.recv(self._query_socket, 0)
1786 1792
1787 1793 if self.debug:
1788 1794 pprint(msg)
1789 1795 content = msg['content']
1790 1796 if content['status'] != 'ok':
1791 1797 raise self._unwrap_exception(content)
1792 1798 else:
1793 1799 return content['history']
1794 1800
1795 1801 @spin_first
1796 1802 def db_query(self, query, keys=None):
1797 1803 """Query the Hub's TaskRecord database
1798 1804
1799 1805 This will return a list of task record dicts that match `query`
1800 1806
1801 1807 Parameters
1802 1808 ----------
1803 1809
1804 1810 query : mongodb query dict
1805 1811 The search dict. See mongodb query docs for details.
1806 1812 keys : list of strs [optional]
1807 1813 The subset of keys to be returned. The default is to fetch everything but buffers.
1808 1814 'msg_id' will *always* be included.
1809 1815 """
1810 1816 if isinstance(keys, basestring):
1811 1817 keys = [keys]
1812 1818 content = dict(query=query, keys=keys)
1813 1819 self.session.send(self._query_socket, "db_request", content=content)
1814 1820 idents, msg = self.session.recv(self._query_socket, 0)
1815 1821 if self.debug:
1816 1822 pprint(msg)
1817 1823 content = msg['content']
1818 1824 if content['status'] != 'ok':
1819 1825 raise self._unwrap_exception(content)
1820 1826
1821 1827 records = content['records']
1822 1828
1823 1829 buffer_lens = content['buffer_lens']
1824 1830 result_buffer_lens = content['result_buffer_lens']
1825 1831 buffers = msg['buffers']
1826 1832 has_bufs = buffer_lens is not None
1827 1833 has_rbufs = result_buffer_lens is not None
1828 1834 for i,rec in enumerate(records):
1829 1835 # relink buffers
1830 1836 if has_bufs:
1831 1837 blen = buffer_lens[i]
1832 1838 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1833 1839 if has_rbufs:
1834 1840 blen = result_buffer_lens[i]
1835 1841 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1836 1842
1837 1843 return records
1838 1844
1839 1845 __all__ = [ 'Client' ]
@@ -1,801 +1,814 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 import base64
19 20 import sys
20 21 import platform
21 22 import time
22 23 from collections import namedtuple
23 24 from tempfile import mktemp
24 25 from StringIO import StringIO
25 26
26 27 import zmq
27 28 from nose import SkipTest
28 29 from nose.plugins.attrib import attr
29 30
30 31 from IPython.testing import decorators as dec
31 32 from IPython.testing.ipunittest import ParametricTestCase
32 33 from IPython.utils.io import capture_output
33 34
34 35 from IPython import parallel as pmod
35 36 from IPython.parallel import error
36 37 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
37 38 from IPython.parallel import DirectView
38 39 from IPython.parallel.util import interactive
39 40
40 41 from IPython.parallel.tests import add_engines
41 42
42 43 from .clienttest import ClusterTestCase, crash, wait, skip_without
43 44
44 45 def setup():
45 46 add_engines(3, total=True)
46 47
47 48 point = namedtuple("point", "x y")
48 49
49 50 class TestView(ClusterTestCase, ParametricTestCase):
50 51
51 52 def setUp(self):
52 53 # On Win XP, wait for resource cleanup, else parallel test group fails
53 54 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
54 55 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
55 56 time.sleep(2)
56 57 super(TestView, self).setUp()
57 58
58 59 @attr('crash')
59 60 def test_z_crash_mux(self):
60 61 """test graceful handling of engine death (direct)"""
61 62 # self.add_engines(1)
62 63 eid = self.client.ids[-1]
63 64 ar = self.client[eid].apply_async(crash)
64 65 self.assertRaisesRemote(error.EngineError, ar.get, 10)
65 66 eid = ar.engine_id
66 67 tic = time.time()
67 68 while eid in self.client.ids and time.time()-tic < 5:
68 69 time.sleep(.01)
69 70 self.client.spin()
70 71 self.assertFalse(eid in self.client.ids, "Engine should have died")
71 72
72 73 def test_push_pull(self):
73 74 """test pushing and pulling"""
74 75 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
75 76 t = self.client.ids[-1]
76 77 v = self.client[t]
77 78 push = v.push
78 79 pull = v.pull
79 80 v.block=True
80 81 nengines = len(self.client)
81 82 push({'data':data})
82 83 d = pull('data')
83 84 self.assertEqual(d, data)
84 85 self.client[:].push({'data':data})
85 86 d = self.client[:].pull('data', block=True)
86 87 self.assertEqual(d, nengines*[data])
87 88 ar = push({'data':data}, block=False)
88 89 self.assertTrue(isinstance(ar, AsyncResult))
89 90 r = ar.get()
90 91 ar = self.client[:].pull('data', block=False)
91 92 self.assertTrue(isinstance(ar, AsyncResult))
92 93 r = ar.get()
93 94 self.assertEqual(r, nengines*[data])
94 95 self.client[:].push(dict(a=10,b=20))
95 96 r = self.client[:].pull(('a','b'), block=True)
96 97 self.assertEqual(r, nengines*[[10,20]])
97 98
98 99 def test_push_pull_function(self):
99 100 "test pushing and pulling functions"
100 101 def testf(x):
101 102 return 2.0*x
102 103
103 104 t = self.client.ids[-1]
104 105 v = self.client[t]
105 106 v.block=True
106 107 push = v.push
107 108 pull = v.pull
108 109 execute = v.execute
109 110 push({'testf':testf})
110 111 r = pull('testf')
111 112 self.assertEqual(r(1.0), testf(1.0))
112 113 execute('r = testf(10)')
113 114 r = pull('r')
114 115 self.assertEqual(r, testf(10))
115 116 ar = self.client[:].push({'testf':testf}, block=False)
116 117 ar.get()
117 118 ar = self.client[:].pull('testf', block=False)
118 119 rlist = ar.get()
119 120 for r in rlist:
120 121 self.assertEqual(r(1.0), testf(1.0))
121 122 execute("def g(x): return x*x")
122 123 r = pull(('testf','g'))
123 124 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
124 125
125 126 def test_push_function_globals(self):
126 127 """test that pushed functions have access to globals"""
127 128 @interactive
128 129 def geta():
129 130 return a
130 131 # self.add_engines(1)
131 132 v = self.client[-1]
132 133 v.block=True
133 134 v['f'] = geta
134 135 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
135 136 v.execute('a=5')
136 137 v.execute('b=f()')
137 138 self.assertEqual(v['b'], 5)
138 139
139 140 def test_push_function_defaults(self):
140 141 """test that pushed functions preserve default args"""
141 142 def echo(a=10):
142 143 return a
143 144 v = self.client[-1]
144 145 v.block=True
145 146 v['f'] = echo
146 147 v.execute('b=f()')
147 148 self.assertEqual(v['b'], 10)
148 149
149 150 def test_get_result(self):
150 151 """test getting results from the Hub."""
151 152 c = pmod.Client(profile='iptest')
152 153 # self.add_engines(1)
153 154 t = c.ids[-1]
154 155 v = c[t]
155 156 v2 = self.client[t]
156 157 ar = v.apply_async(wait, 1)
157 158 # give the monitor time to notice the message
158 159 time.sleep(.25)
159 160 ahr = v2.get_result(ar.msg_ids[0])
160 161 self.assertTrue(isinstance(ahr, AsyncHubResult))
161 162 self.assertEqual(ahr.get(), ar.get())
162 163 ar2 = v2.get_result(ar.msg_ids[0])
163 164 self.assertFalse(isinstance(ar2, AsyncHubResult))
164 165 c.spin()
165 166 c.close()
166 167
167 168 def test_run_newline(self):
168 169 """test that run appends newline to files"""
169 170 tmpfile = mktemp()
170 171 with open(tmpfile, 'w') as f:
171 172 f.write("""def g():
172 173 return 5
173 174 """)
174 175 v = self.client[-1]
175 176 v.run(tmpfile, block=True)
176 177 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
177 178
178 179 def test_apply_tracked(self):
179 180 """test tracking for apply"""
180 181 # self.add_engines(1)
181 182 t = self.client.ids[-1]
182 183 v = self.client[t]
183 184 v.block=False
184 185 def echo(n=1024*1024, **kwargs):
185 186 with v.temp_flags(**kwargs):
186 187 return v.apply(lambda x: x, 'x'*n)
187 188 ar = echo(1, track=False)
188 189 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 190 self.assertTrue(ar.sent)
190 191 ar = echo(track=True)
191 192 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 193 self.assertEqual(ar.sent, ar._tracker.done)
193 194 ar._tracker.wait()
194 195 self.assertTrue(ar.sent)
195 196
196 197 def test_push_tracked(self):
197 198 t = self.client.ids[-1]
198 199 ns = dict(x='x'*1024*1024)
199 200 v = self.client[t]
200 201 ar = v.push(ns, block=False, track=False)
201 202 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 203 self.assertTrue(ar.sent)
203 204
204 205 ar = v.push(ns, block=False, track=True)
205 206 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 207 ar._tracker.wait()
207 208 self.assertEqual(ar.sent, ar._tracker.done)
208 209 self.assertTrue(ar.sent)
209 210 ar.get()
210 211
211 212 def test_scatter_tracked(self):
212 213 t = self.client.ids
213 214 x='x'*1024*1024
214 215 ar = self.client[t].scatter('x', x, block=False, track=False)
215 216 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
216 217 self.assertTrue(ar.sent)
217 218
218 219 ar = self.client[t].scatter('x', x, block=False, track=True)
219 220 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
220 221 self.assertEqual(ar.sent, ar._tracker.done)
221 222 ar._tracker.wait()
222 223 self.assertTrue(ar.sent)
223 224 ar.get()
224 225
225 226 def test_remote_reference(self):
226 227 v = self.client[-1]
227 228 v['a'] = 123
228 229 ra = pmod.Reference('a')
229 230 b = v.apply_sync(lambda x: x, ra)
230 231 self.assertEqual(b, 123)
231 232
232 233
233 234 def test_scatter_gather(self):
234 235 view = self.client[:]
235 236 seq1 = range(16)
236 237 view.scatter('a', seq1)
237 238 seq2 = view.gather('a', block=True)
238 239 self.assertEqual(seq2, seq1)
239 240 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
240 241
241 242 @skip_without('numpy')
242 243 def test_scatter_gather_numpy(self):
243 244 import numpy
244 245 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
245 246 view = self.client[:]
246 247 a = numpy.arange(64)
247 248 view.scatter('a', a, block=True)
248 249 b = view.gather('a', block=True)
249 250 assert_array_equal(b, a)
250 251
251 252 def test_scatter_gather_lazy(self):
252 253 """scatter/gather with targets='all'"""
253 254 view = self.client.direct_view(targets='all')
254 255 x = range(64)
255 256 view.scatter('x', x)
256 257 gathered = view.gather('x', block=True)
257 258 self.assertEqual(gathered, x)
258 259
259 260
260 261 @dec.known_failure_py3
261 262 @skip_without('numpy')
262 263 def test_push_numpy_nocopy(self):
263 264 import numpy
264 265 view = self.client[:]
265 266 a = numpy.arange(64)
266 267 view['A'] = a
267 268 @interactive
268 269 def check_writeable(x):
269 270 return x.flags.writeable
270 271
271 272 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
272 273 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
273 274
274 275 view.push(dict(B=a))
275 276 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
276 277 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
277 278
278 279 @skip_without('numpy')
279 280 def test_apply_numpy(self):
280 281 """view.apply(f, ndarray)"""
281 282 import numpy
282 283 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
283 284
284 285 A = numpy.random.random((100,100))
285 286 view = self.client[-1]
286 287 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
287 288 B = A.astype(dt)
288 289 C = view.apply_sync(lambda x:x, B)
289 290 assert_array_equal(B,C)
290 291
291 292 @skip_without('numpy')
292 293 def test_push_pull_recarray(self):
293 294 """push/pull recarrays"""
294 295 import numpy
295 296 from numpy.testing.utils import assert_array_equal
296 297
297 298 view = self.client[-1]
298 299
299 300 R = numpy.array([
300 301 (1, 'hi', 0.),
301 302 (2**30, 'there', 2.5),
302 303 (-99999, 'world', -12345.6789),
303 304 ], [('n', int), ('s', '|S10'), ('f', float)])
304 305
305 306 view['RR'] = R
306 307 R2 = view['RR']
307 308
308 309 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 310 self.assertEqual(r_dtype, R.dtype)
310 311 self.assertEqual(r_shape, R.shape)
311 312 self.assertEqual(R2.dtype, R.dtype)
312 313 self.assertEqual(R2.shape, R.shape)
313 314 assert_array_equal(R2, R)
314 315
315 316 @skip_without('pandas')
316 317 def test_push_pull_timeseries(self):
317 318 """push/pull pandas.TimeSeries"""
318 319 import pandas
319 320
320 321 ts = pandas.TimeSeries(range(10))
321 322
322 323 view = self.client[-1]
323 324
324 325 view.push(dict(ts=ts), block=True)
325 326 rts = view['ts']
326 327
327 328 self.assertEqual(type(rts), type(ts))
328 329 self.assertTrue((ts == rts).all())
329 330
330 331 def test_map(self):
331 332 view = self.client[:]
332 333 def f(x):
333 334 return x**2
334 335 data = range(16)
335 336 r = view.map_sync(f, data)
336 337 self.assertEqual(r, map(f, data))
337 338
338 339 def test_map_iterable(self):
339 340 """test map on iterables (direct)"""
340 341 view = self.client[:]
341 342 # 101 is prime, so it won't be evenly distributed
342 343 arr = range(101)
343 344 # ensure it will be an iterator, even in Python 3
344 345 it = iter(arr)
345 346 r = view.map_sync(lambda x: x, it)
346 347 self.assertEqual(r, list(arr))
347 348
348 349 @skip_without('numpy')
349 350 def test_map_numpy(self):
350 351 """test map on numpy arrays (direct)"""
351 352 import numpy
352 353 from numpy.testing.utils import assert_array_equal
353 354
354 355 view = self.client[:]
355 356 # 101 is prime, so it won't be evenly distributed
356 357 arr = numpy.arange(101)
357 358 r = view.map_sync(lambda x: x, arr)
358 359 assert_array_equal(r, arr)
359 360
360 361 def test_scatter_gather_nonblocking(self):
361 362 data = range(16)
362 363 view = self.client[:]
363 364 view.scatter('a', data, block=False)
364 365 ar = view.gather('a', block=False)
365 366 self.assertEqual(ar.get(), data)
366 367
367 368 @skip_without('numpy')
368 369 def test_scatter_gather_numpy_nonblocking(self):
369 370 import numpy
370 371 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
371 372 a = numpy.arange(64)
372 373 view = self.client[:]
373 374 ar = view.scatter('a', a, block=False)
374 375 self.assertTrue(isinstance(ar, AsyncResult))
375 376 amr = view.gather('a', block=False)
376 377 self.assertTrue(isinstance(amr, AsyncMapResult))
377 378 assert_array_equal(amr.get(), a)
378 379
379 380 def test_execute(self):
380 381 view = self.client[:]
381 382 # self.client.debug=True
382 383 execute = view.execute
383 384 ar = execute('c=30', block=False)
384 385 self.assertTrue(isinstance(ar, AsyncResult))
385 386 ar = execute('d=[0,1,2]', block=False)
386 387 self.client.wait(ar, 1)
387 388 self.assertEqual(len(ar.get()), len(self.client))
388 389 for c in view['c']:
389 390 self.assertEqual(c, 30)
390 391
391 392 def test_abort(self):
392 393 view = self.client[-1]
393 394 ar = view.execute('import time; time.sleep(1)', block=False)
394 395 ar2 = view.apply_async(lambda : 2)
395 396 ar3 = view.apply_async(lambda : 3)
396 397 view.abort(ar2)
397 398 view.abort(ar3.msg_ids)
398 399 self.assertRaises(error.TaskAborted, ar2.get)
399 400 self.assertRaises(error.TaskAborted, ar3.get)
400 401
401 402 def test_abort_all(self):
402 403 """view.abort() aborts all outstanding tasks"""
403 404 view = self.client[-1]
404 405 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
405 406 view.abort()
406 407 view.wait(timeout=5)
407 408 for ar in ars[5:]:
408 409 self.assertRaises(error.TaskAborted, ar.get)
409 410
410 411 def test_temp_flags(self):
411 412 view = self.client[-1]
412 413 view.block=True
413 414 with view.temp_flags(block=False):
414 415 self.assertFalse(view.block)
415 416 self.assertTrue(view.block)
416 417
417 418 @dec.known_failure_py3
418 419 def test_importer(self):
419 420 view = self.client[-1]
420 421 view.clear(block=True)
421 422 with view.importer:
422 423 import re
423 424
424 425 @interactive
425 426 def findall(pat, s):
426 427 # this globals() step isn't necessary in real code
427 428 # only to prevent a closure in the test
428 429 re = globals()['re']
429 430 return re.findall(pat, s)
430 431
431 432 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
432 433
433 434 def test_unicode_execute(self):
434 435 """test executing unicode strings"""
435 436 v = self.client[-1]
436 437 v.block=True
437 438 if sys.version_info[0] >= 3:
438 439 code="a='é'"
439 440 else:
440 441 code=u"a=u'é'"
441 442 v.execute(code)
442 443 self.assertEqual(v['a'], u'é')
443 444
444 445 def test_unicode_apply_result(self):
445 446 """test unicode apply results"""
446 447 v = self.client[-1]
447 448 r = v.apply_sync(lambda : u'é')
448 449 self.assertEqual(r, u'é')
449 450
450 451 def test_unicode_apply_arg(self):
451 452 """test passing unicode arguments to apply"""
452 453 v = self.client[-1]
453 454
454 455 @interactive
455 456 def check_unicode(a, check):
456 457 assert isinstance(a, unicode), "%r is not unicode"%a
457 458 assert isinstance(check, bytes), "%r is not bytes"%check
458 459 assert a.encode('utf8') == check, "%s != %s"%(a,check)
459 460
460 461 for s in [ u'é', u'ßø®∫',u'asdf' ]:
461 462 try:
462 463 v.apply_sync(check_unicode, s, s.encode('utf8'))
463 464 except error.RemoteError as e:
464 465 if e.ename == 'AssertionError':
465 466 self.fail(e.evalue)
466 467 else:
467 468 raise e
468 469
469 470 def test_map_reference(self):
470 471 """view.map(<Reference>, *seqs) should work"""
471 472 v = self.client[:]
472 473 v.scatter('n', self.client.ids, flatten=True)
473 474 v.execute("f = lambda x,y: x*y")
474 475 rf = pmod.Reference('f')
475 476 nlist = list(range(10))
476 477 mlist = nlist[::-1]
477 478 expected = [ m*n for m,n in zip(mlist, nlist) ]
478 479 result = v.map_sync(rf, mlist, nlist)
479 480 self.assertEqual(result, expected)
480 481
481 482 def test_apply_reference(self):
482 483 """view.apply(<Reference>, *args) should work"""
483 484 v = self.client[:]
484 485 v.scatter('n', self.client.ids, flatten=True)
485 486 v.execute("f = lambda x: n*x")
486 487 rf = pmod.Reference('f')
487 488 result = v.apply_sync(rf, 5)
488 489 expected = [ 5*id for id in self.client.ids ]
489 490 self.assertEqual(result, expected)
490 491
491 492 def test_eval_reference(self):
492 493 v = self.client[self.client.ids[0]]
493 494 v['g'] = range(5)
494 495 rg = pmod.Reference('g[0]')
495 496 echo = lambda x:x
496 497 self.assertEqual(v.apply_sync(echo, rg), 0)
497 498
498 499 def test_reference_nameerror(self):
499 500 v = self.client[self.client.ids[0]]
500 501 r = pmod.Reference('elvis_has_left')
501 502 echo = lambda x:x
502 503 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
503 504
504 505 def test_single_engine_map(self):
505 506 e0 = self.client[self.client.ids[0]]
506 507 r = range(5)
507 508 check = [ -1*i for i in r ]
508 509 result = e0.map_sync(lambda x: -1*x, r)
509 510 self.assertEqual(result, check)
510 511
511 512 def test_len(self):
512 513 """len(view) makes sense"""
513 514 e0 = self.client[self.client.ids[0]]
514 515 yield self.assertEqual(len(e0), 1)
515 516 v = self.client[:]
516 517 yield self.assertEqual(len(v), len(self.client.ids))
517 518 v = self.client.direct_view('all')
518 519 yield self.assertEqual(len(v), len(self.client.ids))
519 520 v = self.client[:2]
520 521 yield self.assertEqual(len(v), 2)
521 522 v = self.client[:1]
522 523 yield self.assertEqual(len(v), 1)
523 524 v = self.client.load_balanced_view()
524 525 yield self.assertEqual(len(v), len(self.client.ids))
525 526 # parametric tests seem to require manual closing?
526 527 self.client.close()
527 528
528 529
529 530 # begin execute tests
530 531
531 532 def test_execute_reply(self):
532 533 e0 = self.client[self.client.ids[0]]
533 534 e0.block = True
534 535 ar = e0.execute("5", silent=False)
535 536 er = ar.get()
536 537 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
537 538 self.assertEqual(er.pyout['data']['text/plain'], '5')
538 539
540 def test_execute_reply_rich(self):
541 e0 = self.client[self.client.ids[0]]
542 e0.block = True
543 e0.execute("from IPython.display import Image, HTML")
544 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
545 er = ar.get()
546 b64data = base64.encodestring(b'garbage').decode('ascii')
547 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
548 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
549 er = ar.get()
550 self.assertEqual(er._repr_html_(), "<b>bold</b>")
551
539 552 def test_execute_reply_stdout(self):
540 553 e0 = self.client[self.client.ids[0]]
541 554 e0.block = True
542 555 ar = e0.execute("print (5)", silent=False)
543 556 er = ar.get()
544 557 self.assertEqual(er.stdout.strip(), '5')
545 558
546 559 def test_execute_pyout(self):
547 560 """execute triggers pyout with silent=False"""
548 561 view = self.client[:]
549 562 ar = view.execute("5", silent=False, block=True)
550 563
551 564 expected = [{'text/plain' : '5'}] * len(view)
552 565 mimes = [ out['data'] for out in ar.pyout ]
553 566 self.assertEqual(mimes, expected)
554 567
555 568 def test_execute_silent(self):
556 569 """execute does not trigger pyout with silent=True"""
557 570 view = self.client[:]
558 571 ar = view.execute("5", block=True)
559 572 expected = [None] * len(view)
560 573 self.assertEqual(ar.pyout, expected)
561 574
562 575 def test_execute_magic(self):
563 576 """execute accepts IPython commands"""
564 577 view = self.client[:]
565 578 view.execute("a = 5")
566 579 ar = view.execute("%whos", block=True)
567 580 # this will raise, if that failed
568 581 ar.get(5)
569 582 for stdout in ar.stdout:
570 583 lines = stdout.splitlines()
571 584 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
572 585 found = False
573 586 for line in lines[2:]:
574 587 split = line.split()
575 588 if split == ['a', 'int', '5']:
576 589 found = True
577 590 break
578 591 self.assertTrue(found, "whos output wrong: %s" % stdout)
579 592
580 593 def test_execute_displaypub(self):
581 594 """execute tracks display_pub output"""
582 595 view = self.client[:]
583 596 view.execute("from IPython.core.display import *")
584 597 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
585 598
586 599 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
587 600 for outputs in ar.outputs:
588 601 mimes = [ out['data'] for out in outputs ]
589 602 self.assertEqual(mimes, expected)
590 603
591 604 def test_apply_displaypub(self):
592 605 """apply tracks display_pub output"""
593 606 view = self.client[:]
594 607 view.execute("from IPython.core.display import *")
595 608
596 609 @interactive
597 610 def publish():
598 611 [ display(i) for i in range(5) ]
599 612
600 613 ar = view.apply_async(publish)
601 614 ar.get(5)
602 615 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
603 616 for outputs in ar.outputs:
604 617 mimes = [ out['data'] for out in outputs ]
605 618 self.assertEqual(mimes, expected)
606 619
607 620 def test_execute_raises(self):
608 621 """exceptions in execute requests raise appropriately"""
609 622 view = self.client[-1]
610 623 ar = view.execute("1/0")
611 624 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
612 625
613 626 def test_remoteerror_render_exception(self):
614 627 """RemoteErrors get nice tracebacks"""
615 628 view = self.client[-1]
616 629 ar = view.execute("1/0")
617 630 ip = get_ipython()
618 631 ip.user_ns['ar'] = ar
619 632 with capture_output() as io:
620 633 ip.run_cell("ar.get(2)")
621 634
622 635 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
623 636
624 637 def test_compositeerror_render_exception(self):
625 638 """CompositeErrors get nice tracebacks"""
626 639 view = self.client[:]
627 640 ar = view.execute("1/0")
628 641 ip = get_ipython()
629 642 ip.user_ns['ar'] = ar
630 643
631 644 with capture_output() as io:
632 645 ip.run_cell("ar.get(2)")
633 646
634 647 count = min(error.CompositeError.tb_limit, len(view))
635 648
636 649 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
637 650 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
638 651 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
639 652
640 653 def test_compositeerror_truncate(self):
641 654 """Truncate CompositeErrors with many exceptions"""
642 655 view = self.client[:]
643 656 msg_ids = []
644 657 for i in range(10):
645 658 ar = view.execute("1/0")
646 659 msg_ids.extend(ar.msg_ids)
647 660
648 661 ar = self.client.get_result(msg_ids)
649 662 try:
650 663 ar.get()
651 664 except error.CompositeError as _e:
652 665 e = _e
653 666 else:
654 667 self.fail("Should have raised CompositeError")
655 668
656 669 lines = e.render_traceback()
657 670 with capture_output() as io:
658 671 e.print_traceback()
659 672
660 673 self.assertTrue("more exceptions" in lines[-1])
661 674 count = e.tb_limit
662 675
663 676 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
664 677 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
665 678 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
666 679
667 680 @dec.skipif_not_matplotlib
668 681 def test_magic_pylab(self):
669 682 """%pylab works on engines"""
670 683 view = self.client[-1]
671 684 ar = view.execute("%pylab inline")
672 685 # at least check if this raised:
673 686 reply = ar.get(5)
674 687 # include imports, in case user config
675 688 ar = view.execute("plot(rand(100))", silent=False)
676 689 reply = ar.get(5)
677 690 self.assertEqual(len(reply.outputs), 1)
678 691 output = reply.outputs[0]
679 692 self.assertTrue("data" in output)
680 693 data = output['data']
681 694 self.assertTrue("image/png" in data)
682 695
683 696 def test_func_default_func(self):
684 697 """interactively defined function as apply func default"""
685 698 def foo():
686 699 return 'foo'
687 700
688 701 def bar(f=foo):
689 702 return f()
690 703
691 704 view = self.client[-1]
692 705 ar = view.apply_async(bar)
693 706 r = ar.get(10)
694 707 self.assertEqual(r, 'foo')
695 708 def test_data_pub_single(self):
696 709 view = self.client[-1]
697 710 ar = view.execute('\n'.join([
698 711 'from IPython.kernel.zmq.datapub import publish_data',
699 712 'for i in range(5):',
700 713 ' publish_data(dict(i=i))'
701 714 ]), block=False)
702 715 self.assertTrue(isinstance(ar.data, dict))
703 716 ar.get(5)
704 717 self.assertEqual(ar.data, dict(i=4))
705 718
706 719 def test_data_pub(self):
707 720 view = self.client[:]
708 721 ar = view.execute('\n'.join([
709 722 'from IPython.kernel.zmq.datapub import publish_data',
710 723 'for i in range(5):',
711 724 ' publish_data(dict(i=i))'
712 725 ]), block=False)
713 726 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
714 727 ar.get(5)
715 728 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
716 729
717 730 def test_can_list_arg(self):
718 731 """args in lists are canned"""
719 732 view = self.client[-1]
720 733 view['a'] = 128
721 734 rA = pmod.Reference('a')
722 735 ar = view.apply_async(lambda x: x, [rA])
723 736 r = ar.get(5)
724 737 self.assertEqual(r, [128])
725 738
726 739 def test_can_dict_arg(self):
727 740 """args in dicts are canned"""
728 741 view = self.client[-1]
729 742 view['a'] = 128
730 743 rA = pmod.Reference('a')
731 744 ar = view.apply_async(lambda x: x, dict(foo=rA))
732 745 r = ar.get(5)
733 746 self.assertEqual(r, dict(foo=128))
734 747
735 748 def test_can_list_kwarg(self):
736 749 """kwargs in lists are canned"""
737 750 view = self.client[-1]
738 751 view['a'] = 128
739 752 rA = pmod.Reference('a')
740 753 ar = view.apply_async(lambda x=5: x, x=[rA])
741 754 r = ar.get(5)
742 755 self.assertEqual(r, [128])
743 756
744 757 def test_can_dict_kwarg(self):
745 758 """kwargs in dicts are canned"""
746 759 view = self.client[-1]
747 760 view['a'] = 128
748 761 rA = pmod.Reference('a')
749 762 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
750 763 r = ar.get(5)
751 764 self.assertEqual(r, dict(foo=128))
752 765
753 766 def test_map_ref(self):
754 767 """view.map works with references"""
755 768 view = self.client[:]
756 769 ranks = sorted(self.client.ids)
757 770 view.scatter('rank', ranks, flatten=True)
758 771 rrank = pmod.Reference('rank')
759 772
760 773 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
761 774 drank = amr.get(5)
762 775 self.assertEqual(drank, [ r*2 for r in ranks ])
763 776
764 777 def test_nested_getitem_setitem(self):
765 778 """get and set with view['a.b']"""
766 779 view = self.client[-1]
767 780 view.execute('\n'.join([
768 781 'class A(object): pass',
769 782 'a = A()',
770 783 'a.b = 128',
771 784 ]), block=True)
772 785 ra = pmod.Reference('a')
773 786
774 787 r = view.apply_sync(lambda x: x.b, ra)
775 788 self.assertEqual(r, 128)
776 789 self.assertEqual(view['a.b'], 128)
777 790
778 791 view['a.b'] = 0
779 792
780 793 r = view.apply_sync(lambda x: x.b, ra)
781 794 self.assertEqual(r, 0)
782 795 self.assertEqual(view['a.b'], 0)
783 796
784 797 def test_return_namedtuple(self):
785 798 def namedtuplify(x, y):
786 799 from IPython.parallel.tests.test_view import point
787 800 return point(x, y)
788 801
789 802 view = self.client[-1]
790 803 p = view.apply_sync(namedtuplify, 1, 2)
791 804 self.assertEqual(p.x, 1)
792 805 self.assertEqual(p.y, 2)
793 806
794 807 def test_apply_namedtuple(self):
795 808 def echoxy(p):
796 809 return p.y, p.x
797 810
798 811 view = self.client[-1]
799 812 tup = view.apply_sync(echoxy, point(1, 2))
800 813 self.assertEqual(tup, (2,1))
801 814
General Comments 0
You need to be logged in to leave comments. Login now