##// END OF EJS Templates
update parallel magics...
MinRK -
Show More
@@ -1,1659 +1,1694 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 from threading import Thread, Event
22 22 import time
23 23 import warnings
24 24 from datetime import datetime
25 25 from getpass import getpass
26 26 from pprint import pprint
27 27
28 28 pjoin = os.path.join
29 29
30 30 import zmq
31 31 # from zmq.eventloop import ioloop, zmqstream
32 32
33 33 from IPython.config.configurable import MultipleInstanceError
34 34 from IPython.core.application import BaseIPythonApplication
35 from IPython.core.profiledir import ProfileDir, ProfileDirError
35 36
36 37 from IPython.utils.coloransi import TermColors
37 38 from IPython.utils.jsonutil import rekey
38 39 from IPython.utils.localinterfaces import LOCAL_IPS
39 40 from IPython.utils.path import get_ipython_dir
40 41 from IPython.utils.py3compat import cast_bytes
41 42 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
42 43 Dict, List, Bool, Set, Any)
43 44 from IPython.external.decorator import decorator
44 45 from IPython.external.ssh import tunnel
45 46
46 47 from IPython.parallel import Reference
47 48 from IPython.parallel import error
48 49 from IPython.parallel import util
49 50
50 51 from IPython.zmq.session import Session, Message
51 52
52 53 from .asyncresult import AsyncResult, AsyncHubResult
53 from IPython.core.profiledir import ProfileDir, ProfileDirError
54 54 from .view import DirectView, LoadBalancedView
55 55
56 56 if sys.version_info[0] >= 3:
57 57 # xrange is used in a couple 'isinstance' tests in py2
58 58 # should be just 'range' in 3k
59 59 xrange = range
60 60
61 61 #--------------------------------------------------------------------------
62 62 # Decorators for Client methods
63 63 #--------------------------------------------------------------------------
64 64
65 65 @decorator
66 66 def spin_first(f, self, *args, **kwargs):
67 67 """Call spin() to sync state prior to calling the method."""
68 68 self.spin()
69 69 return f(self, *args, **kwargs)
70 70
71 71
72 72 #--------------------------------------------------------------------------
73 73 # Classes
74 74 #--------------------------------------------------------------------------
75 75
76 76
77 77 class ExecuteReply(object):
78 78 """wrapper for finished Execute results"""
79 79 def __init__(self, msg_id, content, metadata):
80 80 self.msg_id = msg_id
81 81 self._content = content
82 82 self.execution_count = content['execution_count']
83 83 self.metadata = metadata
84 84
85 85 def __getitem__(self, key):
86 86 return self.metadata[key]
87 87
88 88 def __getattr__(self, key):
89 89 if key not in self.metadata:
90 90 raise AttributeError(key)
91 91 return self.metadata[key]
92 92
93 93 def __repr__(self):
94 94 pyout = self.metadata['pyout'] or {'data':{}}
95 95 text_out = pyout['data'].get('text/plain', '')
96 96 if len(text_out) > 32:
97 97 text_out = text_out[:29] + '...'
98 98
99 99 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
100 100
101 101 def _repr_pretty_(self, p, cycle):
102 102 pyout = self.metadata['pyout'] or {'data':{}}
103 103 text_out = pyout['data'].get('text/plain', '')
104 104
105 105 if not text_out:
106 106 return
107 107
108 108 try:
109 109 ip = get_ipython()
110 110 except NameError:
111 111 colors = "NoColor"
112 112 else:
113 113 colors = ip.colors
114 114
115 115 if colors == "NoColor":
116 116 out = normal = ""
117 117 else:
118 118 out = TermColors.Red
119 119 normal = TermColors.Normal
120 120
121 121 if '\n' in text_out and not text_out.startswith('\n'):
122 122 # add newline for multiline reprs
123 123 text_out = '\n' + text_out
124 124
125 125 p.text(
126 126 out + u'Out[%i:%i]: ' % (
127 127 self.metadata['engine_id'], self.execution_count
128 128 ) + normal + text_out
129 129 )
130 130
131 131 def _repr_html_(self):
132 132 pyout = self.metadata['pyout'] or {'data':{}}
133 133 return pyout['data'].get("text/html")
134 134
135 135 def _repr_latex_(self):
136 136 pyout = self.metadata['pyout'] or {'data':{}}
137 137 return pyout['data'].get("text/latex")
138 138
139 139 def _repr_json_(self):
140 140 pyout = self.metadata['pyout'] or {'data':{}}
141 141 return pyout['data'].get("application/json")
142 142
143 143 def _repr_javascript_(self):
144 144 pyout = self.metadata['pyout'] or {'data':{}}
145 145 return pyout['data'].get("application/javascript")
146 146
147 147 def _repr_png_(self):
148 148 pyout = self.metadata['pyout'] or {'data':{}}
149 149 return pyout['data'].get("image/png")
150 150
151 151 def _repr_jpeg_(self):
152 152 pyout = self.metadata['pyout'] or {'data':{}}
153 153 return pyout['data'].get("image/jpeg")
154 154
155 155 def _repr_svg_(self):
156 156 pyout = self.metadata['pyout'] or {'data':{}}
157 157 return pyout['data'].get("image/svg+xml")
158 158
159 159
160 160 class Metadata(dict):
161 161 """Subclass of dict for initializing metadata values.
162 162
163 163 Attribute access works on keys.
164 164
165 165 These objects have a strict set of keys - errors will raise if you try
166 166 to add new keys.
167 167 """
168 168 def __init__(self, *args, **kwargs):
169 169 dict.__init__(self)
170 170 md = {'msg_id' : None,
171 171 'submitted' : None,
172 172 'started' : None,
173 173 'completed' : None,
174 174 'received' : None,
175 175 'engine_uuid' : None,
176 176 'engine_id' : None,
177 177 'follow' : None,
178 178 'after' : None,
179 179 'status' : None,
180 180
181 181 'pyin' : None,
182 182 'pyout' : None,
183 183 'pyerr' : None,
184 184 'stdout' : '',
185 185 'stderr' : '',
186 186 'outputs' : [],
187 187 }
188 188 self.update(md)
189 189 self.update(dict(*args, **kwargs))
190 190
191 191 def __getattr__(self, key):
192 192 """getattr aliased to getitem"""
193 193 if key in self.iterkeys():
194 194 return self[key]
195 195 else:
196 196 raise AttributeError(key)
197 197
198 198 def __setattr__(self, key, value):
199 199 """setattr aliased to setitem, with strict"""
200 200 if key in self.iterkeys():
201 201 self[key] = value
202 202 else:
203 203 raise AttributeError(key)
204 204
205 205 def __setitem__(self, key, value):
206 206 """strict static key enforcement"""
207 207 if key in self.iterkeys():
208 208 dict.__setitem__(self, key, value)
209 209 else:
210 210 raise KeyError(key)
211 211
212 212
213 213 class Client(HasTraits):
214 214 """A semi-synchronous client to the IPython ZMQ cluster
215 215
216 216 Parameters
217 217 ----------
218 218
219 219 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
220 220 Connection information for the Hub's registration. If a json connector
221 221 file is given, then likely no further configuration is necessary.
222 222 [Default: use profile]
223 223 profile : bytes
224 224 The name of the Cluster profile to be used to find connector information.
225 225 If run from an IPython application, the default profile will be the same
226 226 as the running application, otherwise it will be 'default'.
227 227 context : zmq.Context
228 228 Pass an existing zmq.Context instance, otherwise the client will create its own.
229 229 debug : bool
230 230 flag for lots of message printing for debug purposes
231 231 timeout : int/float
232 232 time (in seconds) to wait for connection replies from the Hub
233 233 [Default: 10]
234 234
235 235 #-------------- session related args ----------------
236 236
237 237 config : Config object
238 238 If specified, this will be relayed to the Session for configuration
239 239 username : str
240 240 set username for the session object
241 241 packer : str (import_string) or callable
242 242 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
243 243 function to serialize messages. Must support same input as
244 244 JSON, and output must be bytes.
245 245 You can pass a callable directly as `pack`
246 246 unpacker : str (import_string) or callable
247 247 The inverse of packer. Only necessary if packer is specified as *not* one
248 248 of 'json' or 'pickle'.
249 249
250 250 #-------------- ssh related args ----------------
251 251 # These are args for configuring the ssh tunnel to be used
252 252 # credentials are used to forward connections over ssh to the Controller
253 253 # Note that the ip given in `addr` needs to be relative to sshserver
254 254 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
255 255 # and set sshserver as the same machine the Controller is on. However,
256 256 # the only requirement is that sshserver is able to see the Controller
257 257 # (i.e. is within the same trusted network).
258 258
259 259 sshserver : str
260 260 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
261 261 If keyfile or password is specified, and this is not, it will default to
262 262 the ip given in addr.
263 263 sshkey : str; path to ssh private key file
264 264 This specifies a key to be used in ssh login, default None.
265 265 Regular default ssh keys will be used without specifying this argument.
266 266 password : str
267 267 Your ssh password to sshserver. Note that if this is left None,
268 268 you will be prompted for it if passwordless key based login is unavailable.
269 269 paramiko : bool
270 270 flag for whether to use paramiko instead of shell ssh for tunneling.
271 271 [default: True on win32, False else]
272 272
273 273 ------- exec authentication args -------
274 274 If even localhost is untrusted, you can have some protection against
275 275 unauthorized execution by signing messages with HMAC digests.
276 276 Messages are still sent as cleartext, so if someone can snoop your
277 277 loopback traffic this will not protect your privacy, but will prevent
278 278 unauthorized execution.
279 279
280 280 exec_key : str
281 281 an authentication key or file containing a key
282 282 default: None
283 283
284 284
285 285 Attributes
286 286 ----------
287 287
288 288 ids : list of int engine IDs
289 289 requesting the ids attribute always synchronizes
290 290 the registration state. To request ids without synchronization,
291 291 use semi-private _ids attributes.
292 292
293 293 history : list of msg_ids
294 294 a list of msg_ids, keeping track of all the execution
295 295 messages you have submitted in order.
296 296
297 297 outstanding : set of msg_ids
298 298 a set of msg_ids that have been submitted, but whose
299 299 results have not yet been received.
300 300
301 301 results : dict
302 302 a dict of all our results, keyed by msg_id
303 303
304 304 block : bool
305 305 determines default behavior when block not specified
306 306 in execution methods
307 307
308 308 Methods
309 309 -------
310 310
311 311 spin
312 312 flushes incoming results and registration state changes
313 313 control methods spin, and requesting `ids` also ensures up to date
314 314
315 315 wait
316 316 wait on one or more msg_ids
317 317
318 318 execution methods
319 319 apply
320 320 legacy: execute, run
321 321
322 322 data movement
323 323 push, pull, scatter, gather
324 324
325 325 query methods
326 326 queue_status, get_result, purge, result_status
327 327
328 328 control methods
329 329 abort, shutdown
330 330
331 331 """
332 332
333 333
334 334 block = Bool(False)
335 335 outstanding = Set()
336 336 results = Instance('collections.defaultdict', (dict,))
337 337 metadata = Instance('collections.defaultdict', (Metadata,))
338 338 history = List()
339 339 debug = Bool(False)
340 340 _spin_thread = Any()
341 341 _stop_spinning = Any()
342 342
343 343 profile=Unicode()
344 344 def _profile_default(self):
345 345 if BaseIPythonApplication.initialized():
346 346 # an IPython app *might* be running, try to get its profile
347 347 try:
348 348 return BaseIPythonApplication.instance().profile
349 349 except (AttributeError, MultipleInstanceError):
350 350 # could be a *different* subclass of config.Application,
351 351 # which would raise one of these two errors.
352 352 return u'default'
353 353 else:
354 354 return u'default'
355 355
356 356
357 357 _outstanding_dict = Instance('collections.defaultdict', (set,))
358 358 _ids = List()
359 359 _connected=Bool(False)
360 360 _ssh=Bool(False)
361 361 _context = Instance('zmq.Context')
362 362 _config = Dict()
363 363 _engines=Instance(util.ReverseDict, (), {})
364 364 # _hub_socket=Instance('zmq.Socket')
365 365 _query_socket=Instance('zmq.Socket')
366 366 _control_socket=Instance('zmq.Socket')
367 367 _iopub_socket=Instance('zmq.Socket')
368 368 _notification_socket=Instance('zmq.Socket')
369 369 _mux_socket=Instance('zmq.Socket')
370 370 _task_socket=Instance('zmq.Socket')
371 371 _task_scheme=Unicode()
372 372 _closed = False
373 373 _ignored_control_replies=Integer(0)
374 374 _ignored_hub_replies=Integer(0)
375 375
376 376 def __new__(self, *args, **kw):
377 377 # don't raise on positional args
378 378 return HasTraits.__new__(self, **kw)
379 379
380 380 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
381 381 context=None, debug=False, exec_key=None,
382 382 sshserver=None, sshkey=None, password=None, paramiko=None,
383 383 timeout=10, **extra_args
384 384 ):
385 385 if profile:
386 386 super(Client, self).__init__(debug=debug, profile=profile)
387 387 else:
388 388 super(Client, self).__init__(debug=debug)
389 389 if context is None:
390 390 context = zmq.Context.instance()
391 391 self._context = context
392 392 self._stop_spinning = Event()
393 393
394 394 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
395 395 if self._cd is not None:
396 396 if url_or_file is None:
397 397 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
398 398 if url_or_file is None:
399 399 raise ValueError(
400 400 "I can't find enough information to connect to a hub!"
401 401 " Please specify at least one of url_or_file or profile."
402 402 )
403 403
404 404 if not util.is_url(url_or_file):
405 405 # it's not a url, try for a file
406 406 if not os.path.exists(url_or_file):
407 407 if self._cd:
408 408 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
409 409 if not os.path.exists(url_or_file):
410 410 raise IOError("Connection file not found: %r" % url_or_file)
411 411 with open(url_or_file) as f:
412 412 cfg = json.loads(f.read())
413 413 else:
414 414 cfg = {'url':url_or_file}
415 415
416 416 # sync defaults from args, json:
417 417 if sshserver:
418 418 cfg['ssh'] = sshserver
419 419 if exec_key:
420 420 cfg['exec_key'] = exec_key
421 421 exec_key = cfg['exec_key']
422 422 location = cfg.setdefault('location', None)
423 423 cfg['url'] = util.disambiguate_url(cfg['url'], location)
424 424 url = cfg['url']
425 425 proto,addr,port = util.split_url(url)
426 426 if location is not None and addr == '127.0.0.1':
427 427 # location specified, and connection is expected to be local
428 428 if location not in LOCAL_IPS and not sshserver:
429 429 # load ssh from JSON *only* if the controller is not on
430 430 # this machine
431 431 sshserver=cfg['ssh']
432 432 if location not in LOCAL_IPS and not sshserver:
433 433 # warn if no ssh specified, but SSH is probably needed
434 434 # This is only a warning, because the most likely cause
435 435 # is a local Controller on a laptop whose IP is dynamic
436 436 warnings.warn("""
437 437 Controller appears to be listening on localhost, but not on this machine.
438 438 If this is true, you should specify Client(...,sshserver='you@%s')
439 439 or instruct your controller to listen on an external IP."""%location,
440 440 RuntimeWarning)
441 441 elif not sshserver:
442 442 # otherwise sync with cfg
443 443 sshserver = cfg['ssh']
444 444
445 445 self._config = cfg
446 446
447 447 self._ssh = bool(sshserver or sshkey or password)
448 448 if self._ssh and sshserver is None:
449 449 # default to ssh via localhost
450 450 sshserver = url.split('://')[1].split(':')[0]
451 451 if self._ssh and password is None:
452 452 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
453 453 password=False
454 454 else:
455 455 password = getpass("SSH Password for %s: "%sshserver)
456 456 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
457 457
458 458 # configure and construct the session
459 459 if exec_key is not None:
460 460 if os.path.isfile(exec_key):
461 461 extra_args['keyfile'] = exec_key
462 462 else:
463 463 exec_key = cast_bytes(exec_key)
464 464 extra_args['key'] = exec_key
465 465 self.session = Session(**extra_args)
466 466
467 467 self._query_socket = self._context.socket(zmq.DEALER)
468 468 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
469 469 if self._ssh:
470 470 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
471 471 else:
472 472 self._query_socket.connect(url)
473 473
474 474 self.session.debug = self.debug
475 475
476 476 self._notification_handlers = {'registration_notification' : self._register_engine,
477 477 'unregistration_notification' : self._unregister_engine,
478 478 'shutdown_notification' : lambda msg: self.close(),
479 479 }
480 480 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
481 481 'apply_reply' : self._handle_apply_reply}
482 482 self._connect(sshserver, ssh_kwargs, timeout)
483
484 # last step: setup magics, if we are in IPython:
485
486 try:
487 ip = get_ipython()
488 except NameError:
489 return
490 else:
491 if 'px' not in ip.magics_manager.magics:
492 # in IPython but we are the first Client.
493 # activate a default view for parallel magics.
494 self.activate()
483 495
484 496 def __del__(self):
485 497 """cleanup sockets, but _not_ context."""
486 498 self.close()
487 499
488 500 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
489 501 if ipython_dir is None:
490 502 ipython_dir = get_ipython_dir()
491 503 if profile_dir is not None:
492 504 try:
493 505 self._cd = ProfileDir.find_profile_dir(profile_dir)
494 506 return
495 507 except ProfileDirError:
496 508 pass
497 509 elif profile is not None:
498 510 try:
499 511 self._cd = ProfileDir.find_profile_dir_by_name(
500 512 ipython_dir, profile)
501 513 return
502 514 except ProfileDirError:
503 515 pass
504 516 self._cd = None
505 517
506 518 def _update_engines(self, engines):
507 519 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
508 520 for k,v in engines.iteritems():
509 521 eid = int(k)
510 522 self._engines[eid] = v
511 523 self._ids.append(eid)
512 524 self._ids = sorted(self._ids)
513 525 if sorted(self._engines.keys()) != range(len(self._engines)) and \
514 526 self._task_scheme == 'pure' and self._task_socket:
515 527 self._stop_scheduling_tasks()
516 528
517 529 def _stop_scheduling_tasks(self):
518 530 """Stop scheduling tasks because an engine has been unregistered
519 531 from a pure ZMQ scheduler.
520 532 """
521 533 self._task_socket.close()
522 534 self._task_socket = None
523 535 msg = "An engine has been unregistered, and we are using pure " +\
524 536 "ZMQ task scheduling. Task farming will be disabled."
525 537 if self.outstanding:
526 538 msg += " If you were running tasks when this happened, " +\
527 539 "some `outstanding` msg_ids may never resolve."
528 540 warnings.warn(msg, RuntimeWarning)
529 541
530 542 def _build_targets(self, targets):
531 543 """Turn valid target IDs or 'all' into two lists:
532 544 (int_ids, uuids).
533 545 """
534 546 if not self._ids:
535 547 # flush notification socket if no engines yet, just in case
536 548 if not self.ids:
537 549 raise error.NoEnginesRegistered("Can't build targets without any engines")
538 550
539 551 if targets is None:
540 552 targets = self._ids
541 553 elif isinstance(targets, basestring):
542 554 if targets.lower() == 'all':
543 555 targets = self._ids
544 556 else:
545 557 raise TypeError("%r not valid str target, must be 'all'"%(targets))
546 558 elif isinstance(targets, int):
547 559 if targets < 0:
548 560 targets = self.ids[targets]
549 561 if targets not in self._ids:
550 562 raise IndexError("No such engine: %i"%targets)
551 563 targets = [targets]
552 564
553 565 if isinstance(targets, slice):
554 566 indices = range(len(self._ids))[targets]
555 567 ids = self.ids
556 568 targets = [ ids[i] for i in indices ]
557 569
558 570 if not isinstance(targets, (tuple, list, xrange)):
559 571 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
560 572
561 573 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
562 574
563 575 def _connect(self, sshserver, ssh_kwargs, timeout):
564 576 """setup all our socket connections to the cluster. This is called from
565 577 __init__."""
566 578
567 579 # Maybe allow reconnecting?
568 580 if self._connected:
569 581 return
570 582 self._connected=True
571 583
572 584 def connect_socket(s, url):
573 585 url = util.disambiguate_url(url, self._config['location'])
574 586 if self._ssh:
575 587 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
576 588 else:
577 589 return s.connect(url)
578 590
579 591 self.session.send(self._query_socket, 'connection_request')
580 592 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
581 593 poller = zmq.Poller()
582 594 poller.register(self._query_socket, zmq.POLLIN)
583 595 # poll expects milliseconds, timeout is seconds
584 596 evts = poller.poll(timeout*1000)
585 597 if not evts:
586 598 raise error.TimeoutError("Hub connection request timed out")
587 599 idents,msg = self.session.recv(self._query_socket,mode=0)
588 600 if self.debug:
589 601 pprint(msg)
590 602 msg = Message(msg)
591 603 content = msg.content
592 604 self._config['registration'] = dict(content)
593 605 if content.status == 'ok':
594 606 ident = self.session.bsession
595 607 if content.mux:
596 608 self._mux_socket = self._context.socket(zmq.DEALER)
597 609 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
598 610 connect_socket(self._mux_socket, content.mux)
599 611 if content.task:
600 612 self._task_scheme, task_addr = content.task
601 613 self._task_socket = self._context.socket(zmq.DEALER)
602 614 self._task_socket.setsockopt(zmq.IDENTITY, ident)
603 615 connect_socket(self._task_socket, task_addr)
604 616 if content.notification:
605 617 self._notification_socket = self._context.socket(zmq.SUB)
606 618 connect_socket(self._notification_socket, content.notification)
607 619 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
608 620 # if content.query:
609 621 # self._query_socket = self._context.socket(zmq.DEALER)
610 622 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
611 623 # connect_socket(self._query_socket, content.query)
612 624 if content.control:
613 625 self._control_socket = self._context.socket(zmq.DEALER)
614 626 self._control_socket.setsockopt(zmq.IDENTITY, ident)
615 627 connect_socket(self._control_socket, content.control)
616 628 if content.iopub:
617 629 self._iopub_socket = self._context.socket(zmq.SUB)
618 630 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
619 631 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
620 632 connect_socket(self._iopub_socket, content.iopub)
621 633 self._update_engines(dict(content.engines))
622 634 else:
623 635 self._connected = False
624 636 raise Exception("Failed to connect!")
625 637
626 638 #--------------------------------------------------------------------------
627 639 # handlers and callbacks for incoming messages
628 640 #--------------------------------------------------------------------------
629 641
630 642 def _unwrap_exception(self, content):
631 643 """unwrap exception, and remap engine_id to int."""
632 644 e = error.unwrap_exception(content)
633 645 # print e.traceback
634 646 if e.engine_info:
635 647 e_uuid = e.engine_info['engine_uuid']
636 648 eid = self._engines[e_uuid]
637 649 e.engine_info['engine_id'] = eid
638 650 return e
639 651
640 652 def _extract_metadata(self, header, parent, content):
641 653 md = {'msg_id' : parent['msg_id'],
642 654 'received' : datetime.now(),
643 655 'engine_uuid' : header.get('engine', None),
644 656 'follow' : parent.get('follow', []),
645 657 'after' : parent.get('after', []),
646 658 'status' : content['status'],
647 659 }
648 660
649 661 if md['engine_uuid'] is not None:
650 662 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
651 663
652 664 if 'date' in parent:
653 665 md['submitted'] = parent['date']
654 666 if 'started' in header:
655 667 md['started'] = header['started']
656 668 if 'date' in header:
657 669 md['completed'] = header['date']
658 670 return md
659 671
660 672 def _register_engine(self, msg):
661 673 """Register a new engine, and update our connection info."""
662 674 content = msg['content']
663 675 eid = content['id']
664 676 d = {eid : content['queue']}
665 677 self._update_engines(d)
666 678
667 679 def _unregister_engine(self, msg):
668 680 """Unregister an engine that has died."""
669 681 content = msg['content']
670 682 eid = int(content['id'])
671 683 if eid in self._ids:
672 684 self._ids.remove(eid)
673 685 uuid = self._engines.pop(eid)
674 686
675 687 self._handle_stranded_msgs(eid, uuid)
676 688
677 689 if self._task_socket and self._task_scheme == 'pure':
678 690 self._stop_scheduling_tasks()
679 691
680 692 def _handle_stranded_msgs(self, eid, uuid):
681 693 """Handle messages known to be on an engine when the engine unregisters.
682 694
683 695 It is possible that this will fire prematurely - that is, an engine will
684 696 go down after completing a result, and the client will be notified
685 697 of the unregistration and later receive the successful result.
686 698 """
687 699
688 700 outstanding = self._outstanding_dict[uuid]
689 701
690 702 for msg_id in list(outstanding):
691 703 if msg_id in self.results:
692 704 # we already
693 705 continue
694 706 try:
695 707 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
696 708 except:
697 709 content = error.wrap_exception()
698 710 # build a fake message:
699 711 parent = {}
700 712 header = {}
701 713 parent['msg_id'] = msg_id
702 714 header['engine'] = uuid
703 715 header['date'] = datetime.now()
704 716 msg = dict(parent_header=parent, header=header, content=content)
705 717 self._handle_apply_reply(msg)
706 718
707 719 def _handle_execute_reply(self, msg):
708 720 """Save the reply to an execute_request into our results.
709 721
710 722 execute messages are never actually used. apply is used instead.
711 723 """
712 724
713 725 parent = msg['parent_header']
714 726 msg_id = parent['msg_id']
715 727 if msg_id not in self.outstanding:
716 728 if msg_id in self.history:
717 729 print ("got stale result: %s"%msg_id)
718 730 else:
719 731 print ("got unknown result: %s"%msg_id)
720 732 else:
721 733 self.outstanding.remove(msg_id)
722 734
723 735 content = msg['content']
724 736 header = msg['header']
725 737
726 738 # construct metadata:
727 739 md = self.metadata[msg_id]
728 740 md.update(self._extract_metadata(header, parent, content))
729 741 # is this redundant?
730 742 self.metadata[msg_id] = md
731 743
732 744 e_outstanding = self._outstanding_dict[md['engine_uuid']]
733 745 if msg_id in e_outstanding:
734 746 e_outstanding.remove(msg_id)
735 747
736 748 # construct result:
737 749 if content['status'] == 'ok':
738 750 self.results[msg_id] = ExecuteReply(msg_id, content, md)
739 751 elif content['status'] == 'aborted':
740 752 self.results[msg_id] = error.TaskAborted(msg_id)
741 753 elif content['status'] == 'resubmitted':
742 754 # TODO: handle resubmission
743 755 pass
744 756 else:
745 757 self.results[msg_id] = self._unwrap_exception(content)
746 758
747 759 def _handle_apply_reply(self, msg):
748 760 """Save the reply to an apply_request into our results."""
749 761 parent = msg['parent_header']
750 762 msg_id = parent['msg_id']
751 763 if msg_id not in self.outstanding:
752 764 if msg_id in self.history:
753 765 print ("got stale result: %s"%msg_id)
754 766 print self.results[msg_id]
755 767 print msg
756 768 else:
757 769 print ("got unknown result: %s"%msg_id)
758 770 else:
759 771 self.outstanding.remove(msg_id)
760 772 content = msg['content']
761 773 header = msg['header']
762 774
763 775 # construct metadata:
764 776 md = self.metadata[msg_id]
765 777 md.update(self._extract_metadata(header, parent, content))
766 778 # is this redundant?
767 779 self.metadata[msg_id] = md
768 780
769 781 e_outstanding = self._outstanding_dict[md['engine_uuid']]
770 782 if msg_id in e_outstanding:
771 783 e_outstanding.remove(msg_id)
772 784
773 785 # construct result:
774 786 if content['status'] == 'ok':
775 787 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
776 788 elif content['status'] == 'aborted':
777 789 self.results[msg_id] = error.TaskAborted(msg_id)
778 790 elif content['status'] == 'resubmitted':
779 791 # TODO: handle resubmission
780 792 pass
781 793 else:
782 794 self.results[msg_id] = self._unwrap_exception(content)
783 795
784 796 def _flush_notifications(self):
785 797 """Flush notifications of engine registrations waiting
786 798 in ZMQ queue."""
787 799 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
788 800 while msg is not None:
789 801 if self.debug:
790 802 pprint(msg)
791 803 msg_type = msg['header']['msg_type']
792 804 handler = self._notification_handlers.get(msg_type, None)
793 805 if handler is None:
794 806 raise Exception("Unhandled message type: %s"%msg.msg_type)
795 807 else:
796 808 handler(msg)
797 809 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
798 810
799 811 def _flush_results(self, sock):
800 812 """Flush task or queue results waiting in ZMQ queue."""
801 813 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
802 814 while msg is not None:
803 815 if self.debug:
804 816 pprint(msg)
805 817 msg_type = msg['header']['msg_type']
806 818 handler = self._queue_handlers.get(msg_type, None)
807 819 if handler is None:
808 820 raise Exception("Unhandled message type: %s"%msg.msg_type)
809 821 else:
810 822 handler(msg)
811 823 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
812 824
813 825 def _flush_control(self, sock):
814 826 """Flush replies from the control channel waiting
815 827 in the ZMQ queue.
816 828
817 829 Currently: ignore them."""
818 830 if self._ignored_control_replies <= 0:
819 831 return
820 832 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
821 833 while msg is not None:
822 834 self._ignored_control_replies -= 1
823 835 if self.debug:
824 836 pprint(msg)
825 837 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
826 838
827 839 def _flush_ignored_control(self):
828 840 """flush ignored control replies"""
829 841 while self._ignored_control_replies > 0:
830 842 self.session.recv(self._control_socket)
831 843 self._ignored_control_replies -= 1
832 844
833 845 def _flush_ignored_hub_replies(self):
834 846 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
835 847 while msg is not None:
836 848 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
837 849
838 850 def _flush_iopub(self, sock):
839 851 """Flush replies from the iopub channel waiting
840 852 in the ZMQ queue.
841 853 """
842 854 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
843 855 while msg is not None:
844 856 if self.debug:
845 857 pprint(msg)
846 858 parent = msg['parent_header']
847 859 # ignore IOPub messages with no parent.
848 860 # Caused by print statements or warnings from before the first execution.
849 861 if not parent:
850 862 continue
851 863 msg_id = parent['msg_id']
852 864 content = msg['content']
853 865 header = msg['header']
854 866 msg_type = msg['header']['msg_type']
855 867
856 868 # init metadata:
857 869 md = self.metadata[msg_id]
858 870
859 871 if msg_type == 'stream':
860 872 name = content['name']
861 873 s = md[name] or ''
862 874 md[name] = s + content['data']
863 875 elif msg_type == 'pyerr':
864 876 md.update({'pyerr' : self._unwrap_exception(content)})
865 877 elif msg_type == 'pyin':
866 878 md.update({'pyin' : content['code']})
867 879 elif msg_type == 'display_data':
868 880 md['outputs'].append(content)
869 881 elif msg_type == 'pyout':
870 882 md['pyout'] = content
871 883 else:
872 884 # unhandled msg_type (status, etc.)
873 885 pass
874 886
875 887 # reduntant?
876 888 self.metadata[msg_id] = md
877 889
878 890 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
879 891
880 892 #--------------------------------------------------------------------------
881 893 # len, getitem
882 894 #--------------------------------------------------------------------------
883 895
884 896 def __len__(self):
885 897 """len(client) returns # of engines."""
886 898 return len(self.ids)
887 899
888 900 def __getitem__(self, key):
889 901 """index access returns DirectView multiplexer objects
890 902
891 903 Must be int, slice, or list/tuple/xrange of ints"""
892 904 if not isinstance(key, (int, slice, tuple, list, xrange)):
893 905 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
894 906 else:
895 907 return self.direct_view(key)
896 908
897 909 #--------------------------------------------------------------------------
898 910 # Begin public methods
899 911 #--------------------------------------------------------------------------
900 912
901 913 @property
902 914 def ids(self):
903 915 """Always up-to-date ids property."""
904 916 self._flush_notifications()
905 917 # always copy:
906 918 return list(self._ids)
907 919
920 def activate(self, targets='all', suffix=''):
921 """Create a DirectView and register it with IPython magics
922
923 Defines the magics `%px, %autopx, %pxresult, %%px`
924
925 Parameters
926 ----------
927
928 targets: int, list of ints, or 'all'
929 The engines on which the view's magics will run
930 suffix: str [default: '']
931 The suffix, if any, for the magics. This allows you to have
932 multiple views associated with parallel magics at the same time.
933
934 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
935 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
936 on engine 0.
937 """
938 view = self.direct_view(targets)
939 view.block = True
940 view.activate(suffix)
941 return view
942
908 943 def close(self):
909 944 if self._closed:
910 945 return
911 946 self.stop_spin_thread()
912 947 snames = filter(lambda n: n.endswith('socket'), dir(self))
913 948 for socket in map(lambda name: getattr(self, name), snames):
914 949 if isinstance(socket, zmq.Socket) and not socket.closed:
915 950 socket.close()
916 951 self._closed = True
917 952
918 953 def _spin_every(self, interval=1):
919 954 """target func for use in spin_thread"""
920 955 while True:
921 956 if self._stop_spinning.is_set():
922 957 return
923 958 time.sleep(interval)
924 959 self.spin()
925 960
926 961 def spin_thread(self, interval=1):
927 962 """call Client.spin() in a background thread on some regular interval
928 963
929 964 This helps ensure that messages don't pile up too much in the zmq queue
930 965 while you are working on other things, or just leaving an idle terminal.
931 966
932 967 It also helps limit potential padding of the `received` timestamp
933 968 on AsyncResult objects, used for timings.
934 969
935 970 Parameters
936 971 ----------
937 972
938 973 interval : float, optional
939 974 The interval on which to spin the client in the background thread
940 975 (simply passed to time.sleep).
941 976
942 977 Notes
943 978 -----
944 979
945 980 For precision timing, you may want to use this method to put a bound
946 981 on the jitter (in seconds) in `received` timestamps used
947 982 in AsyncResult.wall_time.
948 983
949 984 """
950 985 if self._spin_thread is not None:
951 986 self.stop_spin_thread()
952 987 self._stop_spinning.clear()
953 988 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
954 989 self._spin_thread.daemon = True
955 990 self._spin_thread.start()
956 991
957 992 def stop_spin_thread(self):
958 993 """stop background spin_thread, if any"""
959 994 if self._spin_thread is not None:
960 995 self._stop_spinning.set()
961 996 self._spin_thread.join()
962 997 self._spin_thread = None
963 998
964 999 def spin(self):
965 1000 """Flush any registration notifications and execution results
966 1001 waiting in the ZMQ queue.
967 1002 """
968 1003 if self._notification_socket:
969 1004 self._flush_notifications()
970 1005 if self._iopub_socket:
971 1006 self._flush_iopub(self._iopub_socket)
972 1007 if self._mux_socket:
973 1008 self._flush_results(self._mux_socket)
974 1009 if self._task_socket:
975 1010 self._flush_results(self._task_socket)
976 1011 if self._control_socket:
977 1012 self._flush_control(self._control_socket)
978 1013 if self._query_socket:
979 1014 self._flush_ignored_hub_replies()
980 1015
981 1016 def wait(self, jobs=None, timeout=-1):
982 1017 """waits on one or more `jobs`, for up to `timeout` seconds.
983 1018
984 1019 Parameters
985 1020 ----------
986 1021
987 1022 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
988 1023 ints are indices to self.history
989 1024 strs are msg_ids
990 1025 default: wait on all outstanding messages
991 1026 timeout : float
992 1027 a time in seconds, after which to give up.
993 1028 default is -1, which means no timeout
994 1029
995 1030 Returns
996 1031 -------
997 1032
998 1033 True : when all msg_ids are done
999 1034 False : timeout reached, some msg_ids still outstanding
1000 1035 """
1001 1036 tic = time.time()
1002 1037 if jobs is None:
1003 1038 theids = self.outstanding
1004 1039 else:
1005 1040 if isinstance(jobs, (int, basestring, AsyncResult)):
1006 1041 jobs = [jobs]
1007 1042 theids = set()
1008 1043 for job in jobs:
1009 1044 if isinstance(job, int):
1010 1045 # index access
1011 1046 job = self.history[job]
1012 1047 elif isinstance(job, AsyncResult):
1013 1048 map(theids.add, job.msg_ids)
1014 1049 continue
1015 1050 theids.add(job)
1016 1051 if not theids.intersection(self.outstanding):
1017 1052 return True
1018 1053 self.spin()
1019 1054 while theids.intersection(self.outstanding):
1020 1055 if timeout >= 0 and ( time.time()-tic ) > timeout:
1021 1056 break
1022 1057 time.sleep(1e-3)
1023 1058 self.spin()
1024 1059 return len(theids.intersection(self.outstanding)) == 0
1025 1060
1026 1061 #--------------------------------------------------------------------------
1027 1062 # Control methods
1028 1063 #--------------------------------------------------------------------------
1029 1064
1030 1065 @spin_first
1031 1066 def clear(self, targets=None, block=None):
1032 1067 """Clear the namespace in target(s)."""
1033 1068 block = self.block if block is None else block
1034 1069 targets = self._build_targets(targets)[0]
1035 1070 for t in targets:
1036 1071 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1037 1072 error = False
1038 1073 if block:
1039 1074 self._flush_ignored_control()
1040 1075 for i in range(len(targets)):
1041 1076 idents,msg = self.session.recv(self._control_socket,0)
1042 1077 if self.debug:
1043 1078 pprint(msg)
1044 1079 if msg['content']['status'] != 'ok':
1045 1080 error = self._unwrap_exception(msg['content'])
1046 1081 else:
1047 1082 self._ignored_control_replies += len(targets)
1048 1083 if error:
1049 1084 raise error
1050 1085
1051 1086
1052 1087 @spin_first
1053 1088 def abort(self, jobs=None, targets=None, block=None):
1054 1089 """Abort specific jobs from the execution queues of target(s).
1055 1090
1056 1091 This is a mechanism to prevent jobs that have already been submitted
1057 1092 from executing.
1058 1093
1059 1094 Parameters
1060 1095 ----------
1061 1096
1062 1097 jobs : msg_id, list of msg_ids, or AsyncResult
1063 1098 The jobs to be aborted
1064 1099
1065 1100 If unspecified/None: abort all outstanding jobs.
1066 1101
1067 1102 """
1068 1103 block = self.block if block is None else block
1069 1104 jobs = jobs if jobs is not None else list(self.outstanding)
1070 1105 targets = self._build_targets(targets)[0]
1071 1106
1072 1107 msg_ids = []
1073 1108 if isinstance(jobs, (basestring,AsyncResult)):
1074 1109 jobs = [jobs]
1075 1110 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1076 1111 if bad_ids:
1077 1112 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1078 1113 for j in jobs:
1079 1114 if isinstance(j, AsyncResult):
1080 1115 msg_ids.extend(j.msg_ids)
1081 1116 else:
1082 1117 msg_ids.append(j)
1083 1118 content = dict(msg_ids=msg_ids)
1084 1119 for t in targets:
1085 1120 self.session.send(self._control_socket, 'abort_request',
1086 1121 content=content, ident=t)
1087 1122 error = False
1088 1123 if block:
1089 1124 self._flush_ignored_control()
1090 1125 for i in range(len(targets)):
1091 1126 idents,msg = self.session.recv(self._control_socket,0)
1092 1127 if self.debug:
1093 1128 pprint(msg)
1094 1129 if msg['content']['status'] != 'ok':
1095 1130 error = self._unwrap_exception(msg['content'])
1096 1131 else:
1097 1132 self._ignored_control_replies += len(targets)
1098 1133 if error:
1099 1134 raise error
1100 1135
1101 1136 @spin_first
1102 1137 def shutdown(self, targets=None, restart=False, hub=False, block=None):
1103 1138 """Terminates one or more engine processes, optionally including the hub."""
1104 1139 block = self.block if block is None else block
1105 1140 if hub:
1106 1141 targets = 'all'
1107 1142 targets = self._build_targets(targets)[0]
1108 1143 for t in targets:
1109 1144 self.session.send(self._control_socket, 'shutdown_request',
1110 1145 content={'restart':restart},ident=t)
1111 1146 error = False
1112 1147 if block or hub:
1113 1148 self._flush_ignored_control()
1114 1149 for i in range(len(targets)):
1115 1150 idents,msg = self.session.recv(self._control_socket, 0)
1116 1151 if self.debug:
1117 1152 pprint(msg)
1118 1153 if msg['content']['status'] != 'ok':
1119 1154 error = self._unwrap_exception(msg['content'])
1120 1155 else:
1121 1156 self._ignored_control_replies += len(targets)
1122 1157
1123 1158 if hub:
1124 1159 time.sleep(0.25)
1125 1160 self.session.send(self._query_socket, 'shutdown_request')
1126 1161 idents,msg = self.session.recv(self._query_socket, 0)
1127 1162 if self.debug:
1128 1163 pprint(msg)
1129 1164 if msg['content']['status'] != 'ok':
1130 1165 error = self._unwrap_exception(msg['content'])
1131 1166
1132 1167 if error:
1133 1168 raise error
1134 1169
1135 1170 #--------------------------------------------------------------------------
1136 1171 # Execution related methods
1137 1172 #--------------------------------------------------------------------------
1138 1173
1139 1174 def _maybe_raise(self, result):
1140 1175 """wrapper for maybe raising an exception if apply failed."""
1141 1176 if isinstance(result, error.RemoteError):
1142 1177 raise result
1143 1178
1144 1179 return result
1145 1180
1146 1181 def send_apply_request(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
1147 1182 ident=None):
1148 1183 """construct and send an apply message via a socket.
1149 1184
1150 1185 This is the principal method with which all engine execution is performed by views.
1151 1186 """
1152 1187
1153 1188 if self._closed:
1154 1189 raise RuntimeError("Client cannot be used after its sockets have been closed")
1155 1190
1156 1191 # defaults:
1157 1192 args = args if args is not None else []
1158 1193 kwargs = kwargs if kwargs is not None else {}
1159 1194 subheader = subheader if subheader is not None else {}
1160 1195
1161 1196 # validate arguments
1162 1197 if not callable(f) and not isinstance(f, Reference):
1163 1198 raise TypeError("f must be callable, not %s"%type(f))
1164 1199 if not isinstance(args, (tuple, list)):
1165 1200 raise TypeError("args must be tuple or list, not %s"%type(args))
1166 1201 if not isinstance(kwargs, dict):
1167 1202 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1168 1203 if not isinstance(subheader, dict):
1169 1204 raise TypeError("subheader must be dict, not %s"%type(subheader))
1170 1205
1171 1206 bufs = util.pack_apply_message(f,args,kwargs)
1172 1207
1173 1208 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1174 1209 subheader=subheader, track=track)
1175 1210
1176 1211 msg_id = msg['header']['msg_id']
1177 1212 self.outstanding.add(msg_id)
1178 1213 if ident:
1179 1214 # possibly routed to a specific engine
1180 1215 if isinstance(ident, list):
1181 1216 ident = ident[-1]
1182 1217 if ident in self._engines.values():
1183 1218 # save for later, in case of engine death
1184 1219 self._outstanding_dict[ident].add(msg_id)
1185 1220 self.history.append(msg_id)
1186 1221 self.metadata[msg_id]['submitted'] = datetime.now()
1187 1222
1188 1223 return msg
1189 1224
1190 1225 def send_execute_request(self, socket, code, silent=True, subheader=None, ident=None):
1191 1226 """construct and send an execute request via a socket.
1192 1227
1193 1228 """
1194 1229
1195 1230 if self._closed:
1196 1231 raise RuntimeError("Client cannot be used after its sockets have been closed")
1197 1232
1198 1233 # defaults:
1199 1234 subheader = subheader if subheader is not None else {}
1200 1235
1201 1236 # validate arguments
1202 1237 if not isinstance(code, basestring):
1203 1238 raise TypeError("code must be text, not %s" % type(code))
1204 1239 if not isinstance(subheader, dict):
1205 1240 raise TypeError("subheader must be dict, not %s" % type(subheader))
1206 1241
1207 1242 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1208 1243
1209 1244
1210 1245 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1211 1246 subheader=subheader)
1212 1247
1213 1248 msg_id = msg['header']['msg_id']
1214 1249 self.outstanding.add(msg_id)
1215 1250 if ident:
1216 1251 # possibly routed to a specific engine
1217 1252 if isinstance(ident, list):
1218 1253 ident = ident[-1]
1219 1254 if ident in self._engines.values():
1220 1255 # save for later, in case of engine death
1221 1256 self._outstanding_dict[ident].add(msg_id)
1222 1257 self.history.append(msg_id)
1223 1258 self.metadata[msg_id]['submitted'] = datetime.now()
1224 1259
1225 1260 return msg
1226 1261
1227 1262 #--------------------------------------------------------------------------
1228 1263 # construct a View object
1229 1264 #--------------------------------------------------------------------------
1230 1265
1231 1266 def load_balanced_view(self, targets=None):
1232 1267 """construct a DirectView object.
1233 1268
1234 1269 If no arguments are specified, create a LoadBalancedView
1235 1270 using all engines.
1236 1271
1237 1272 Parameters
1238 1273 ----------
1239 1274
1240 1275 targets: list,slice,int,etc. [default: use all engines]
1241 1276 The subset of engines across which to load-balance
1242 1277 """
1243 1278 if targets == 'all':
1244 1279 targets = None
1245 1280 if targets is not None:
1246 1281 targets = self._build_targets(targets)[1]
1247 1282 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1248 1283
1249 1284 def direct_view(self, targets='all'):
1250 1285 """construct a DirectView object.
1251 1286
1252 1287 If no targets are specified, create a DirectView using all engines.
1253 1288
1254 1289 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1255 1290 evaluate the target engines at each execution, whereas rc[:] will connect to
1256 1291 all *current* engines, and that list will not change.
1257 1292
1258 1293 That is, 'all' will always use all engines, whereas rc[:] will not use
1259 1294 engines added after the DirectView is constructed.
1260 1295
1261 1296 Parameters
1262 1297 ----------
1263 1298
1264 1299 targets: list,slice,int,etc. [default: use all engines]
1265 1300 The engines to use for the View
1266 1301 """
1267 1302 single = isinstance(targets, int)
1268 1303 # allow 'all' to be lazily evaluated at each execution
1269 1304 if targets != 'all':
1270 1305 targets = self._build_targets(targets)[1]
1271 1306 if single:
1272 1307 targets = targets[0]
1273 1308 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1274 1309
1275 1310 #--------------------------------------------------------------------------
1276 1311 # Query methods
1277 1312 #--------------------------------------------------------------------------
1278 1313
1279 1314 @spin_first
1280 1315 def get_result(self, indices_or_msg_ids=None, block=None):
1281 1316 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1282 1317
1283 1318 If the client already has the results, no request to the Hub will be made.
1284 1319
1285 1320 This is a convenient way to construct AsyncResult objects, which are wrappers
1286 1321 that include metadata about execution, and allow for awaiting results that
1287 1322 were not submitted by this Client.
1288 1323
1289 1324 It can also be a convenient way to retrieve the metadata associated with
1290 1325 blocking execution, since it always retrieves
1291 1326
1292 1327 Examples
1293 1328 --------
1294 1329 ::
1295 1330
1296 1331 In [10]: r = client.apply()
1297 1332
1298 1333 Parameters
1299 1334 ----------
1300 1335
1301 1336 indices_or_msg_ids : integer history index, str msg_id, or list of either
1302 1337 The indices or msg_ids of indices to be retrieved
1303 1338
1304 1339 block : bool
1305 1340 Whether to wait for the result to be done
1306 1341
1307 1342 Returns
1308 1343 -------
1309 1344
1310 1345 AsyncResult
1311 1346 A single AsyncResult object will always be returned.
1312 1347
1313 1348 AsyncHubResult
1314 1349 A subclass of AsyncResult that retrieves results from the Hub
1315 1350
1316 1351 """
1317 1352 block = self.block if block is None else block
1318 1353 if indices_or_msg_ids is None:
1319 1354 indices_or_msg_ids = -1
1320 1355
1321 1356 if not isinstance(indices_or_msg_ids, (list,tuple)):
1322 1357 indices_or_msg_ids = [indices_or_msg_ids]
1323 1358
1324 1359 theids = []
1325 1360 for id in indices_or_msg_ids:
1326 1361 if isinstance(id, int):
1327 1362 id = self.history[id]
1328 1363 if not isinstance(id, basestring):
1329 1364 raise TypeError("indices must be str or int, not %r"%id)
1330 1365 theids.append(id)
1331 1366
1332 1367 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1333 1368 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1334 1369
1335 1370 if remote_ids:
1336 1371 ar = AsyncHubResult(self, msg_ids=theids)
1337 1372 else:
1338 1373 ar = AsyncResult(self, msg_ids=theids)
1339 1374
1340 1375 if block:
1341 1376 ar.wait()
1342 1377
1343 1378 return ar
1344 1379
1345 1380 @spin_first
1346 1381 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1347 1382 """Resubmit one or more tasks.
1348 1383
1349 1384 in-flight tasks may not be resubmitted.
1350 1385
1351 1386 Parameters
1352 1387 ----------
1353 1388
1354 1389 indices_or_msg_ids : integer history index, str msg_id, or list of either
1355 1390 The indices or msg_ids of indices to be retrieved
1356 1391
1357 1392 block : bool
1358 1393 Whether to wait for the result to be done
1359 1394
1360 1395 Returns
1361 1396 -------
1362 1397
1363 1398 AsyncHubResult
1364 1399 A subclass of AsyncResult that retrieves results from the Hub
1365 1400
1366 1401 """
1367 1402 block = self.block if block is None else block
1368 1403 if indices_or_msg_ids is None:
1369 1404 indices_or_msg_ids = -1
1370 1405
1371 1406 if not isinstance(indices_or_msg_ids, (list,tuple)):
1372 1407 indices_or_msg_ids = [indices_or_msg_ids]
1373 1408
1374 1409 theids = []
1375 1410 for id in indices_or_msg_ids:
1376 1411 if isinstance(id, int):
1377 1412 id = self.history[id]
1378 1413 if not isinstance(id, basestring):
1379 1414 raise TypeError("indices must be str or int, not %r"%id)
1380 1415 theids.append(id)
1381 1416
1382 1417 content = dict(msg_ids = theids)
1383 1418
1384 1419 self.session.send(self._query_socket, 'resubmit_request', content)
1385 1420
1386 1421 zmq.select([self._query_socket], [], [])
1387 1422 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1388 1423 if self.debug:
1389 1424 pprint(msg)
1390 1425 content = msg['content']
1391 1426 if content['status'] != 'ok':
1392 1427 raise self._unwrap_exception(content)
1393 1428 mapping = content['resubmitted']
1394 1429 new_ids = [ mapping[msg_id] for msg_id in theids ]
1395 1430
1396 1431 ar = AsyncHubResult(self, msg_ids=new_ids)
1397 1432
1398 1433 if block:
1399 1434 ar.wait()
1400 1435
1401 1436 return ar
1402 1437
1403 1438 @spin_first
1404 1439 def result_status(self, msg_ids, status_only=True):
1405 1440 """Check on the status of the result(s) of the apply request with `msg_ids`.
1406 1441
1407 1442 If status_only is False, then the actual results will be retrieved, else
1408 1443 only the status of the results will be checked.
1409 1444
1410 1445 Parameters
1411 1446 ----------
1412 1447
1413 1448 msg_ids : list of msg_ids
1414 1449 if int:
1415 1450 Passed as index to self.history for convenience.
1416 1451 status_only : bool (default: True)
1417 1452 if False:
1418 1453 Retrieve the actual results of completed tasks.
1419 1454
1420 1455 Returns
1421 1456 -------
1422 1457
1423 1458 results : dict
1424 1459 There will always be the keys 'pending' and 'completed', which will
1425 1460 be lists of msg_ids that are incomplete or complete. If `status_only`
1426 1461 is False, then completed results will be keyed by their `msg_id`.
1427 1462 """
1428 1463 if not isinstance(msg_ids, (list,tuple)):
1429 1464 msg_ids = [msg_ids]
1430 1465
1431 1466 theids = []
1432 1467 for msg_id in msg_ids:
1433 1468 if isinstance(msg_id, int):
1434 1469 msg_id = self.history[msg_id]
1435 1470 if not isinstance(msg_id, basestring):
1436 1471 raise TypeError("msg_ids must be str, not %r"%msg_id)
1437 1472 theids.append(msg_id)
1438 1473
1439 1474 completed = []
1440 1475 local_results = {}
1441 1476
1442 1477 # comment this block out to temporarily disable local shortcut:
1443 1478 for msg_id in theids:
1444 1479 if msg_id in self.results:
1445 1480 completed.append(msg_id)
1446 1481 local_results[msg_id] = self.results[msg_id]
1447 1482 theids.remove(msg_id)
1448 1483
1449 1484 if theids: # some not locally cached
1450 1485 content = dict(msg_ids=theids, status_only=status_only)
1451 1486 msg = self.session.send(self._query_socket, "result_request", content=content)
1452 1487 zmq.select([self._query_socket], [], [])
1453 1488 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1454 1489 if self.debug:
1455 1490 pprint(msg)
1456 1491 content = msg['content']
1457 1492 if content['status'] != 'ok':
1458 1493 raise self._unwrap_exception(content)
1459 1494 buffers = msg['buffers']
1460 1495 else:
1461 1496 content = dict(completed=[],pending=[])
1462 1497
1463 1498 content['completed'].extend(completed)
1464 1499
1465 1500 if status_only:
1466 1501 return content
1467 1502
1468 1503 failures = []
1469 1504 # load cached results into result:
1470 1505 content.update(local_results)
1471 1506
1472 1507 # update cache with results:
1473 1508 for msg_id in sorted(theids):
1474 1509 if msg_id in content['completed']:
1475 1510 rec = content[msg_id]
1476 1511 parent = rec['header']
1477 1512 header = rec['result_header']
1478 1513 rcontent = rec['result_content']
1479 1514 iodict = rec['io']
1480 1515 if isinstance(rcontent, str):
1481 1516 rcontent = self.session.unpack(rcontent)
1482 1517
1483 1518 md = self.metadata[msg_id]
1484 1519 md.update(self._extract_metadata(header, parent, rcontent))
1485 1520 if rec.get('received'):
1486 1521 md['received'] = rec['received']
1487 1522 md.update(iodict)
1488 1523
1489 1524 if rcontent['status'] == 'ok':
1490 1525 res,buffers = util.unserialize_object(buffers)
1491 1526 else:
1492 1527 print rcontent
1493 1528 res = self._unwrap_exception(rcontent)
1494 1529 failures.append(res)
1495 1530
1496 1531 self.results[msg_id] = res
1497 1532 content[msg_id] = res
1498 1533
1499 1534 if len(theids) == 1 and failures:
1500 1535 raise failures[0]
1501 1536
1502 1537 error.collect_exceptions(failures, "result_status")
1503 1538 return content
1504 1539
1505 1540 @spin_first
1506 1541 def queue_status(self, targets='all', verbose=False):
1507 1542 """Fetch the status of engine queues.
1508 1543
1509 1544 Parameters
1510 1545 ----------
1511 1546
1512 1547 targets : int/str/list of ints/strs
1513 1548 the engines whose states are to be queried.
1514 1549 default : all
1515 1550 verbose : bool
1516 1551 Whether to return lengths only, or lists of ids for each element
1517 1552 """
1518 1553 if targets == 'all':
1519 1554 # allow 'all' to be evaluated on the engine
1520 1555 engine_ids = None
1521 1556 else:
1522 1557 engine_ids = self._build_targets(targets)[1]
1523 1558 content = dict(targets=engine_ids, verbose=verbose)
1524 1559 self.session.send(self._query_socket, "queue_request", content=content)
1525 1560 idents,msg = self.session.recv(self._query_socket, 0)
1526 1561 if self.debug:
1527 1562 pprint(msg)
1528 1563 content = msg['content']
1529 1564 status = content.pop('status')
1530 1565 if status != 'ok':
1531 1566 raise self._unwrap_exception(content)
1532 1567 content = rekey(content)
1533 1568 if isinstance(targets, int):
1534 1569 return content[targets]
1535 1570 else:
1536 1571 return content
1537 1572
1538 1573 @spin_first
1539 1574 def purge_results(self, jobs=[], targets=[]):
1540 1575 """Tell the Hub to forget results.
1541 1576
1542 1577 Individual results can be purged by msg_id, or the entire
1543 1578 history of specific targets can be purged.
1544 1579
1545 1580 Use `purge_results('all')` to scrub everything from the Hub's db.
1546 1581
1547 1582 Parameters
1548 1583 ----------
1549 1584
1550 1585 jobs : str or list of str or AsyncResult objects
1551 1586 the msg_ids whose results should be forgotten.
1552 1587 targets : int/str/list of ints/strs
1553 1588 The targets, by int_id, whose entire history is to be purged.
1554 1589
1555 1590 default : None
1556 1591 """
1557 1592 if not targets and not jobs:
1558 1593 raise ValueError("Must specify at least one of `targets` and `jobs`")
1559 1594 if targets:
1560 1595 targets = self._build_targets(targets)[1]
1561 1596
1562 1597 # construct msg_ids from jobs
1563 1598 if jobs == 'all':
1564 1599 msg_ids = jobs
1565 1600 else:
1566 1601 msg_ids = []
1567 1602 if isinstance(jobs, (basestring,AsyncResult)):
1568 1603 jobs = [jobs]
1569 1604 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1570 1605 if bad_ids:
1571 1606 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1572 1607 for j in jobs:
1573 1608 if isinstance(j, AsyncResult):
1574 1609 msg_ids.extend(j.msg_ids)
1575 1610 else:
1576 1611 msg_ids.append(j)
1577 1612
1578 1613 content = dict(engine_ids=targets, msg_ids=msg_ids)
1579 1614 self.session.send(self._query_socket, "purge_request", content=content)
1580 1615 idents, msg = self.session.recv(self._query_socket, 0)
1581 1616 if self.debug:
1582 1617 pprint(msg)
1583 1618 content = msg['content']
1584 1619 if content['status'] != 'ok':
1585 1620 raise self._unwrap_exception(content)
1586 1621
1587 1622 @spin_first
1588 1623 def hub_history(self):
1589 1624 """Get the Hub's history
1590 1625
1591 1626 Just like the Client, the Hub has a history, which is a list of msg_ids.
1592 1627 This will contain the history of all clients, and, depending on configuration,
1593 1628 may contain history across multiple cluster sessions.
1594 1629
1595 1630 Any msg_id returned here is a valid argument to `get_result`.
1596 1631
1597 1632 Returns
1598 1633 -------
1599 1634
1600 1635 msg_ids : list of strs
1601 1636 list of all msg_ids, ordered by task submission time.
1602 1637 """
1603 1638
1604 1639 self.session.send(self._query_socket, "history_request", content={})
1605 1640 idents, msg = self.session.recv(self._query_socket, 0)
1606 1641
1607 1642 if self.debug:
1608 1643 pprint(msg)
1609 1644 content = msg['content']
1610 1645 if content['status'] != 'ok':
1611 1646 raise self._unwrap_exception(content)
1612 1647 else:
1613 1648 return content['history']
1614 1649
1615 1650 @spin_first
1616 1651 def db_query(self, query, keys=None):
1617 1652 """Query the Hub's TaskRecord database
1618 1653
1619 1654 This will return a list of task record dicts that match `query`
1620 1655
1621 1656 Parameters
1622 1657 ----------
1623 1658
1624 1659 query : mongodb query dict
1625 1660 The search dict. See mongodb query docs for details.
1626 1661 keys : list of strs [optional]
1627 1662 The subset of keys to be returned. The default is to fetch everything but buffers.
1628 1663 'msg_id' will *always* be included.
1629 1664 """
1630 1665 if isinstance(keys, basestring):
1631 1666 keys = [keys]
1632 1667 content = dict(query=query, keys=keys)
1633 1668 self.session.send(self._query_socket, "db_request", content=content)
1634 1669 idents, msg = self.session.recv(self._query_socket, 0)
1635 1670 if self.debug:
1636 1671 pprint(msg)
1637 1672 content = msg['content']
1638 1673 if content['status'] != 'ok':
1639 1674 raise self._unwrap_exception(content)
1640 1675
1641 1676 records = content['records']
1642 1677
1643 1678 buffer_lens = content['buffer_lens']
1644 1679 result_buffer_lens = content['result_buffer_lens']
1645 1680 buffers = msg['buffers']
1646 1681 has_bufs = buffer_lens is not None
1647 1682 has_rbufs = result_buffer_lens is not None
1648 1683 for i,rec in enumerate(records):
1649 1684 # relink buffers
1650 1685 if has_bufs:
1651 1686 blen = buffer_lens[i]
1652 1687 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1653 1688 if has_rbufs:
1654 1689 blen = result_buffer_lens[i]
1655 1690 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1656 1691
1657 1692 return records
1658 1693
1659 1694 __all__ = [ 'Client' ]
@@ -1,390 +1,411 b''
1 1 # encoding: utf-8
2 2 """
3 3 =============
4 4 parallelmagic
5 5 =============
6 6
7 7 Magic command interface for interactive parallel work.
8 8
9 9 Usage
10 10 =====
11 11
12 12 ``%autopx``
13 13
14 14 {AUTOPX_DOC}
15 15
16 16 ``%px``
17 17
18 18 {PX_DOC}
19 19
20 ``%result``
20 ``%pxresult``
21 21
22 22 {RESULT_DOC}
23 23
24 ``%pxconfig``
25
26 {CONFIG_DOC}
27
24 28 """
25 29
26 30 #-----------------------------------------------------------------------------
27 31 # Copyright (C) 2008 The IPython Development Team
28 32 #
29 33 # Distributed under the terms of the BSD License. The full license is in
30 34 # the file COPYING, distributed as part of this software.
31 35 #-----------------------------------------------------------------------------
32 36
33 37 #-----------------------------------------------------------------------------
34 38 # Imports
35 39 #-----------------------------------------------------------------------------
36 40
37 41 import ast
38 42 import re
39 43
40 44 from IPython.core.error import UsageError
41 from IPython.core.magic import Magics, magics_class, line_magic, cell_magic
45 from IPython.core.magic import Magics
46 from IPython.core import magic_arguments
42 47 from IPython.testing.skipdoctest import skip_doctest
43 48
44 49 #-----------------------------------------------------------------------------
45 50 # Definitions of magic functions for use with IPython
46 51 #-----------------------------------------------------------------------------
47 52
48 53
49 NO_ACTIVE_VIEW = "Use activate() on a DirectView object to use it with magics."
50 NO_LAST_RESULT = "%result recalls last %px result, which has not yet been used."
54 NO_LAST_RESULT = "%pxresult recalls last %px result, which has not yet been used."
51 55
56 def exec_args(f):
57 """decorator for adding block/targets args for execution
58
59 applied to %pxconfig and %%px
60 """
61 args = [
62 magic_arguments.argument('-b', '--block', action="store_const",
63 const=True, dest='block',
64 help="use blocking (sync) execution"
65 ),
66 magic_arguments.argument('-a', '--noblock', action="store_const",
67 const=False, dest='block',
68 help="use non-blocking (async) execution"
69 ),
70 magic_arguments.argument('-t', '--targets', type=str,
71 help="specify the targets on which to execute"
72 ),
73 ]
74 for a in args:
75 f = a(f)
76 return f
77
78 def output_args(f):
79 """decorator for output-formatting args
80
81 applied to %pxresult and %%px
82 """
83 args = [
84 magic_arguments.argument('-r', action="store_const", dest='groupby',
85 const='order',
86 help="collate outputs in order (same as group-outputs=order)"
87 ),
88 magic_arguments.argument('-e', action="store_const", dest='groupby',
89 const='engine',
90 help="group outputs by engine (same as group-outputs=engine)"
91 ),
92 magic_arguments.argument('--group-outputs', dest='groupby', type=str,
93 choices=['engine', 'order', 'type'], default='type',
94 help="""Group the outputs in a particular way.
95
96 Choices are:
97
98 type: group outputs of all engines by type (stdout, stderr, displaypub, etc.).
99
100 engine: display all output for each engine together.
101
102 order: like type, but individual displaypub output from each engine is collated.
103 For example, if multiple plots are generated by each engine, the first
104 figure of each engine will be displayed, then the second of each, etc.
105 """
106 ),
107 magic_arguments.argument('-o', '--out', dest='save_name', type=str,
108 help="""store the AsyncResult object for this computation
109 in the global namespace under this name.
110 """
111 ),
112 ]
113 for a in args:
114 f = a(f)
115 return f
52 116
53 @magics_class
54 117 class ParallelMagics(Magics):
55 118 """A set of magics useful when controlling a parallel IPython cluster.
56 119 """
57 120
121 # magic-related
122 magics = None
123 registered = True
124
125 # suffix for magics
126 suffix = ''
58 127 # A flag showing if autopx is activated or not
59 128 _autopx = False
60 129 # the current view used by the magics:
61 active_view = None
62 # last result cache for %result
130 view = None
131 # last result cache for %pxresult
63 132 last_result = None
64 133
65 @skip_doctest
66 @line_magic
67 def result(self, line=''):
68 """Print the result of the last asynchronous %px command.
69
70 Usage:
134 def __init__(self, shell, view, suffix=''):
135 self.view = view
136 self.suffix = suffix
71 137
72 %result [-o] [-e] [--group-options=type|engine|order]
138 # register magics
139 self.magics = dict(cell={},line={})
140 line_magics = self.magics['line']
73 141
74 Options:
142 px = 'px' + suffix
143 if not suffix:
144 # keep %result for legacy compatibility
145 line_magics['result'] = self.result
75 146
76 -o: collate outputs in order (same as group-outputs=order)
147 line_magics['pxresult' + suffix] = self.result
148 line_magics[px] = self.px
149 line_magics['pxconfig' + suffix] = self.pxconfig
150 line_magics['auto' + px] = self.autopx
77 151
78 -e: group outputs by engine (same as group-outputs=engine)
152 self.magics['cell'][px] = self.cell_px
79 153
80 --group-outputs=type [default behavior]:
81 each output type (stdout, stderr, displaypub) for all engines
82 displayed together.
83
84 --group-outputs=order:
85 The same as 'type', but individual displaypub outputs (e.g. plots)
86 will be interleaved, so it will display all of the first plots,
87 then all of the second plots, etc.
88
89 --group-outputs=engine:
90 All of an engine's output is displayed before moving on to the next.
91
92 To use this a :class:`DirectView` instance must be created
93 and then activated by calling its :meth:`activate` method.
154 super(ParallelMagics, self).__init__(shell=shell)
155
156 def _eval_target_str(self, ts):
157 if ':' in ts:
158 targets = eval("self.view.client.ids[%s]" % ts)
159 elif 'all' in ts:
160 targets = 'all'
161 else:
162 targets = eval(ts)
163 return targets
164
165 @magic_arguments.magic_arguments()
166 @exec_args
167 def pxconfig(self, line):
168 """configure default targets/blocking for %px magics"""
169 args = magic_arguments.parse_argstring(self.pxconfig, line)
170 if args.targets:
171 self.view.targets = self._eval_target_str(args.targets)
172 if args.block is not None:
173 self.view.block = args.block
174
175 @magic_arguments.magic_arguments()
176 @output_args
177 @skip_doctest
178 def result(self, line=''):
179 """Print the result of the last asynchronous %px command.
94 180
95 181 This lets you recall the results of %px computations after
96 asynchronous submission (view.block=False).
182 asynchronous submission (block=False).
97 183
98 Then you can do the following::
184 Examples
185 --------
186 ::
99 187
100 188 In [23]: %px os.getpid()
101 189 Async parallel execution on engine(s): all
102 190
103 In [24]: %result
191 In [24]: %pxresult
104 192 [ 8] Out[10]: 60920
105 193 [ 9] Out[10]: 60921
106 194 [10] Out[10]: 60922
107 195 [11] Out[10]: 60923
108 196 """
109 opts, _ = self.parse_options(line, 'oe', 'group-outputs=')
110
111 if 'group-outputs' in opts:
112 groupby = opts['group-outputs']
113 elif 'o' in opts:
114 groupby = 'order'
115 elif 'e' in opts:
116 groupby = 'engine'
117 else:
118 groupby = 'type'
119
120 if self.active_view is None:
121 raise UsageError(NO_ACTIVE_VIEW)
197 args = magic_arguments.parse_argstring(self.result, line)
122 198
123 199 if self.last_result is None:
124 200 raise UsageError(NO_LAST_RESULT)
125 201
126 202 self.last_result.get()
127 self.last_result.display_outputs(groupby=groupby)
203 self.last_result.display_outputs(groupby=args.groupby)
128 204
129 205 @skip_doctest
130 @line_magic
131 def px(self, parameter_s=''):
206 def px(self, line=''):
132 207 """Executes the given python command in parallel.
133 208
134 To use this a :class:`DirectView` instance must be created
135 and then activated by calling its :meth:`activate` method.
136
137 Then you can do the following::
209 Examples
210 --------
211 ::
138 212
139 213 In [24]: %px a = os.getpid()
140 214 Parallel execution on engine(s): all
141 215
142 216 In [25]: %px print a
143 217 [stdout:0] 1234
144 218 [stdout:1] 1235
145 219 [stdout:2] 1236
146 220 [stdout:3] 1237
147 221 """
148 return self.parallel_execute(parameter_s)
222 return self.parallel_execute(line)
149 223
150 224 def parallel_execute(self, cell, block=None, groupby='type', save_name=None):
151 225 """implementation used by %px and %%parallel"""
152 226
153 if self.active_view is None:
154 raise UsageError(NO_ACTIVE_VIEW)
155
156 227 # defaults:
157 block = self.active_view.block if block is None else block
228 block = self.view.block if block is None else block
158 229
159 230 base = "Parallel" if block else "Async parallel"
160 231
161 targets = self.active_view.targets
232 targets = self.view.targets
162 233 if isinstance(targets, list) and len(targets) > 10:
163 234 str_targets = str(targets[:4])[:-1] + ', ..., ' + str(targets[-4:])[1:]
164 235 else:
165 236 str_targets = str(targets)
166 237 print base + " execution on engine(s): %s" % str_targets
167 238
168 result = self.active_view.execute(cell, silent=False, block=False)
239 result = self.view.execute(cell, silent=False, block=False)
169 240 self.last_result = result
170 241
171 242 if save_name:
172 243 self.shell.user_ns[save_name] = result
173 244
174 245 if block:
175 246 result.get()
176 247 result.display_outputs(groupby)
177 248 else:
178 249 # return AsyncResult only on non-blocking submission
179 250 return result
180 251
252 @magic_arguments.magic_arguments()
253 @exec_args
254 @output_args
181 255 @skip_doctest
182 @cell_magic('px')
183 256 def cell_px(self, line='', cell=None):
184 """Executes the given python command in parallel.
185
186 Cell magic usage:
187
188 %%px [-o] [-e] [--group-options=type|engine|order] [--[no]block] [--out name]
189
190 Options:
191
192 --out <name>: store the AsyncResult object for this computation
193 in the global namespace.
194
195 -o: collate outputs in order (same as group-outputs=order)
196
197 -e: group outputs by engine (same as group-outputs=engine)
198
199 --group-outputs=type [default behavior]:
200 each output type (stdout, stderr, displaypub) for all engines
201 displayed together.
202
203 --group-outputs=order:
204 The same as 'type', but individual displaypub outputs (e.g. plots)
205 will be interleaved, so it will display all of the first plots,
206 then all of the second plots, etc.
207
208 --group-outputs=engine:
209 All of an engine's output is displayed before moving on to the next.
210
211 --[no]block:
212 Whether or not to block for the execution to complete
213 (and display the results). If unspecified, the active view's
257 """Executes the cell in parallel.
214 258
215
216 To use this a :class:`DirectView` instance must be created
217 and then activated by calling its :meth:`activate` method.
218
219 Then you can do the following::
259 Examples
260 --------
261 ::
220 262
221 263 In [24]: %%px --noblock
222 264 ....: a = os.getpid()
223 265 Async parallel execution on engine(s): all
224 266
225 267 In [25]: %%px
226 268 ....: print a
227 269 [stdout:0] 1234
228 270 [stdout:1] 1235
229 271 [stdout:2] 1236
230 272 [stdout:3] 1237
231 273 """
232 274
233 block = None
234 groupby = 'type'
235 # as a cell magic, we accept args
236 opts, _ = self.parse_options(line, 'oe', 'group-outputs=', 'out=', 'block', 'noblock')
237
238 if 'group-outputs' in opts:
239 groupby = opts['group-outputs']
240 elif 'o' in opts:
241 groupby = 'order'
242 elif 'e' in opts:
243 groupby = 'engine'
244
245 if 'block' in opts:
246 block = True
247 elif 'noblock' in opts:
248 block = False
249
250 save_name = opts.get('out')
251
252 return self.parallel_execute(cell, block=block, groupby=groupby, save_name=save_name)
253
275 args = magic_arguments.parse_argstring(self.cell_px, line)
276
277 if args.targets:
278 save_targets = self.view.targets
279 self.view.targets = self._eval_target_str(args.targets)
280 try:
281 return self.parallel_execute(cell, block=args.block,
282 groupby=args.groupby,
283 save_name=args.save_name,
284 )
285 finally:
286 if args.targets:
287 self.view.targets = save_targets
288
254 289 @skip_doctest
255 @line_magic
256 def autopx(self, parameter_s=''):
290 def autopx(self, line=''):
257 291 """Toggles auto parallel mode.
258 292
259 To use this a :class:`DirectView` instance must be created
260 and then activated by calling its :meth:`activate` method. Once this
261 is called, all commands typed at the command line are send to
262 the engines to be executed in parallel. To control which engine
263 are used, set the ``targets`` attributed of the multiengine client
264 before entering ``%autopx`` mode.
293 Once this is called, all commands typed at the command line are send to
294 the engines to be executed in parallel. To control which engine are
295 used, the ``targets`` attribute of the view before
296 entering ``%autopx`` mode.
297
265 298
266 299 Then you can do the following::
267 300
268 301 In [25]: %autopx
269 302 %autopx to enabled
270 303
271 304 In [26]: a = 10
272 305 Parallel execution on engine(s): [0,1,2,3]
273 306 In [27]: print a
274 307 Parallel execution on engine(s): [0,1,2,3]
275 308 [stdout:0] 10
276 309 [stdout:1] 10
277 310 [stdout:2] 10
278 311 [stdout:3] 10
279 312
280 313
281 314 In [27]: %autopx
282 315 %autopx disabled
283 316 """
284 317 if self._autopx:
285 318 self._disable_autopx()
286 319 else:
287 320 self._enable_autopx()
288 321
289 322 def _enable_autopx(self):
290 323 """Enable %autopx mode by saving the original run_cell and installing
291 324 pxrun_cell.
292 325 """
293 if self.active_view is None:
294 raise UsageError(NO_ACTIVE_VIEW)
295
296 326 # override run_cell
297 327 self._original_run_cell = self.shell.run_cell
298 328 self.shell.run_cell = self.pxrun_cell
299 329
300 330 self._autopx = True
301 331 print "%autopx enabled"
302 332
303 333 def _disable_autopx(self):
304 334 """Disable %autopx by restoring the original InteractiveShell.run_cell.
305 335 """
306 336 if self._autopx:
307 337 self.shell.run_cell = self._original_run_cell
308 338 self._autopx = False
309 339 print "%autopx disabled"
310 340
311 341 def pxrun_cell(self, raw_cell, store_history=False, silent=False):
312 342 """drop-in replacement for InteractiveShell.run_cell.
313 343
314 344 This executes code remotely, instead of in the local namespace.
315 345
316 346 See InteractiveShell.run_cell for details.
317 347 """
318 348
319 349 if (not raw_cell) or raw_cell.isspace():
320 350 return
321 351
322 352 ipself = self.shell
323 353
324 354 with ipself.builtin_trap:
325 355 cell = ipself.prefilter_manager.prefilter_lines(raw_cell)
326 356
327 357 # Store raw and processed history
328 358 if store_history:
329 359 ipself.history_manager.store_inputs(ipself.execution_count,
330 360 cell, raw_cell)
331 361
332 362 # ipself.logger.log(cell, raw_cell)
333 363
334 364 cell_name = ipself.compile.cache(cell, ipself.execution_count)
335 365
336 366 try:
337 367 ast.parse(cell, filename=cell_name)
338 368 except (OverflowError, SyntaxError, ValueError, TypeError,
339 369 MemoryError):
340 370 # Case 1
341 371 ipself.showsyntaxerror()
342 372 ipself.execution_count += 1
343 373 return None
344 374 except NameError:
345 375 # ignore name errors, because we don't know the remote keys
346 376 pass
347 377
348 378 if store_history:
349 379 # Write output to the database. Does nothing unless
350 380 # history output logging is enabled.
351 381 ipself.history_manager.store_output(ipself.execution_count)
352 382 # Each cell is a *single* input, regardless of how many lines it has
353 383 ipself.execution_count += 1
354 384 if re.search(r'get_ipython\(\)\.magic\(u?["\']%?autopx', cell):
355 385 self._disable_autopx()
356 386 return False
357 387 else:
358 388 try:
359 result = self.active_view.execute(cell, silent=False, block=False)
389 result = self.view.execute(cell, silent=False, block=False)
360 390 except:
361 391 ipself.showtraceback()
362 392 return True
363 393 else:
364 if self.active_view.block:
394 if self.view.block:
365 395 try:
366 396 result.get()
367 397 except:
368 398 self.shell.showtraceback()
369 399 return True
370 400 else:
371 401 with ipself.builtin_trap:
372 402 result.display_outputs()
373 403 return False
374 404
375 405
376 406 __doc__ = __doc__.format(
377 407 AUTOPX_DOC = ' '*8 + ParallelMagics.autopx.__doc__,
378 408 PX_DOC = ' '*8 + ParallelMagics.px.__doc__,
379 RESULT_DOC = ' '*8 + ParallelMagics.result.__doc__
409 RESULT_DOC = ' '*8 + ParallelMagics.result.__doc__,
410 CONFIG_DOC = ' '*8 + ParallelMagics.pxconfig.__doc__,
380 411 )
381
382 _loaded = False
383
384
385 def load_ipython_extension(ip):
386 """Load the extension in IPython."""
387 global _loaded
388 if not _loaded:
389 ip.register_magics(ParallelMagics)
390 _loaded = True
@@ -1,1100 +1,1104 b''
1 1 """Views of remote engines.
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import imp
19 19 import sys
20 20 import warnings
21 21 from contextlib import contextmanager
22 22 from types import ModuleType
23 23
24 24 import zmq
25 25
26 26 from IPython.testing.skipdoctest import skip_doctest
27 27 from IPython.utils.traitlets import (
28 28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
29 29 )
30 30 from IPython.external.decorator import decorator
31 31
32 32 from IPython.parallel import util
33 33 from IPython.parallel.controller.dependency import Dependency, dependent
34 34
35 35 from . import map as Map
36 36 from .asyncresult import AsyncResult, AsyncMapResult
37 37 from .remotefunction import ParallelFunction, parallel, remote, getname
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Decorators
41 41 #-----------------------------------------------------------------------------
42 42
43 43 @decorator
44 44 def save_ids(f, self, *args, **kwargs):
45 45 """Keep our history and outstanding attributes up to date after a method call."""
46 46 n_previous = len(self.client.history)
47 47 try:
48 48 ret = f(self, *args, **kwargs)
49 49 finally:
50 50 nmsgs = len(self.client.history) - n_previous
51 51 msg_ids = self.client.history[-nmsgs:]
52 52 self.history.extend(msg_ids)
53 53 map(self.outstanding.add, msg_ids)
54 54 return ret
55 55
56 56 @decorator
57 57 def sync_results(f, self, *args, **kwargs):
58 58 """sync relevant results from self.client to our results attribute."""
59 59 ret = f(self, *args, **kwargs)
60 60 delta = self.outstanding.difference(self.client.outstanding)
61 61 completed = self.outstanding.intersection(delta)
62 62 self.outstanding = self.outstanding.difference(completed)
63 63 for msg_id in completed:
64 64 self.results[msg_id] = self.client.results[msg_id]
65 65 return ret
66 66
67 67 @decorator
68 68 def spin_after(f, self, *args, **kwargs):
69 69 """call spin after the method."""
70 70 ret = f(self, *args, **kwargs)
71 71 self.spin()
72 72 return ret
73 73
74 74 #-----------------------------------------------------------------------------
75 75 # Classes
76 76 #-----------------------------------------------------------------------------
77 77
78 78 @skip_doctest
79 79 class View(HasTraits):
80 80 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
81 81
82 82 Don't use this class, use subclasses.
83 83
84 84 Methods
85 85 -------
86 86
87 87 spin
88 88 flushes incoming results and registration state changes
89 89 control methods spin, and requesting `ids` also ensures up to date
90 90
91 91 wait
92 92 wait on one or more msg_ids
93 93
94 94 execution methods
95 95 apply
96 96 legacy: execute, run
97 97
98 98 data movement
99 99 push, pull, scatter, gather
100 100
101 101 query methods
102 102 get_result, queue_status, purge_results, result_status
103 103
104 104 control methods
105 105 abort, shutdown
106 106
107 107 """
108 108 # flags
109 109 block=Bool(False)
110 110 track=Bool(True)
111 111 targets = Any()
112 112
113 113 history=List()
114 114 outstanding = Set()
115 115 results = Dict()
116 116 client = Instance('IPython.parallel.Client')
117 117
118 118 _socket = Instance('zmq.Socket')
119 119 _flag_names = List(['targets', 'block', 'track'])
120 120 _targets = Any()
121 121 _idents = Any()
122 122
123 123 def __init__(self, client=None, socket=None, **flags):
124 124 super(View, self).__init__(client=client, _socket=socket)
125 125 self.block = client.block
126 126
127 127 self.set_flags(**flags)
128 128
129 129 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
130 130
131 131 def __repr__(self):
132 132 strtargets = str(self.targets)
133 133 if len(strtargets) > 16:
134 134 strtargets = strtargets[:12]+'...]'
135 135 return "<%s %s>"%(self.__class__.__name__, strtargets)
136 136
137 137 def __len__(self):
138 138 if isinstance(self.targets, list):
139 139 return len(self.targets)
140 140 elif isinstance(self.targets, int):
141 141 return 1
142 142 else:
143 143 return len(self.client)
144 144
145 145 def set_flags(self, **kwargs):
146 146 """set my attribute flags by keyword.
147 147
148 148 Views determine behavior with a few attributes (`block`, `track`, etc.).
149 149 These attributes can be set all at once by name with this method.
150 150
151 151 Parameters
152 152 ----------
153 153
154 154 block : bool
155 155 whether to wait for results
156 156 track : bool
157 157 whether to create a MessageTracker to allow the user to
158 158 safely edit after arrays and buffers during non-copying
159 159 sends.
160 160 """
161 161 for name, value in kwargs.iteritems():
162 162 if name not in self._flag_names:
163 163 raise KeyError("Invalid name: %r"%name)
164 164 else:
165 165 setattr(self, name, value)
166 166
167 167 @contextmanager
168 168 def temp_flags(self, **kwargs):
169 169 """temporarily set flags, for use in `with` statements.
170 170
171 171 See set_flags for permanent setting of flags
172 172
173 173 Examples
174 174 --------
175 175
176 176 >>> view.track=False
177 177 ...
178 178 >>> with view.temp_flags(track=True):
179 179 ... ar = view.apply(dostuff, my_big_array)
180 180 ... ar.tracker.wait() # wait for send to finish
181 181 >>> view.track
182 182 False
183 183
184 184 """
185 185 # preflight: save flags, and set temporaries
186 186 saved_flags = {}
187 187 for f in self._flag_names:
188 188 saved_flags[f] = getattr(self, f)
189 189 self.set_flags(**kwargs)
190 190 # yield to the with-statement block
191 191 try:
192 192 yield
193 193 finally:
194 194 # postflight: restore saved flags
195 195 self.set_flags(**saved_flags)
196 196
197 197
198 198 #----------------------------------------------------------------
199 199 # apply
200 200 #----------------------------------------------------------------
201 201
202 202 @sync_results
203 203 @save_ids
204 204 def _really_apply(self, f, args, kwargs, block=None, **options):
205 205 """wrapper for client.send_apply_request"""
206 206 raise NotImplementedError("Implement in subclasses")
207 207
208 208 def apply(self, f, *args, **kwargs):
209 209 """calls f(*args, **kwargs) on remote engines, returning the result.
210 210
211 211 This method sets all apply flags via this View's attributes.
212 212
213 213 if self.block is False:
214 214 returns AsyncResult
215 215 else:
216 216 returns actual result of f(*args, **kwargs)
217 217 """
218 218 return self._really_apply(f, args, kwargs)
219 219
220 220 def apply_async(self, f, *args, **kwargs):
221 221 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
222 222
223 223 returns AsyncResult
224 224 """
225 225 return self._really_apply(f, args, kwargs, block=False)
226 226
227 227 @spin_after
228 228 def apply_sync(self, f, *args, **kwargs):
229 229 """calls f(*args, **kwargs) on remote engines in a blocking manner,
230 230 returning the result.
231 231
232 232 returns: actual result of f(*args, **kwargs)
233 233 """
234 234 return self._really_apply(f, args, kwargs, block=True)
235 235
236 236 #----------------------------------------------------------------
237 237 # wrappers for client and control methods
238 238 #----------------------------------------------------------------
239 239 @sync_results
240 240 def spin(self):
241 241 """spin the client, and sync"""
242 242 self.client.spin()
243 243
244 244 @sync_results
245 245 def wait(self, jobs=None, timeout=-1):
246 246 """waits on one or more `jobs`, for up to `timeout` seconds.
247 247
248 248 Parameters
249 249 ----------
250 250
251 251 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
252 252 ints are indices to self.history
253 253 strs are msg_ids
254 254 default: wait on all outstanding messages
255 255 timeout : float
256 256 a time in seconds, after which to give up.
257 257 default is -1, which means no timeout
258 258
259 259 Returns
260 260 -------
261 261
262 262 True : when all msg_ids are done
263 263 False : timeout reached, some msg_ids still outstanding
264 264 """
265 265 if jobs is None:
266 266 jobs = self.history
267 267 return self.client.wait(jobs, timeout)
268 268
269 269 def abort(self, jobs=None, targets=None, block=None):
270 270 """Abort jobs on my engines.
271 271
272 272 Parameters
273 273 ----------
274 274
275 275 jobs : None, str, list of strs, optional
276 276 if None: abort all jobs.
277 277 else: abort specific msg_id(s).
278 278 """
279 279 block = block if block is not None else self.block
280 280 targets = targets if targets is not None else self.targets
281 281 jobs = jobs if jobs is not None else list(self.outstanding)
282 282
283 283 return self.client.abort(jobs=jobs, targets=targets, block=block)
284 284
285 285 def queue_status(self, targets=None, verbose=False):
286 286 """Fetch the Queue status of my engines"""
287 287 targets = targets if targets is not None else self.targets
288 288 return self.client.queue_status(targets=targets, verbose=verbose)
289 289
290 290 def purge_results(self, jobs=[], targets=[]):
291 291 """Instruct the controller to forget specific results."""
292 292 if targets is None or targets == 'all':
293 293 targets = self.targets
294 294 return self.client.purge_results(jobs=jobs, targets=targets)
295 295
296 296 def shutdown(self, targets=None, restart=False, hub=False, block=None):
297 297 """Terminates one or more engine processes, optionally including the hub.
298 298 """
299 299 block = self.block if block is None else block
300 300 if targets is None or targets == 'all':
301 301 targets = self.targets
302 302 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
303 303
304 304 @spin_after
305 305 def get_result(self, indices_or_msg_ids=None):
306 306 """return one or more results, specified by history index or msg_id.
307 307
308 308 See client.get_result for details.
309 309
310 310 """
311 311
312 312 if indices_or_msg_ids is None:
313 313 indices_or_msg_ids = -1
314 314 if isinstance(indices_or_msg_ids, int):
315 315 indices_or_msg_ids = self.history[indices_or_msg_ids]
316 316 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
317 317 indices_or_msg_ids = list(indices_or_msg_ids)
318 318 for i,index in enumerate(indices_or_msg_ids):
319 319 if isinstance(index, int):
320 320 indices_or_msg_ids[i] = self.history[index]
321 321 return self.client.get_result(indices_or_msg_ids)
322 322
323 323 #-------------------------------------------------------------------
324 324 # Map
325 325 #-------------------------------------------------------------------
326 326
327 327 def map(self, f, *sequences, **kwargs):
328 328 """override in subclasses"""
329 329 raise NotImplementedError
330 330
331 331 def map_async(self, f, *sequences, **kwargs):
332 332 """Parallel version of builtin `map`, using this view's engines.
333 333
334 334 This is equivalent to map(...block=False)
335 335
336 336 See `self.map` for details.
337 337 """
338 338 if 'block' in kwargs:
339 339 raise TypeError("map_async doesn't take a `block` keyword argument.")
340 340 kwargs['block'] = False
341 341 return self.map(f,*sequences,**kwargs)
342 342
343 343 def map_sync(self, f, *sequences, **kwargs):
344 344 """Parallel version of builtin `map`, using this view's engines.
345 345
346 346 This is equivalent to map(...block=True)
347 347
348 348 See `self.map` for details.
349 349 """
350 350 if 'block' in kwargs:
351 351 raise TypeError("map_sync doesn't take a `block` keyword argument.")
352 352 kwargs['block'] = True
353 353 return self.map(f,*sequences,**kwargs)
354 354
355 355 def imap(self, f, *sequences, **kwargs):
356 356 """Parallel version of `itertools.imap`.
357 357
358 358 See `self.map` for details.
359 359
360 360 """
361 361
362 362 return iter(self.map_async(f,*sequences, **kwargs))
363 363
364 364 #-------------------------------------------------------------------
365 365 # Decorators
366 366 #-------------------------------------------------------------------
367 367
368 368 def remote(self, block=True, **flags):
369 369 """Decorator for making a RemoteFunction"""
370 370 block = self.block if block is None else block
371 371 return remote(self, block=block, **flags)
372 372
373 373 def parallel(self, dist='b', block=None, **flags):
374 374 """Decorator for making a ParallelFunction"""
375 375 block = self.block if block is None else block
376 376 return parallel(self, dist=dist, block=block, **flags)
377 377
378 378 @skip_doctest
379 379 class DirectView(View):
380 380 """Direct Multiplexer View of one or more engines.
381 381
382 382 These are created via indexed access to a client:
383 383
384 384 >>> dv_1 = client[1]
385 385 >>> dv_all = client[:]
386 386 >>> dv_even = client[::2]
387 387 >>> dv_some = client[1:3]
388 388
389 389 This object provides dictionary access to engine namespaces:
390 390
391 391 # push a=5:
392 392 >>> dv['a'] = 5
393 393 # pull 'foo':
394 394 >>> db['foo']
395 395
396 396 """
397 397
398 398 def __init__(self, client=None, socket=None, targets=None):
399 399 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
400 400
401 401 @property
402 402 def importer(self):
403 403 """sync_imports(local=True) as a property.
404 404
405 405 See sync_imports for details.
406 406
407 407 """
408 408 return self.sync_imports(True)
409 409
410 410 @contextmanager
411 411 def sync_imports(self, local=True, quiet=False):
412 412 """Context Manager for performing simultaneous local and remote imports.
413 413
414 414 'import x as y' will *not* work. The 'as y' part will simply be ignored.
415 415
416 416 If `local=True`, then the package will also be imported locally.
417 417
418 418 If `quiet=True`, no output will be produced when attempting remote
419 419 imports.
420 420
421 421 Note that remote-only (`local=False`) imports have not been implemented.
422 422
423 423 >>> with view.sync_imports():
424 424 ... from numpy import recarray
425 425 importing recarray from numpy on engine(s)
426 426
427 427 """
428 428 import __builtin__
429 429 local_import = __builtin__.__import__
430 430 modules = set()
431 431 results = []
432 432 @util.interactive
433 433 def remote_import(name, fromlist, level):
434 434 """the function to be passed to apply, that actually performs the import
435 435 on the engine, and loads up the user namespace.
436 436 """
437 437 import sys
438 438 user_ns = globals()
439 439 mod = __import__(name, fromlist=fromlist, level=level)
440 440 if fromlist:
441 441 for key in fromlist:
442 442 user_ns[key] = getattr(mod, key)
443 443 else:
444 444 user_ns[name] = sys.modules[name]
445 445
446 446 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
447 447 """the drop-in replacement for __import__, that optionally imports
448 448 locally as well.
449 449 """
450 450 # don't override nested imports
451 451 save_import = __builtin__.__import__
452 452 __builtin__.__import__ = local_import
453 453
454 454 if imp.lock_held():
455 455 # this is a side-effect import, don't do it remotely, or even
456 456 # ignore the local effects
457 457 return local_import(name, globals, locals, fromlist, level)
458 458
459 459 imp.acquire_lock()
460 460 if local:
461 461 mod = local_import(name, globals, locals, fromlist, level)
462 462 else:
463 463 raise NotImplementedError("remote-only imports not yet implemented")
464 464 imp.release_lock()
465 465
466 466 key = name+':'+','.join(fromlist or [])
467 467 if level == -1 and key not in modules:
468 468 modules.add(key)
469 469 if not quiet:
470 470 if fromlist:
471 471 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
472 472 else:
473 473 print "importing %s on engine(s)"%name
474 474 results.append(self.apply_async(remote_import, name, fromlist, level))
475 475 # restore override
476 476 __builtin__.__import__ = save_import
477 477
478 478 return mod
479 479
480 480 # override __import__
481 481 __builtin__.__import__ = view_import
482 482 try:
483 483 # enter the block
484 484 yield
485 485 except ImportError:
486 486 if local:
487 487 raise
488 488 else:
489 489 # ignore import errors if not doing local imports
490 490 pass
491 491 finally:
492 492 # always restore __import__
493 493 __builtin__.__import__ = local_import
494 494
495 495 for r in results:
496 496 # raise possible remote ImportErrors here
497 497 r.get()
498 498
499 499
500 500 @sync_results
501 501 @save_ids
502 502 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
503 503 """calls f(*args, **kwargs) on remote engines, returning the result.
504 504
505 505 This method sets all of `apply`'s flags via this View's attributes.
506 506
507 507 Parameters
508 508 ----------
509 509
510 510 f : callable
511 511
512 512 args : list [default: empty]
513 513
514 514 kwargs : dict [default: empty]
515 515
516 516 targets : target list [default: self.targets]
517 517 where to run
518 518 block : bool [default: self.block]
519 519 whether to block
520 520 track : bool [default: self.track]
521 521 whether to ask zmq to track the message, for safe non-copying sends
522 522
523 523 Returns
524 524 -------
525 525
526 526 if self.block is False:
527 527 returns AsyncResult
528 528 else:
529 529 returns actual result of f(*args, **kwargs) on the engine(s)
530 530 This will be a list of self.targets is also a list (even length 1), or
531 531 the single result if self.targets is an integer engine id
532 532 """
533 533 args = [] if args is None else args
534 534 kwargs = {} if kwargs is None else kwargs
535 535 block = self.block if block is None else block
536 536 track = self.track if track is None else track
537 537 targets = self.targets if targets is None else targets
538 538
539 539 _idents = self.client._build_targets(targets)[0]
540 540 msg_ids = []
541 541 trackers = []
542 542 for ident in _idents:
543 543 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
544 544 ident=ident)
545 545 if track:
546 546 trackers.append(msg['tracker'])
547 547 msg_ids.append(msg['header']['msg_id'])
548 548 tracker = None if track is False else zmq.MessageTracker(*trackers)
549 549 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker)
550 550 if block:
551 551 try:
552 552 return ar.get()
553 553 except KeyboardInterrupt:
554 554 pass
555 555 return ar
556 556
557 557
558 558 @spin_after
559 559 def map(self, f, *sequences, **kwargs):
560 560 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
561 561
562 562 Parallel version of builtin `map`, using this View's `targets`.
563 563
564 564 There will be one task per target, so work will be chunked
565 565 if the sequences are longer than `targets`.
566 566
567 567 Results can be iterated as they are ready, but will become available in chunks.
568 568
569 569 Parameters
570 570 ----------
571 571
572 572 f : callable
573 573 function to be mapped
574 574 *sequences: one or more sequences of matching length
575 575 the sequences to be distributed and passed to `f`
576 576 block : bool
577 577 whether to wait for the result or not [default self.block]
578 578
579 579 Returns
580 580 -------
581 581
582 582 if block=False:
583 583 AsyncMapResult
584 584 An object like AsyncResult, but which reassembles the sequence of results
585 585 into a single list. AsyncMapResults can be iterated through before all
586 586 results are complete.
587 587 else:
588 588 list
589 589 the result of map(f,*sequences)
590 590 """
591 591
592 592 block = kwargs.pop('block', self.block)
593 593 for k in kwargs.keys():
594 594 if k not in ['block', 'track']:
595 595 raise TypeError("invalid keyword arg, %r"%k)
596 596
597 597 assert len(sequences) > 0, "must have some sequences to map onto!"
598 598 pf = ParallelFunction(self, f, block=block, **kwargs)
599 599 return pf.map(*sequences)
600 600
601 601 @sync_results
602 602 @save_ids
603 603 def execute(self, code, silent=True, targets=None, block=None):
604 604 """Executes `code` on `targets` in blocking or nonblocking manner.
605 605
606 606 ``execute`` is always `bound` (affects engine namespace)
607 607
608 608 Parameters
609 609 ----------
610 610
611 611 code : str
612 612 the code string to be executed
613 613 block : bool
614 614 whether or not to wait until done to return
615 615 default: self.block
616 616 """
617 617 block = self.block if block is None else block
618 618 targets = self.targets if targets is None else targets
619 619
620 620 _idents = self.client._build_targets(targets)[0]
621 621 msg_ids = []
622 622 trackers = []
623 623 for ident in _idents:
624 624 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
625 625 msg_ids.append(msg['header']['msg_id'])
626 626 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=targets)
627 627 if block:
628 628 try:
629 629 ar.get()
630 630 except KeyboardInterrupt:
631 631 pass
632 632 return ar
633 633
634 634 def run(self, filename, targets=None, block=None):
635 635 """Execute contents of `filename` on my engine(s).
636 636
637 637 This simply reads the contents of the file and calls `execute`.
638 638
639 639 Parameters
640 640 ----------
641 641
642 642 filename : str
643 643 The path to the file
644 644 targets : int/str/list of ints/strs
645 645 the engines on which to execute
646 646 default : all
647 647 block : bool
648 648 whether or not to wait until done
649 649 default: self.block
650 650
651 651 """
652 652 with open(filename, 'r') as f:
653 653 # add newline in case of trailing indented whitespace
654 654 # which will cause SyntaxError
655 655 code = f.read()+'\n'
656 656 return self.execute(code, block=block, targets=targets)
657 657
658 658 def update(self, ns):
659 659 """update remote namespace with dict `ns`
660 660
661 661 See `push` for details.
662 662 """
663 663 return self.push(ns, block=self.block, track=self.track)
664 664
665 665 def push(self, ns, targets=None, block=None, track=None):
666 666 """update remote namespace with dict `ns`
667 667
668 668 Parameters
669 669 ----------
670 670
671 671 ns : dict
672 672 dict of keys with which to update engine namespace(s)
673 673 block : bool [default : self.block]
674 674 whether to wait to be notified of engine receipt
675 675
676 676 """
677 677
678 678 block = block if block is not None else self.block
679 679 track = track if track is not None else self.track
680 680 targets = targets if targets is not None else self.targets
681 681 # applier = self.apply_sync if block else self.apply_async
682 682 if not isinstance(ns, dict):
683 683 raise TypeError("Must be a dict, not %s"%type(ns))
684 684 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
685 685
686 686 def get(self, key_s):
687 687 """get object(s) by `key_s` from remote namespace
688 688
689 689 see `pull` for details.
690 690 """
691 691 # block = block if block is not None else self.block
692 692 return self.pull(key_s, block=True)
693 693
694 694 def pull(self, names, targets=None, block=None):
695 695 """get object(s) by `name` from remote namespace
696 696
697 697 will return one object if it is a key.
698 698 can also take a list of keys, in which case it will return a list of objects.
699 699 """
700 700 block = block if block is not None else self.block
701 701 targets = targets if targets is not None else self.targets
702 702 applier = self.apply_sync if block else self.apply_async
703 703 if isinstance(names, basestring):
704 704 pass
705 705 elif isinstance(names, (list,tuple,set)):
706 706 for key in names:
707 707 if not isinstance(key, basestring):
708 708 raise TypeError("keys must be str, not type %r"%type(key))
709 709 else:
710 710 raise TypeError("names must be strs, not %r"%names)
711 711 return self._really_apply(util._pull, (names,), block=block, targets=targets)
712 712
713 713 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
714 714 """
715 715 Partition a Python sequence and send the partitions to a set of engines.
716 716 """
717 717 block = block if block is not None else self.block
718 718 track = track if track is not None else self.track
719 719 targets = targets if targets is not None else self.targets
720 720
721 721 # construct integer ID list:
722 722 targets = self.client._build_targets(targets)[1]
723 723
724 724 mapObject = Map.dists[dist]()
725 725 nparts = len(targets)
726 726 msg_ids = []
727 727 trackers = []
728 728 for index, engineid in enumerate(targets):
729 729 partition = mapObject.getPartition(seq, index, nparts)
730 730 if flatten and len(partition) == 1:
731 731 ns = {key: partition[0]}
732 732 else:
733 733 ns = {key: partition}
734 734 r = self.push(ns, block=False, track=track, targets=engineid)
735 735 msg_ids.extend(r.msg_ids)
736 736 if track:
737 737 trackers.append(r._tracker)
738 738
739 739 if track:
740 740 tracker = zmq.MessageTracker(*trackers)
741 741 else:
742 742 tracker = None
743 743
744 744 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
745 745 if block:
746 746 r.wait()
747 747 else:
748 748 return r
749 749
750 750 @sync_results
751 751 @save_ids
752 752 def gather(self, key, dist='b', targets=None, block=None):
753 753 """
754 754 Gather a partitioned sequence on a set of engines as a single local seq.
755 755 """
756 756 block = block if block is not None else self.block
757 757 targets = targets if targets is not None else self.targets
758 758 mapObject = Map.dists[dist]()
759 759 msg_ids = []
760 760
761 761 # construct integer ID list:
762 762 targets = self.client._build_targets(targets)[1]
763 763
764 764 for index, engineid in enumerate(targets):
765 765 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
766 766
767 767 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
768 768
769 769 if block:
770 770 try:
771 771 return r.get()
772 772 except KeyboardInterrupt:
773 773 pass
774 774 return r
775 775
776 776 def __getitem__(self, key):
777 777 return self.get(key)
778 778
779 779 def __setitem__(self,key, value):
780 780 self.update({key:value})
781 781
782 782 def clear(self, targets=None, block=False):
783 783 """Clear the remote namespaces on my engines."""
784 784 block = block if block is not None else self.block
785 785 targets = targets if targets is not None else self.targets
786 786 return self.client.clear(targets=targets, block=block)
787 787
788 788 def kill(self, targets=None, block=True):
789 789 """Kill my engines."""
790 790 block = block if block is not None else self.block
791 791 targets = targets if targets is not None else self.targets
792 792 return self.client.kill(targets=targets, block=block)
793 793
794 794 #----------------------------------------
795 # activate for %px,%autopx magics
795 # activate for %px, %autopx, etc. magics
796 796 #----------------------------------------
797 def activate(self):
798 """Make this `View` active for parallel magic commands.
799 797
800 IPython has a magic command syntax to work with `MultiEngineClient` objects.
801 In a given IPython session there is a single active one. While
802 there can be many `Views` created and used by the user,
803 there is only one active one. The active `View` is used whenever
804 the magic commands %px and %autopx are used.
805
806 The activate() method is called on a given `View` to make it
807 active. Once this has been done, the magic commands can be used.
798 def activate(self, suffix=''):
799 """Activate IPython magics associated with this View
800
801 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
802
803 Parameters
804 ----------
805
806 suffix: str [default: '']
807 The suffix, if any, for the magics. This allows you to have
808 multiple views associated with parallel magics at the same time.
809
810 e.g. ``rc[::2].activate(suffix='_even')`` will give you
811 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
812 on the even engines.
808 813 """
809
814
815 from IPython.parallel.client.magics import ParallelMagics
816
810 817 try:
811 818 # This is injected into __builtins__.
812 819 ip = get_ipython()
813 820 except NameError:
814 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
815 else:
816 pmagic = ip.magics_manager.registry.get('ParallelMagics')
817 if pmagic is None:
818 ip.magic('load_ext parallelmagic')
819 pmagic = ip.magics_manager.registry.get('ParallelMagics')
820
821 pmagic.active_view = self
821 print "The IPython parallel magics (%px, etc.) only work within IPython."
822 return
823
824 M = ParallelMagics(ip, self, suffix)
825 ip.magics_manager.register(M)
822 826
823 827
824 828 @skip_doctest
825 829 class LoadBalancedView(View):
826 830 """An load-balancing View that only executes via the Task scheduler.
827 831
828 832 Load-balanced views can be created with the client's `view` method:
829 833
830 834 >>> v = client.load_balanced_view()
831 835
832 836 or targets can be specified, to restrict the potential destinations:
833 837
834 838 >>> v = client.client.load_balanced_view([1,3])
835 839
836 840 which would restrict loadbalancing to between engines 1 and 3.
837 841
838 842 """
839 843
840 844 follow=Any()
841 845 after=Any()
842 846 timeout=CFloat()
843 847 retries = Integer(0)
844 848
845 849 _task_scheme = Any()
846 850 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
847 851
848 852 def __init__(self, client=None, socket=None, **flags):
849 853 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
850 854 self._task_scheme=client._task_scheme
851 855
852 856 def _validate_dependency(self, dep):
853 857 """validate a dependency.
854 858
855 859 For use in `set_flags`.
856 860 """
857 861 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
858 862 return True
859 863 elif isinstance(dep, (list,set, tuple)):
860 864 for d in dep:
861 865 if not isinstance(d, (basestring, AsyncResult)):
862 866 return False
863 867 elif isinstance(dep, dict):
864 868 if set(dep.keys()) != set(Dependency().as_dict().keys()):
865 869 return False
866 870 if not isinstance(dep['msg_ids'], list):
867 871 return False
868 872 for d in dep['msg_ids']:
869 873 if not isinstance(d, basestring):
870 874 return False
871 875 else:
872 876 return False
873 877
874 878 return True
875 879
876 880 def _render_dependency(self, dep):
877 881 """helper for building jsonable dependencies from various input forms."""
878 882 if isinstance(dep, Dependency):
879 883 return dep.as_dict()
880 884 elif isinstance(dep, AsyncResult):
881 885 return dep.msg_ids
882 886 elif dep is None:
883 887 return []
884 888 else:
885 889 # pass to Dependency constructor
886 890 return list(Dependency(dep))
887 891
888 892 def set_flags(self, **kwargs):
889 893 """set my attribute flags by keyword.
890 894
891 895 A View is a wrapper for the Client's apply method, but with attributes
892 896 that specify keyword arguments, those attributes can be set by keyword
893 897 argument with this method.
894 898
895 899 Parameters
896 900 ----------
897 901
898 902 block : bool
899 903 whether to wait for results
900 904 track : bool
901 905 whether to create a MessageTracker to allow the user to
902 906 safely edit after arrays and buffers during non-copying
903 907 sends.
904 908
905 909 after : Dependency or collection of msg_ids
906 910 Only for load-balanced execution (targets=None)
907 911 Specify a list of msg_ids as a time-based dependency.
908 912 This job will only be run *after* the dependencies
909 913 have been met.
910 914
911 915 follow : Dependency or collection of msg_ids
912 916 Only for load-balanced execution (targets=None)
913 917 Specify a list of msg_ids as a location-based dependency.
914 918 This job will only be run on an engine where this dependency
915 919 is met.
916 920
917 921 timeout : float/int or None
918 922 Only for load-balanced execution (targets=None)
919 923 Specify an amount of time (in seconds) for the scheduler to
920 924 wait for dependencies to be met before failing with a
921 925 DependencyTimeout.
922 926
923 927 retries : int
924 928 Number of times a task will be retried on failure.
925 929 """
926 930
927 931 super(LoadBalancedView, self).set_flags(**kwargs)
928 932 for name in ('follow', 'after'):
929 933 if name in kwargs:
930 934 value = kwargs[name]
931 935 if self._validate_dependency(value):
932 936 setattr(self, name, value)
933 937 else:
934 938 raise ValueError("Invalid dependency: %r"%value)
935 939 if 'timeout' in kwargs:
936 940 t = kwargs['timeout']
937 941 if not isinstance(t, (int, long, float, type(None))):
938 942 raise TypeError("Invalid type for timeout: %r"%type(t))
939 943 if t is not None:
940 944 if t < 0:
941 945 raise ValueError("Invalid timeout: %s"%t)
942 946 self.timeout = t
943 947
944 948 @sync_results
945 949 @save_ids
946 950 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
947 951 after=None, follow=None, timeout=None,
948 952 targets=None, retries=None):
949 953 """calls f(*args, **kwargs) on a remote engine, returning the result.
950 954
951 955 This method temporarily sets all of `apply`'s flags for a single call.
952 956
953 957 Parameters
954 958 ----------
955 959
956 960 f : callable
957 961
958 962 args : list [default: empty]
959 963
960 964 kwargs : dict [default: empty]
961 965
962 966 block : bool [default: self.block]
963 967 whether to block
964 968 track : bool [default: self.track]
965 969 whether to ask zmq to track the message, for safe non-copying sends
966 970
967 971 !!!!!! TODO: THE REST HERE !!!!
968 972
969 973 Returns
970 974 -------
971 975
972 976 if self.block is False:
973 977 returns AsyncResult
974 978 else:
975 979 returns actual result of f(*args, **kwargs) on the engine(s)
976 980 This will be a list of self.targets is also a list (even length 1), or
977 981 the single result if self.targets is an integer engine id
978 982 """
979 983
980 984 # validate whether we can run
981 985 if self._socket.closed:
982 986 msg = "Task farming is disabled"
983 987 if self._task_scheme == 'pure':
984 988 msg += " because the pure ZMQ scheduler cannot handle"
985 989 msg += " disappearing engines."
986 990 raise RuntimeError(msg)
987 991
988 992 if self._task_scheme == 'pure':
989 993 # pure zmq scheme doesn't support extra features
990 994 msg = "Pure ZMQ scheduler doesn't support the following flags:"
991 995 "follow, after, retries, targets, timeout"
992 996 if (follow or after or retries or targets or timeout):
993 997 # hard fail on Scheduler flags
994 998 raise RuntimeError(msg)
995 999 if isinstance(f, dependent):
996 1000 # soft warn on functional dependencies
997 1001 warnings.warn(msg, RuntimeWarning)
998 1002
999 1003 # build args
1000 1004 args = [] if args is None else args
1001 1005 kwargs = {} if kwargs is None else kwargs
1002 1006 block = self.block if block is None else block
1003 1007 track = self.track if track is None else track
1004 1008 after = self.after if after is None else after
1005 1009 retries = self.retries if retries is None else retries
1006 1010 follow = self.follow if follow is None else follow
1007 1011 timeout = self.timeout if timeout is None else timeout
1008 1012 targets = self.targets if targets is None else targets
1009 1013
1010 1014 if not isinstance(retries, int):
1011 1015 raise TypeError('retries must be int, not %r'%type(retries))
1012 1016
1013 1017 if targets is None:
1014 1018 idents = []
1015 1019 else:
1016 1020 idents = self.client._build_targets(targets)[0]
1017 1021 # ensure *not* bytes
1018 1022 idents = [ ident.decode() for ident in idents ]
1019 1023
1020 1024 after = self._render_dependency(after)
1021 1025 follow = self._render_dependency(follow)
1022 1026 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1023 1027
1024 1028 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1025 1029 subheader=subheader)
1026 1030 tracker = None if track is False else msg['tracker']
1027 1031
1028 1032 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
1029 1033
1030 1034 if block:
1031 1035 try:
1032 1036 return ar.get()
1033 1037 except KeyboardInterrupt:
1034 1038 pass
1035 1039 return ar
1036 1040
1037 1041 @spin_after
1038 1042 @save_ids
1039 1043 def map(self, f, *sequences, **kwargs):
1040 1044 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1041 1045
1042 1046 Parallel version of builtin `map`, load-balanced by this View.
1043 1047
1044 1048 `block`, and `chunksize` can be specified by keyword only.
1045 1049
1046 1050 Each `chunksize` elements will be a separate task, and will be
1047 1051 load-balanced. This lets individual elements be available for iteration
1048 1052 as soon as they arrive.
1049 1053
1050 1054 Parameters
1051 1055 ----------
1052 1056
1053 1057 f : callable
1054 1058 function to be mapped
1055 1059 *sequences: one or more sequences of matching length
1056 1060 the sequences to be distributed and passed to `f`
1057 1061 block : bool [default self.block]
1058 1062 whether to wait for the result or not
1059 1063 track : bool
1060 1064 whether to create a MessageTracker to allow the user to
1061 1065 safely edit after arrays and buffers during non-copying
1062 1066 sends.
1063 1067 chunksize : int [default 1]
1064 1068 how many elements should be in each task.
1065 1069 ordered : bool [default True]
1066 1070 Whether the results should be gathered as they arrive, or enforce
1067 1071 the order of submission.
1068 1072
1069 1073 Only applies when iterating through AsyncMapResult as results arrive.
1070 1074 Has no effect when block=True.
1071 1075
1072 1076 Returns
1073 1077 -------
1074 1078
1075 1079 if block=False:
1076 1080 AsyncMapResult
1077 1081 An object like AsyncResult, but which reassembles the sequence of results
1078 1082 into a single list. AsyncMapResults can be iterated through before all
1079 1083 results are complete.
1080 1084 else:
1081 1085 the result of map(f,*sequences)
1082 1086
1083 1087 """
1084 1088
1085 1089 # default
1086 1090 block = kwargs.get('block', self.block)
1087 1091 chunksize = kwargs.get('chunksize', 1)
1088 1092 ordered = kwargs.get('ordered', True)
1089 1093
1090 1094 keyset = set(kwargs.keys())
1091 1095 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1092 1096 if extra_keys:
1093 1097 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1094 1098
1095 1099 assert len(sequences) > 0, "must have some sequences to map onto!"
1096 1100
1097 1101 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1098 1102 return pf.map(*sequences)
1099 1103
1100 1104 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,422 +1,436 b''
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import time
22 22 from datetime import datetime
23 23 from tempfile import mktemp
24 24
25 25 import zmq
26 26
27 from IPython import parallel
27 28 from IPython.parallel.client import client as clientmod
28 29 from IPython.parallel import error
29 30 from IPython.parallel import AsyncResult, AsyncHubResult
30 31 from IPython.parallel import LoadBalancedView, DirectView
31 32
32 33 from clienttest import ClusterTestCase, segfault, wait, add_engines
33 34
34 35 def setup():
35 36 add_engines(4, total=True)
36 37
37 38 class TestClient(ClusterTestCase):
38 39
39 40 def test_ids(self):
40 41 n = len(self.client.ids)
41 42 self.add_engines(2)
42 43 self.assertEquals(len(self.client.ids), n+2)
43 44
44 45 def test_view_indexing(self):
45 46 """test index access for views"""
46 47 self.minimum_engines(4)
47 48 targets = self.client._build_targets('all')[-1]
48 49 v = self.client[:]
49 50 self.assertEquals(v.targets, targets)
50 51 t = self.client.ids[2]
51 52 v = self.client[t]
52 53 self.assert_(isinstance(v, DirectView))
53 54 self.assertEquals(v.targets, t)
54 55 t = self.client.ids[2:4]
55 56 v = self.client[t]
56 57 self.assert_(isinstance(v, DirectView))
57 58 self.assertEquals(v.targets, t)
58 59 v = self.client[::2]
59 60 self.assert_(isinstance(v, DirectView))
60 61 self.assertEquals(v.targets, targets[::2])
61 62 v = self.client[1::3]
62 63 self.assert_(isinstance(v, DirectView))
63 64 self.assertEquals(v.targets, targets[1::3])
64 65 v = self.client[:-3]
65 66 self.assert_(isinstance(v, DirectView))
66 67 self.assertEquals(v.targets, targets[:-3])
67 68 v = self.client[-1]
68 69 self.assert_(isinstance(v, DirectView))
69 70 self.assertEquals(v.targets, targets[-1])
70 71 self.assertRaises(TypeError, lambda : self.client[None])
71 72
72 73 def test_lbview_targets(self):
73 74 """test load_balanced_view targets"""
74 75 v = self.client.load_balanced_view()
75 76 self.assertEquals(v.targets, None)
76 77 v = self.client.load_balanced_view(-1)
77 78 self.assertEquals(v.targets, [self.client.ids[-1]])
78 79 v = self.client.load_balanced_view('all')
79 80 self.assertEquals(v.targets, None)
80 81
81 82 def test_dview_targets(self):
82 83 """test direct_view targets"""
83 84 v = self.client.direct_view()
84 85 self.assertEquals(v.targets, 'all')
85 86 v = self.client.direct_view('all')
86 87 self.assertEquals(v.targets, 'all')
87 88 v = self.client.direct_view(-1)
88 89 self.assertEquals(v.targets, self.client.ids[-1])
89 90
90 91 def test_lazy_all_targets(self):
91 92 """test lazy evaluation of rc.direct_view('all')"""
92 93 v = self.client.direct_view()
93 94 self.assertEquals(v.targets, 'all')
94 95
95 96 def double(x):
96 97 return x*2
97 98 seq = range(100)
98 99 ref = [ double(x) for x in seq ]
99 100
100 101 # add some engines, which should be used
101 102 self.add_engines(1)
102 103 n1 = len(self.client.ids)
103 104
104 105 # simple apply
105 106 r = v.apply_sync(lambda : 1)
106 107 self.assertEquals(r, [1] * n1)
107 108
108 109 # map goes through remotefunction
109 110 r = v.map_sync(double, seq)
110 111 self.assertEquals(r, ref)
111 112
112 113 # add a couple more engines, and try again
113 114 self.add_engines(2)
114 115 n2 = len(self.client.ids)
115 116 self.assertNotEquals(n2, n1)
116 117
117 118 # apply
118 119 r = v.apply_sync(lambda : 1)
119 120 self.assertEquals(r, [1] * n2)
120 121
121 122 # map
122 123 r = v.map_sync(double, seq)
123 124 self.assertEquals(r, ref)
124 125
125 126 def test_targets(self):
126 127 """test various valid targets arguments"""
127 128 build = self.client._build_targets
128 129 ids = self.client.ids
129 130 idents,targets = build(None)
130 131 self.assertEquals(ids, targets)
131 132
132 133 def test_clear(self):
133 134 """test clear behavior"""
134 135 self.minimum_engines(2)
135 136 v = self.client[:]
136 137 v.block=True
137 138 v.push(dict(a=5))
138 139 v.pull('a')
139 140 id0 = self.client.ids[-1]
140 141 self.client.clear(targets=id0, block=True)
141 142 a = self.client[:-1].get('a')
142 143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
143 144 self.client.clear(block=True)
144 145 for i in self.client.ids:
145 146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
146 147
147 148 def test_get_result(self):
148 149 """test getting results from the Hub."""
149 150 c = clientmod.Client(profile='iptest')
150 151 t = c.ids[-1]
151 152 ar = c[t].apply_async(wait, 1)
152 153 # give the monitor time to notice the message
153 154 time.sleep(.25)
154 155 ahr = self.client.get_result(ar.msg_ids)
155 156 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 157 self.assertEquals(ahr.get(), ar.get())
157 158 ar2 = self.client.get_result(ar.msg_ids)
158 159 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 160 c.close()
160 161
161 162 def test_ids_list(self):
162 163 """test client.ids"""
163 164 ids = self.client.ids
164 165 self.assertEquals(ids, self.client._ids)
165 166 self.assertFalse(ids is self.client._ids)
166 167 ids.remove(ids[-1])
167 168 self.assertNotEquals(ids, self.client._ids)
168 169
169 170 def test_queue_status(self):
170 171 ids = self.client.ids
171 172 id0 = ids[0]
172 173 qs = self.client.queue_status(targets=id0)
173 174 self.assertTrue(isinstance(qs, dict))
174 175 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
175 176 allqs = self.client.queue_status()
176 177 self.assertTrue(isinstance(allqs, dict))
177 178 intkeys = list(allqs.keys())
178 179 intkeys.remove('unassigned')
179 180 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
180 181 unassigned = allqs.pop('unassigned')
181 182 for eid,qs in allqs.items():
182 183 self.assertTrue(isinstance(qs, dict))
183 184 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
184 185
185 186 def test_shutdown(self):
186 187 ids = self.client.ids
187 188 id0 = ids[0]
188 189 self.client.shutdown(id0, block=True)
189 190 while id0 in self.client.ids:
190 191 time.sleep(0.1)
191 192 self.client.spin()
192 193
193 194 self.assertRaises(IndexError, lambda : self.client[id0])
194 195
195 196 def test_result_status(self):
196 197 pass
197 198 # to be written
198 199
199 200 def test_db_query_dt(self):
200 201 """test db query by date"""
201 202 hist = self.client.hub_history()
202 203 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
203 204 tic = middle['submitted']
204 205 before = self.client.db_query({'submitted' : {'$lt' : tic}})
205 206 after = self.client.db_query({'submitted' : {'$gte' : tic}})
206 207 self.assertEquals(len(before)+len(after),len(hist))
207 208 for b in before:
208 209 self.assertTrue(b['submitted'] < tic)
209 210 for a in after:
210 211 self.assertTrue(a['submitted'] >= tic)
211 212 same = self.client.db_query({'submitted' : tic})
212 213 for s in same:
213 214 self.assertTrue(s['submitted'] == tic)
214 215
215 216 def test_db_query_keys(self):
216 217 """test extracting subset of record keys"""
217 218 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
218 219 for rec in found:
219 220 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
220 221
221 222 def test_db_query_default_keys(self):
222 223 """default db_query excludes buffers"""
223 224 found = self.client.db_query({'msg_id': {'$ne' : ''}})
224 225 for rec in found:
225 226 keys = set(rec.keys())
226 227 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
227 228 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
228 229
229 230 def test_db_query_msg_id(self):
230 231 """ensure msg_id is always in db queries"""
231 232 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
232 233 for rec in found:
233 234 self.assertTrue('msg_id' in rec.keys())
234 235 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
235 236 for rec in found:
236 237 self.assertTrue('msg_id' in rec.keys())
237 238 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
238 239 for rec in found:
239 240 self.assertTrue('msg_id' in rec.keys())
240 241
241 242 def test_db_query_get_result(self):
242 243 """pop in db_query shouldn't pop from result itself"""
243 244 self.client[:].apply_sync(lambda : 1)
244 245 found = self.client.db_query({'msg_id': {'$ne' : ''}})
245 246 rc2 = clientmod.Client(profile='iptest')
246 247 # If this bug is not fixed, this call will hang:
247 248 ar = rc2.get_result(self.client.history[-1])
248 249 ar.wait(2)
249 250 self.assertTrue(ar.ready())
250 251 ar.get()
251 252 rc2.close()
252 253
253 254 def test_db_query_in(self):
254 255 """test db query with '$in','$nin' operators"""
255 256 hist = self.client.hub_history()
256 257 even = hist[::2]
257 258 odd = hist[1::2]
258 259 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
259 260 found = [ r['msg_id'] for r in recs ]
260 261 self.assertEquals(set(even), set(found))
261 262 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
262 263 found = [ r['msg_id'] for r in recs ]
263 264 self.assertEquals(set(odd), set(found))
264 265
265 266 def test_hub_history(self):
266 267 hist = self.client.hub_history()
267 268 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
268 269 recdict = {}
269 270 for rec in recs:
270 271 recdict[rec['msg_id']] = rec
271 272
272 273 latest = datetime(1984,1,1)
273 274 for msg_id in hist:
274 275 rec = recdict[msg_id]
275 276 newt = rec['submitted']
276 277 self.assertTrue(newt >= latest)
277 278 latest = newt
278 279 ar = self.client[-1].apply_async(lambda : 1)
279 280 ar.get()
280 281 time.sleep(0.25)
281 282 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
282 283
283 284 def _wait_for_idle(self):
284 285 """wait for an engine to become idle, according to the Hub"""
285 286 rc = self.client
286 287
287 288 # timeout 2s, polling every 100ms
288 289 for i in range(20):
289 290 qs = rc.queue_status()
290 291 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
291 292 time.sleep(0.1)
292 293 else:
293 294 break
294 295
295 296 # ensure Hub up to date:
296 297 qs = rc.queue_status()
297 298 self.assertEquals(qs['unassigned'], 0)
298 299 for eid in rc.ids:
299 300 self.assertEquals(qs[eid]['tasks'], 0)
300 301
301 302
302 303 def test_resubmit(self):
303 304 def f():
304 305 import random
305 306 return random.random()
306 307 v = self.client.load_balanced_view()
307 308 ar = v.apply_async(f)
308 309 r1 = ar.get(1)
309 310 # give the Hub a chance to notice:
310 311 self._wait_for_idle()
311 312 ahr = self.client.resubmit(ar.msg_ids)
312 313 r2 = ahr.get(1)
313 314 self.assertFalse(r1 == r2)
314 315
315 316 def test_resubmit_chain(self):
316 317 """resubmit resubmitted tasks"""
317 318 v = self.client.load_balanced_view()
318 319 ar = v.apply_async(lambda x: x, 'x'*1024)
319 320 ar.get()
320 321 self._wait_for_idle()
321 322 ars = [ar]
322 323
323 324 for i in range(10):
324 325 ar = ars[-1]
325 326 ar2 = self.client.resubmit(ar.msg_ids)
326 327
327 328 [ ar.get() for ar in ars ]
328 329
329 330 def test_resubmit_header(self):
330 331 """resubmit shouldn't clobber the whole header"""
331 332 def f():
332 333 import random
333 334 return random.random()
334 335 v = self.client.load_balanced_view()
335 336 v.retries = 1
336 337 ar = v.apply_async(f)
337 338 r1 = ar.get(1)
338 339 # give the Hub a chance to notice:
339 340 self._wait_for_idle()
340 341 ahr = self.client.resubmit(ar.msg_ids)
341 342 ahr.get(1)
342 343 time.sleep(0.5)
343 344 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
344 345 h1,h2 = [ r['header'] for r in records ]
345 346 for key in set(h1.keys()).union(set(h2.keys())):
346 347 if key in ('msg_id', 'date'):
347 348 self.assertNotEquals(h1[key], h2[key])
348 349 else:
349 350 self.assertEquals(h1[key], h2[key])
350 351
351 352 def test_resubmit_aborted(self):
352 353 def f():
353 354 import random
354 355 return random.random()
355 356 v = self.client.load_balanced_view()
356 357 # restrict to one engine, so we can put a sleep
357 358 # ahead of the task, so it will get aborted
358 359 eid = self.client.ids[-1]
359 360 v.targets = [eid]
360 361 sleep = v.apply_async(time.sleep, 0.5)
361 362 ar = v.apply_async(f)
362 363 ar.abort()
363 364 self.assertRaises(error.TaskAborted, ar.get)
364 365 # Give the Hub a chance to get up to date:
365 366 self._wait_for_idle()
366 367 ahr = self.client.resubmit(ar.msg_ids)
367 368 r2 = ahr.get(1)
368 369
369 370 def test_resubmit_inflight(self):
370 371 """resubmit of inflight task"""
371 372 v = self.client.load_balanced_view()
372 373 ar = v.apply_async(time.sleep,1)
373 374 # give the message a chance to arrive
374 375 time.sleep(0.2)
375 376 ahr = self.client.resubmit(ar.msg_ids)
376 377 ar.get(2)
377 378 ahr.get(2)
378 379
379 380 def test_resubmit_badkey(self):
380 381 """ensure KeyError on resubmit of nonexistant task"""
381 382 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
382 383
383 384 def test_purge_results(self):
384 385 # ensure there are some tasks
385 386 for i in range(5):
386 387 self.client[:].apply_sync(lambda : 1)
387 388 # Wait for the Hub to realise the result is done:
388 389 # This prevents a race condition, where we
389 390 # might purge a result the Hub still thinks is pending.
390 391 time.sleep(0.1)
391 392 rc2 = clientmod.Client(profile='iptest')
392 393 hist = self.client.hub_history()
393 394 ahr = rc2.get_result([hist[-1]])
394 395 ahr.wait(10)
395 396 self.client.purge_results(hist[-1])
396 397 newhist = self.client.hub_history()
397 398 self.assertEquals(len(newhist)+1,len(hist))
398 399 rc2.spin()
399 400 rc2.close()
400 401
401 402 def test_purge_all_results(self):
402 403 self.client.purge_results('all')
403 404 hist = self.client.hub_history()
404 405 self.assertEquals(len(hist), 0)
405 406
406 407 def test_spin_thread(self):
407 408 self.client.spin_thread(0.01)
408 409 ar = self.client[-1].apply_async(lambda : 1)
409 410 time.sleep(0.1)
410 411 self.assertTrue(ar.wall_time < 0.1,
411 412 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
412 413 )
413 414
414 415 def test_stop_spin_thread(self):
415 416 self.client.spin_thread(0.01)
416 417 self.client.stop_spin_thread()
417 418 ar = self.client[-1].apply_async(lambda : 1)
418 419 time.sleep(0.15)
419 420 self.assertTrue(ar.wall_time > 0.1,
420 421 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
421 422 )
422 423
424 def test_activate(self):
425 ip = get_ipython()
426 magics = ip.magics_manager.magics
427 self.assertTrue('px' in magics['line'])
428 self.assertTrue('px' in magics['cell'])
429 v0 = self.client.activate(-1, '0')
430 self.assertTrue('px0' in magics['line'])
431 self.assertTrue('px0' in magics['cell'])
432 self.assertEquals(v0.targets, self.client.ids[-1])
433 v0 = self.client.activate('all', 'all')
434 self.assertTrue('pxall' in magics['line'])
435 self.assertTrue('pxall' in magics['cell'])
436 self.assertEquals(v0.targets, 'all')
@@ -1,340 +1,345 b''
1 1 # -*- coding: utf-8 -*-
2 2 """Test Parallel magics
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import re
20 20 import sys
21 21 import time
22 22
23 23 import zmq
24 24 from nose import SkipTest
25 25
26 26 from IPython.testing import decorators as dec
27 27 from IPython.testing.ipunittest import ParametricTestCase
28 28 from IPython.utils.io import capture_output
29 29
30 30 from IPython import parallel as pmod
31 31 from IPython.parallel import error
32 32 from IPython.parallel import AsyncResult
33 33 from IPython.parallel.util import interactive
34 34
35 35 from IPython.parallel.tests import add_engines
36 36
37 37 from .clienttest import ClusterTestCase, generate_output
38 38
39 39 def setup():
40 40 add_engines(3, total=True)
41 41
42 42 class TestParallelMagics(ClusterTestCase, ParametricTestCase):
43 43
44 44 def test_px_blocking(self):
45 45 ip = get_ipython()
46 46 v = self.client[-1:]
47 47 v.activate()
48 48 v.block=True
49 49
50 50 ip.magic('px a=5')
51 51 self.assertEquals(v['a'], [5])
52 52 ip.magic('px a=10')
53 53 self.assertEquals(v['a'], [10])
54 54 # just 'print a' works ~99% of the time, but this ensures that
55 55 # the stdout message has arrived when the result is finished:
56 56 with capture_output() as io:
57 57 ip.magic(
58 58 'px import sys,time;print(a);sys.stdout.flush();time.sleep(0.2)'
59 59 )
60 60 out = io.stdout
61 61 self.assertTrue('[stdout:' in out, out)
62 62 self.assertFalse('\n\n' in out)
63 63 self.assertTrue(out.rstrip().endswith('10'))
64 64 self.assertRaisesRemote(ZeroDivisionError, ip.magic, 'px 1/0')
65 65
66 66 def _check_generated_stderr(self, stderr, n):
67 67 expected = [
68 68 r'\[stderr:\d+\]',
69 69 '^stderr$',
70 70 '^stderr2$',
71 71 ] * n
72 72
73 73 self.assertFalse('\n\n' in stderr, stderr)
74 74 lines = stderr.splitlines()
75 75 self.assertEquals(len(lines), len(expected), stderr)
76 76 for line,expect in zip(lines, expected):
77 77 if isinstance(expect, str):
78 78 expect = [expect]
79 79 for ex in expect:
80 80 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
81 81
82 82 def test_cellpx_block_args(self):
83 83 """%%px --[no]block flags work"""
84 84 ip = get_ipython()
85 85 v = self.client[-1:]
86 86 v.activate()
87 87 v.block=False
88 88
89 89 for block in (True, False):
90 90 v.block = block
91 91
92 92 with capture_output() as io:
93 93 ip.run_cell_magic("px", "", "1")
94 94 if block:
95 95 self.assertTrue(io.stdout.startswith("Parallel"), io.stdout)
96 96 else:
97 97 self.assertTrue(io.stdout.startswith("Async"), io.stdout)
98 98
99 99 with capture_output() as io:
100 100 ip.run_cell_magic("px", "--block", "1")
101 101 self.assertTrue(io.stdout.startswith("Parallel"), io.stdout)
102 102
103 103 with capture_output() as io:
104 104 ip.run_cell_magic("px", "--noblock", "1")
105 105 self.assertTrue(io.stdout.startswith("Async"), io.stdout)
106 106
107 107 def test_cellpx_groupby_engine(self):
108 108 """%%px --group-outputs=engine"""
109 109 ip = get_ipython()
110 110 v = self.client[:]
111 111 v.block = True
112 112 v.activate()
113 113
114 114 v['generate_output'] = generate_output
115 115
116 116 with capture_output() as io:
117 117 ip.run_cell_magic('px', '--group-outputs=engine', 'generate_output()')
118 118
119 119 self.assertFalse('\n\n' in io.stdout)
120 120 lines = io.stdout.splitlines()[1:]
121 121 expected = [
122 122 r'\[stdout:\d+\]',
123 123 'stdout',
124 124 'stdout2',
125 125 r'\[output:\d+\]',
126 126 r'IPython\.core\.display\.HTML',
127 127 r'IPython\.core\.display\.Math',
128 128 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math',
129 129 ] * len(v)
130 130
131 131 self.assertEquals(len(lines), len(expected), io.stdout)
132 132 for line,expect in zip(lines, expected):
133 133 if isinstance(expect, str):
134 134 expect = [expect]
135 135 for ex in expect:
136 136 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
137 137
138 138 self._check_generated_stderr(io.stderr, len(v))
139 139
140 140
141 141 def test_cellpx_groupby_order(self):
142 142 """%%px --group-outputs=order"""
143 143 ip = get_ipython()
144 144 v = self.client[:]
145 145 v.block = True
146 146 v.activate()
147 147
148 148 v['generate_output'] = generate_output
149 149
150 150 with capture_output() as io:
151 151 ip.run_cell_magic('px', '--group-outputs=order', 'generate_output()')
152 152
153 153 self.assertFalse('\n\n' in io.stdout)
154 154 lines = io.stdout.splitlines()[1:]
155 155 expected = []
156 156 expected.extend([
157 157 r'\[stdout:\d+\]',
158 158 'stdout',
159 159 'stdout2',
160 160 ] * len(v))
161 161 expected.extend([
162 162 r'\[output:\d+\]',
163 163 'IPython.core.display.HTML',
164 164 ] * len(v))
165 165 expected.extend([
166 166 r'\[output:\d+\]',
167 167 'IPython.core.display.Math',
168 168 ] * len(v))
169 169 expected.extend([
170 170 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math'
171 171 ] * len(v))
172 172
173 173 self.assertEquals(len(lines), len(expected), io.stdout)
174 174 for line,expect in zip(lines, expected):
175 175 if isinstance(expect, str):
176 176 expect = [expect]
177 177 for ex in expect:
178 178 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
179 179
180 180 self._check_generated_stderr(io.stderr, len(v))
181 181
182 182 def test_cellpx_groupby_type(self):
183 183 """%%px --group-outputs=type"""
184 184 ip = get_ipython()
185 185 v = self.client[:]
186 186 v.block = True
187 187 v.activate()
188 188
189 189 v['generate_output'] = generate_output
190 190
191 191 with capture_output() as io:
192 192 ip.run_cell_magic('px', '--group-outputs=type', 'generate_output()')
193 193
194 194 self.assertFalse('\n\n' in io.stdout)
195 195 lines = io.stdout.splitlines()[1:]
196 196
197 197 expected = []
198 198 expected.extend([
199 199 r'\[stdout:\d+\]',
200 200 'stdout',
201 201 'stdout2',
202 202 ] * len(v))
203 203 expected.extend([
204 204 r'\[output:\d+\]',
205 205 r'IPython\.core\.display\.HTML',
206 206 r'IPython\.core\.display\.Math',
207 207 ] * len(v))
208 208 expected.extend([
209 209 (r'Out\[\d+:\d+\]', r'IPython\.core\.display\.Math')
210 210 ] * len(v))
211 211
212 212 self.assertEquals(len(lines), len(expected), io.stdout)
213 213 for line,expect in zip(lines, expected):
214 214 if isinstance(expect, str):
215 215 expect = [expect]
216 216 for ex in expect:
217 217 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
218 218
219 219 self._check_generated_stderr(io.stderr, len(v))
220 220
221 221
222 222 def test_px_nonblocking(self):
223 223 ip = get_ipython()
224 224 v = self.client[-1:]
225 225 v.activate()
226 226 v.block=False
227 227
228 228 ip.magic('px a=5')
229 229 self.assertEquals(v['a'], [5])
230 230 ip.magic('px a=10')
231 231 self.assertEquals(v['a'], [10])
232 232 with capture_output() as io:
233 233 ar = ip.magic('px print (a)')
234 234 self.assertTrue(isinstance(ar, AsyncResult))
235 235 self.assertTrue('Async' in io.stdout)
236 236 self.assertFalse('[stdout:' in io.stdout)
237 237 self.assertFalse('\n\n' in io.stdout)
238 238
239 239 ar = ip.magic('px 1/0')
240 240 self.assertRaisesRemote(ZeroDivisionError, ar.get)
241 241
242 242 def test_autopx_blocking(self):
243 243 ip = get_ipython()
244 244 v = self.client[-1]
245 245 v.activate()
246 246 v.block=True
247 247
248 248 with capture_output() as io:
249 249 ip.magic('autopx')
250 250 ip.run_cell('\n'.join(('a=5','b=12345','c=0')))
251 251 ip.run_cell('b*=2')
252 252 ip.run_cell('print (b)')
253 253 ip.run_cell('b')
254 254 ip.run_cell("b/c")
255 255 ip.magic('autopx')
256 256
257 257 output = io.stdout
258 258
259 259 self.assertTrue(output.startswith('%autopx enabled'), output)
260 260 self.assertTrue(output.rstrip().endswith('%autopx disabled'), output)
261 261 self.assertTrue('ZeroDivisionError' in output, output)
262 262 self.assertTrue('\nOut[' in output, output)
263 263 self.assertTrue(': 24690' in output, output)
264 264 ar = v.get_result(-1)
265 265 self.assertEquals(v['a'], 5)
266 266 self.assertEquals(v['b'], 24690)
267 267 self.assertRaisesRemote(ZeroDivisionError, ar.get)
268 268
269 269 def test_autopx_nonblocking(self):
270 270 ip = get_ipython()
271 271 v = self.client[-1]
272 272 v.activate()
273 273 v.block=False
274 274
275 275 with capture_output() as io:
276 276 ip.magic('autopx')
277 277 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
278 278 ip.run_cell('print (b)')
279 279 ip.run_cell('import time; time.sleep(0.1)')
280 280 ip.run_cell("b/c")
281 281 ip.run_cell('b*=2')
282 282 ip.magic('autopx')
283 283
284 284 output = io.stdout.rstrip()
285 285
286 286 self.assertTrue(output.startswith('%autopx enabled'))
287 287 self.assertTrue(output.endswith('%autopx disabled'))
288 288 self.assertFalse('ZeroDivisionError' in output)
289 289 ar = v.get_result(-2)
290 290 self.assertRaisesRemote(ZeroDivisionError, ar.get)
291 291 # prevent TaskAborted on pulls, due to ZeroDivisionError
292 292 time.sleep(0.5)
293 293 self.assertEquals(v['a'], 5)
294 294 # b*=2 will not fire, due to abort
295 295 self.assertEquals(v['b'], 10)
296 296
297 297 def test_result(self):
298 298 ip = get_ipython()
299 299 v = self.client[-1]
300 300 v.activate()
301 301 data = dict(a=111,b=222)
302 302 v.push(data, block=True)
303 303
304 ip.magic('px a')
305 ip.magic('px b')
306 for idx, name in [
307 ('', 'b'),
308 ('-1', 'b'),
309 ('2', 'b'),
310 ('1', 'a'),
311 ('-2', 'a'),
312 ]:
304 for name in ('a', 'b'):
305 ip.magic('px ' + name)
313 306 with capture_output() as io:
314 ip.magic('result ' + idx)
307 ip.magic('pxresult')
315 308 output = io.stdout
316 309 msg = "expected %s output to include %s, but got: %s" % \
317 ('%result '+idx, str(data[name]), output)
310 ('%pxresult', str(data[name]), output)
318 311 self.assertTrue(str(data[name]) in output, msg)
319 312
320 313 @dec.skipif_not_matplotlib
321 314 def test_px_pylab(self):
322 315 """%pylab works on engines"""
323 316 ip = get_ipython()
324 317 v = self.client[-1]
325 318 v.block = True
326 319 v.activate()
327 320
328 321 with capture_output() as io:
329 322 ip.magic("px %pylab inline")
330 323
331 324 self.assertTrue("Welcome to pylab" in io.stdout, io.stdout)
332 325 self.assertTrue("backend_inline" in io.stdout, io.stdout)
333 326
334 327 with capture_output() as io:
335 328 ip.magic("px plot(rand(100))")
336 329
337 330 self.assertTrue('Out[' in io.stdout, io.stdout)
338 331 self.assertTrue('matplotlib.lines' in io.stdout, io.stdout)
332
333 def test_pxconfig(self):
334 ip = get_ipython()
335 rc = self.client
336 v = rc.activate(-1, '_tst')
337 self.assertEquals(v.targets, rc.ids[-1])
338 ip.magic("%pxconfig_tst -t :")
339 self.assertEquals(v.targets, rc.ids)
340 ip.magic("%pxconfig_tst --block")
341 self.assertEquals(v.block, True)
342 ip.magic("%pxconfig_tst --noblock")
343 self.assertEquals(v.block, False)
339 344
340 345
General Comments 0
You need to be logged in to leave comments. Login now