##// END OF EJS Templates
view.abort() aborts all outstanding tasks...
MinRK -
Show More
@@ -1,1440 +1,1443 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 import time
22 22 import warnings
23 23 from datetime import datetime
24 24 from getpass import getpass
25 25 from pprint import pprint
26 26
27 27 pjoin = os.path.join
28 28
29 29 import zmq
30 30 # from zmq.eventloop import ioloop, zmqstream
31 31
32 32 from IPython.config.configurable import MultipleInstanceError
33 33 from IPython.core.application import BaseIPythonApplication
34 34
35 35 from IPython.utils.jsonutil import rekey
36 36 from IPython.utils.localinterfaces import LOCAL_IPS
37 37 from IPython.utils.path import get_ipython_dir
38 38 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
39 39 Dict, List, Bool, Set)
40 40 from IPython.external.decorator import decorator
41 41 from IPython.external.ssh import tunnel
42 42
43 43 from IPython.parallel import error
44 44 from IPython.parallel import util
45 45
46 46 from IPython.zmq.session import Session, Message
47 47
48 48 from .asyncresult import AsyncResult, AsyncHubResult
49 49 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 50 from .view import DirectView, LoadBalancedView
51 51
52 52 if sys.version_info[0] >= 3:
53 53 # xrange is used in a couple 'isinstance' tests in py2
54 54 # should be just 'range' in 3k
55 55 xrange = range
56 56
57 57 #--------------------------------------------------------------------------
58 58 # Decorators for Client methods
59 59 #--------------------------------------------------------------------------
60 60
61 61 @decorator
62 62 def spin_first(f, self, *args, **kwargs):
63 63 """Call spin() to sync state prior to calling the method."""
64 64 self.spin()
65 65 return f(self, *args, **kwargs)
66 66
67 67
68 68 #--------------------------------------------------------------------------
69 69 # Classes
70 70 #--------------------------------------------------------------------------
71 71
72 72 class Metadata(dict):
73 73 """Subclass of dict for initializing metadata values.
74 74
75 75 Attribute access works on keys.
76 76
77 77 These objects have a strict set of keys - errors will raise if you try
78 78 to add new keys.
79 79 """
80 80 def __init__(self, *args, **kwargs):
81 81 dict.__init__(self)
82 82 md = {'msg_id' : None,
83 83 'submitted' : None,
84 84 'started' : None,
85 85 'completed' : None,
86 86 'received' : None,
87 87 'engine_uuid' : None,
88 88 'engine_id' : None,
89 89 'follow' : None,
90 90 'after' : None,
91 91 'status' : None,
92 92
93 93 'pyin' : None,
94 94 'pyout' : None,
95 95 'pyerr' : None,
96 96 'stdout' : '',
97 97 'stderr' : '',
98 98 }
99 99 self.update(md)
100 100 self.update(dict(*args, **kwargs))
101 101
102 102 def __getattr__(self, key):
103 103 """getattr aliased to getitem"""
104 104 if key in self.iterkeys():
105 105 return self[key]
106 106 else:
107 107 raise AttributeError(key)
108 108
109 109 def __setattr__(self, key, value):
110 110 """setattr aliased to setitem, with strict"""
111 111 if key in self.iterkeys():
112 112 self[key] = value
113 113 else:
114 114 raise AttributeError(key)
115 115
116 116 def __setitem__(self, key, value):
117 117 """strict static key enforcement"""
118 118 if key in self.iterkeys():
119 119 dict.__setitem__(self, key, value)
120 120 else:
121 121 raise KeyError(key)
122 122
123 123
124 124 class Client(HasTraits):
125 125 """A semi-synchronous client to the IPython ZMQ cluster
126 126
127 127 Parameters
128 128 ----------
129 129
130 130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 131 Connection information for the Hub's registration. If a json connector
132 132 file is given, then likely no further configuration is necessary.
133 133 [Default: use profile]
134 134 profile : bytes
135 135 The name of the Cluster profile to be used to find connector information.
136 136 If run from an IPython application, the default profile will be the same
137 137 as the running application, otherwise it will be 'default'.
138 138 context : zmq.Context
139 139 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 140 debug : bool
141 141 flag for lots of message printing for debug purposes
142 142 timeout : int/float
143 143 time (in seconds) to wait for connection replies from the Hub
144 144 [Default: 10]
145 145
146 146 #-------------- session related args ----------------
147 147
148 148 config : Config object
149 149 If specified, this will be relayed to the Session for configuration
150 150 username : str
151 151 set username for the session object
152 152 packer : str (import_string) or callable
153 153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 154 function to serialize messages. Must support same input as
155 155 JSON, and output must be bytes.
156 156 You can pass a callable directly as `pack`
157 157 unpacker : str (import_string) or callable
158 158 The inverse of packer. Only necessary if packer is specified as *not* one
159 159 of 'json' or 'pickle'.
160 160
161 161 #-------------- ssh related args ----------------
162 162 # These are args for configuring the ssh tunnel to be used
163 163 # credentials are used to forward connections over ssh to the Controller
164 164 # Note that the ip given in `addr` needs to be relative to sshserver
165 165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 166 # and set sshserver as the same machine the Controller is on. However,
167 167 # the only requirement is that sshserver is able to see the Controller
168 168 # (i.e. is within the same trusted network).
169 169
170 170 sshserver : str
171 171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 172 If keyfile or password is specified, and this is not, it will default to
173 173 the ip given in addr.
174 174 sshkey : str; path to ssh private key file
175 175 This specifies a key to be used in ssh login, default None.
176 176 Regular default ssh keys will be used without specifying this argument.
177 177 password : str
178 178 Your ssh password to sshserver. Note that if this is left None,
179 179 you will be prompted for it if passwordless key based login is unavailable.
180 180 paramiko : bool
181 181 flag for whether to use paramiko instead of shell ssh for tunneling.
182 182 [default: True on win32, False else]
183 183
184 184 ------- exec authentication args -------
185 185 If even localhost is untrusted, you can have some protection against
186 186 unauthorized execution by signing messages with HMAC digests.
187 187 Messages are still sent as cleartext, so if someone can snoop your
188 188 loopback traffic this will not protect your privacy, but will prevent
189 189 unauthorized execution.
190 190
191 191 exec_key : str
192 192 an authentication key or file containing a key
193 193 default: None
194 194
195 195
196 196 Attributes
197 197 ----------
198 198
199 199 ids : list of int engine IDs
200 200 requesting the ids attribute always synchronizes
201 201 the registration state. To request ids without synchronization,
202 202 use semi-private _ids attributes.
203 203
204 204 history : list of msg_ids
205 205 a list of msg_ids, keeping track of all the execution
206 206 messages you have submitted in order.
207 207
208 208 outstanding : set of msg_ids
209 209 a set of msg_ids that have been submitted, but whose
210 210 results have not yet been received.
211 211
212 212 results : dict
213 213 a dict of all our results, keyed by msg_id
214 214
215 215 block : bool
216 216 determines default behavior when block not specified
217 217 in execution methods
218 218
219 219 Methods
220 220 -------
221 221
222 222 spin
223 223 flushes incoming results and registration state changes
224 224 control methods spin, and requesting `ids` also ensures up to date
225 225
226 226 wait
227 227 wait on one or more msg_ids
228 228
229 229 execution methods
230 230 apply
231 231 legacy: execute, run
232 232
233 233 data movement
234 234 push, pull, scatter, gather
235 235
236 236 query methods
237 237 queue_status, get_result, purge, result_status
238 238
239 239 control methods
240 240 abort, shutdown
241 241
242 242 """
243 243
244 244
245 245 block = Bool(False)
246 246 outstanding = Set()
247 247 results = Instance('collections.defaultdict', (dict,))
248 248 metadata = Instance('collections.defaultdict', (Metadata,))
249 249 history = List()
250 250 debug = Bool(False)
251 251
252 252 profile=Unicode()
253 253 def _profile_default(self):
254 254 if BaseIPythonApplication.initialized():
255 255 # an IPython app *might* be running, try to get its profile
256 256 try:
257 257 return BaseIPythonApplication.instance().profile
258 258 except (AttributeError, MultipleInstanceError):
259 259 # could be a *different* subclass of config.Application,
260 260 # which would raise one of these two errors.
261 261 return u'default'
262 262 else:
263 263 return u'default'
264 264
265 265
266 266 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 267 _ids = List()
268 268 _connected=Bool(False)
269 269 _ssh=Bool(False)
270 270 _context = Instance('zmq.Context')
271 271 _config = Dict()
272 272 _engines=Instance(util.ReverseDict, (), {})
273 273 # _hub_socket=Instance('zmq.Socket')
274 274 _query_socket=Instance('zmq.Socket')
275 275 _control_socket=Instance('zmq.Socket')
276 276 _iopub_socket=Instance('zmq.Socket')
277 277 _notification_socket=Instance('zmq.Socket')
278 278 _mux_socket=Instance('zmq.Socket')
279 279 _task_socket=Instance('zmq.Socket')
280 280 _task_scheme=Unicode()
281 281 _closed = False
282 282 _ignored_control_replies=Integer(0)
283 283 _ignored_hub_replies=Integer(0)
284 284
285 285 def __new__(self, *args, **kw):
286 286 # don't raise on positional args
287 287 return HasTraits.__new__(self, **kw)
288 288
289 289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 290 context=None, debug=False, exec_key=None,
291 291 sshserver=None, sshkey=None, password=None, paramiko=None,
292 292 timeout=10, **extra_args
293 293 ):
294 294 if profile:
295 295 super(Client, self).__init__(debug=debug, profile=profile)
296 296 else:
297 297 super(Client, self).__init__(debug=debug)
298 298 if context is None:
299 299 context = zmq.Context.instance()
300 300 self._context = context
301 301
302 302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 303 if self._cd is not None:
304 304 if url_or_file is None:
305 305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 307 " Please specify at least one of url_or_file or profile."
308 308
309 309 if not util.is_url(url_or_file):
310 310 # it's not a url, try for a file
311 311 if not os.path.exists(url_or_file):
312 312 if self._cd:
313 313 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
314 314 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
315 315 with open(url_or_file) as f:
316 316 cfg = json.loads(f.read())
317 317 else:
318 318 cfg = {'url':url_or_file}
319 319
320 320 # sync defaults from args, json:
321 321 if sshserver:
322 322 cfg['ssh'] = sshserver
323 323 if exec_key:
324 324 cfg['exec_key'] = exec_key
325 325 exec_key = cfg['exec_key']
326 326 location = cfg.setdefault('location', None)
327 327 cfg['url'] = util.disambiguate_url(cfg['url'], location)
328 328 url = cfg['url']
329 329 proto,addr,port = util.split_url(url)
330 330 if location is not None and addr == '127.0.0.1':
331 331 # location specified, and connection is expected to be local
332 332 if location not in LOCAL_IPS and not sshserver:
333 333 # load ssh from JSON *only* if the controller is not on
334 334 # this machine
335 335 sshserver=cfg['ssh']
336 336 if location not in LOCAL_IPS and not sshserver:
337 337 # warn if no ssh specified, but SSH is probably needed
338 338 # This is only a warning, because the most likely cause
339 339 # is a local Controller on a laptop whose IP is dynamic
340 340 warnings.warn("""
341 341 Controller appears to be listening on localhost, but not on this machine.
342 342 If this is true, you should specify Client(...,sshserver='you@%s')
343 343 or instruct your controller to listen on an external IP."""%location,
344 344 RuntimeWarning)
345 345 elif not sshserver:
346 346 # otherwise sync with cfg
347 347 sshserver = cfg['ssh']
348 348
349 349 self._config = cfg
350 350
351 351 self._ssh = bool(sshserver or sshkey or password)
352 352 if self._ssh and sshserver is None:
353 353 # default to ssh via localhost
354 354 sshserver = url.split('://')[1].split(':')[0]
355 355 if self._ssh and password is None:
356 356 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
357 357 password=False
358 358 else:
359 359 password = getpass("SSH Password for %s: "%sshserver)
360 360 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
361 361
362 362 # configure and construct the session
363 363 if exec_key is not None:
364 364 if os.path.isfile(exec_key):
365 365 extra_args['keyfile'] = exec_key
366 366 else:
367 367 exec_key = util.asbytes(exec_key)
368 368 extra_args['key'] = exec_key
369 369 self.session = Session(**extra_args)
370 370
371 371 self._query_socket = self._context.socket(zmq.DEALER)
372 372 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
373 373 if self._ssh:
374 374 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
375 375 else:
376 376 self._query_socket.connect(url)
377 377
378 378 self.session.debug = self.debug
379 379
380 380 self._notification_handlers = {'registration_notification' : self._register_engine,
381 381 'unregistration_notification' : self._unregister_engine,
382 382 'shutdown_notification' : lambda msg: self.close(),
383 383 }
384 384 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
385 385 'apply_reply' : self._handle_apply_reply}
386 386 self._connect(sshserver, ssh_kwargs, timeout)
387 387
388 388 def __del__(self):
389 389 """cleanup sockets, but _not_ context."""
390 390 self.close()
391 391
392 392 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
393 393 if ipython_dir is None:
394 394 ipython_dir = get_ipython_dir()
395 395 if profile_dir is not None:
396 396 try:
397 397 self._cd = ProfileDir.find_profile_dir(profile_dir)
398 398 return
399 399 except ProfileDirError:
400 400 pass
401 401 elif profile is not None:
402 402 try:
403 403 self._cd = ProfileDir.find_profile_dir_by_name(
404 404 ipython_dir, profile)
405 405 return
406 406 except ProfileDirError:
407 407 pass
408 408 self._cd = None
409 409
410 410 def _update_engines(self, engines):
411 411 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
412 412 for k,v in engines.iteritems():
413 413 eid = int(k)
414 414 self._engines[eid] = v
415 415 self._ids.append(eid)
416 416 self._ids = sorted(self._ids)
417 417 if sorted(self._engines.keys()) != range(len(self._engines)) and \
418 418 self._task_scheme == 'pure' and self._task_socket:
419 419 self._stop_scheduling_tasks()
420 420
421 421 def _stop_scheduling_tasks(self):
422 422 """Stop scheduling tasks because an engine has been unregistered
423 423 from a pure ZMQ scheduler.
424 424 """
425 425 self._task_socket.close()
426 426 self._task_socket = None
427 427 msg = "An engine has been unregistered, and we are using pure " +\
428 428 "ZMQ task scheduling. Task farming will be disabled."
429 429 if self.outstanding:
430 430 msg += " If you were running tasks when this happened, " +\
431 431 "some `outstanding` msg_ids may never resolve."
432 432 warnings.warn(msg, RuntimeWarning)
433 433
434 434 def _build_targets(self, targets):
435 435 """Turn valid target IDs or 'all' into two lists:
436 436 (int_ids, uuids).
437 437 """
438 438 if not self._ids:
439 439 # flush notification socket if no engines yet, just in case
440 440 if not self.ids:
441 441 raise error.NoEnginesRegistered("Can't build targets without any engines")
442 442
443 443 if targets is None:
444 444 targets = self._ids
445 445 elif isinstance(targets, basestring):
446 446 if targets.lower() == 'all':
447 447 targets = self._ids
448 448 else:
449 449 raise TypeError("%r not valid str target, must be 'all'"%(targets))
450 450 elif isinstance(targets, int):
451 451 if targets < 0:
452 452 targets = self.ids[targets]
453 453 if targets not in self._ids:
454 454 raise IndexError("No such engine: %i"%targets)
455 455 targets = [targets]
456 456
457 457 if isinstance(targets, slice):
458 458 indices = range(len(self._ids))[targets]
459 459 ids = self.ids
460 460 targets = [ ids[i] for i in indices ]
461 461
462 462 if not isinstance(targets, (tuple, list, xrange)):
463 463 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
464 464
465 465 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
466 466
467 467 def _connect(self, sshserver, ssh_kwargs, timeout):
468 468 """setup all our socket connections to the cluster. This is called from
469 469 __init__."""
470 470
471 471 # Maybe allow reconnecting?
472 472 if self._connected:
473 473 return
474 474 self._connected=True
475 475
476 476 def connect_socket(s, url):
477 477 url = util.disambiguate_url(url, self._config['location'])
478 478 if self._ssh:
479 479 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
480 480 else:
481 481 return s.connect(url)
482 482
483 483 self.session.send(self._query_socket, 'connection_request')
484 484 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
485 485 poller = zmq.Poller()
486 486 poller.register(self._query_socket, zmq.POLLIN)
487 487 # poll expects milliseconds, timeout is seconds
488 488 evts = poller.poll(timeout*1000)
489 489 if not evts:
490 490 raise error.TimeoutError("Hub connection request timed out")
491 491 idents,msg = self.session.recv(self._query_socket,mode=0)
492 492 if self.debug:
493 493 pprint(msg)
494 494 msg = Message(msg)
495 495 content = msg.content
496 496 self._config['registration'] = dict(content)
497 497 if content.status == 'ok':
498 498 ident = self.session.bsession
499 499 if content.mux:
500 500 self._mux_socket = self._context.socket(zmq.DEALER)
501 501 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
502 502 connect_socket(self._mux_socket, content.mux)
503 503 if content.task:
504 504 self._task_scheme, task_addr = content.task
505 505 self._task_socket = self._context.socket(zmq.DEALER)
506 506 self._task_socket.setsockopt(zmq.IDENTITY, ident)
507 507 connect_socket(self._task_socket, task_addr)
508 508 if content.notification:
509 509 self._notification_socket = self._context.socket(zmq.SUB)
510 510 connect_socket(self._notification_socket, content.notification)
511 511 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
512 512 # if content.query:
513 513 # self._query_socket = self._context.socket(zmq.DEALER)
514 514 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
515 515 # connect_socket(self._query_socket, content.query)
516 516 if content.control:
517 517 self._control_socket = self._context.socket(zmq.DEALER)
518 518 self._control_socket.setsockopt(zmq.IDENTITY, ident)
519 519 connect_socket(self._control_socket, content.control)
520 520 if content.iopub:
521 521 self._iopub_socket = self._context.socket(zmq.SUB)
522 522 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
523 523 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
524 524 connect_socket(self._iopub_socket, content.iopub)
525 525 self._update_engines(dict(content.engines))
526 526 else:
527 527 self._connected = False
528 528 raise Exception("Failed to connect!")
529 529
530 530 #--------------------------------------------------------------------------
531 531 # handlers and callbacks for incoming messages
532 532 #--------------------------------------------------------------------------
533 533
534 534 def _unwrap_exception(self, content):
535 535 """unwrap exception, and remap engine_id to int."""
536 536 e = error.unwrap_exception(content)
537 537 # print e.traceback
538 538 if e.engine_info:
539 539 e_uuid = e.engine_info['engine_uuid']
540 540 eid = self._engines[e_uuid]
541 541 e.engine_info['engine_id'] = eid
542 542 return e
543 543
544 544 def _extract_metadata(self, header, parent, content):
545 545 md = {'msg_id' : parent['msg_id'],
546 546 'received' : datetime.now(),
547 547 'engine_uuid' : header.get('engine', None),
548 548 'follow' : parent.get('follow', []),
549 549 'after' : parent.get('after', []),
550 550 'status' : content['status'],
551 551 }
552 552
553 553 if md['engine_uuid'] is not None:
554 554 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
555 555
556 556 if 'date' in parent:
557 557 md['submitted'] = parent['date']
558 558 if 'started' in header:
559 559 md['started'] = header['started']
560 560 if 'date' in header:
561 561 md['completed'] = header['date']
562 562 return md
563 563
564 564 def _register_engine(self, msg):
565 565 """Register a new engine, and update our connection info."""
566 566 content = msg['content']
567 567 eid = content['id']
568 568 d = {eid : content['queue']}
569 569 self._update_engines(d)
570 570
571 571 def _unregister_engine(self, msg):
572 572 """Unregister an engine that has died."""
573 573 content = msg['content']
574 574 eid = int(content['id'])
575 575 if eid in self._ids:
576 576 self._ids.remove(eid)
577 577 uuid = self._engines.pop(eid)
578 578
579 579 self._handle_stranded_msgs(eid, uuid)
580 580
581 581 if self._task_socket and self._task_scheme == 'pure':
582 582 self._stop_scheduling_tasks()
583 583
584 584 def _handle_stranded_msgs(self, eid, uuid):
585 585 """Handle messages known to be on an engine when the engine unregisters.
586 586
587 587 It is possible that this will fire prematurely - that is, an engine will
588 588 go down after completing a result, and the client will be notified
589 589 of the unregistration and later receive the successful result.
590 590 """
591 591
592 592 outstanding = self._outstanding_dict[uuid]
593 593
594 594 for msg_id in list(outstanding):
595 595 if msg_id in self.results:
596 596 # we already
597 597 continue
598 598 try:
599 599 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
600 600 except:
601 601 content = error.wrap_exception()
602 602 # build a fake message:
603 603 parent = {}
604 604 header = {}
605 605 parent['msg_id'] = msg_id
606 606 header['engine'] = uuid
607 607 header['date'] = datetime.now()
608 608 msg = dict(parent_header=parent, header=header, content=content)
609 609 self._handle_apply_reply(msg)
610 610
611 611 def _handle_execute_reply(self, msg):
612 612 """Save the reply to an execute_request into our results.
613 613
614 614 execute messages are never actually used. apply is used instead.
615 615 """
616 616
617 617 parent = msg['parent_header']
618 618 msg_id = parent['msg_id']
619 619 if msg_id not in self.outstanding:
620 620 if msg_id in self.history:
621 621 print ("got stale result: %s"%msg_id)
622 622 else:
623 623 print ("got unknown result: %s"%msg_id)
624 624 else:
625 625 self.outstanding.remove(msg_id)
626 626 self.results[msg_id] = self._unwrap_exception(msg['content'])
627 627
628 628 def _handle_apply_reply(self, msg):
629 629 """Save the reply to an apply_request into our results."""
630 630 parent = msg['parent_header']
631 631 msg_id = parent['msg_id']
632 632 if msg_id not in self.outstanding:
633 633 if msg_id in self.history:
634 634 print ("got stale result: %s"%msg_id)
635 635 print self.results[msg_id]
636 636 print msg
637 637 else:
638 638 print ("got unknown result: %s"%msg_id)
639 639 else:
640 640 self.outstanding.remove(msg_id)
641 641 content = msg['content']
642 642 header = msg['header']
643 643
644 644 # construct metadata:
645 645 md = self.metadata[msg_id]
646 646 md.update(self._extract_metadata(header, parent, content))
647 647 # is this redundant?
648 648 self.metadata[msg_id] = md
649 649
650 650 e_outstanding = self._outstanding_dict[md['engine_uuid']]
651 651 if msg_id in e_outstanding:
652 652 e_outstanding.remove(msg_id)
653 653
654 654 # construct result:
655 655 if content['status'] == 'ok':
656 656 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
657 657 elif content['status'] == 'aborted':
658 658 self.results[msg_id] = error.TaskAborted(msg_id)
659 659 elif content['status'] == 'resubmitted':
660 660 # TODO: handle resubmission
661 661 pass
662 662 else:
663 663 self.results[msg_id] = self._unwrap_exception(content)
664 664
665 665 def _flush_notifications(self):
666 666 """Flush notifications of engine registrations waiting
667 667 in ZMQ queue."""
668 668 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
669 669 while msg is not None:
670 670 if self.debug:
671 671 pprint(msg)
672 672 msg_type = msg['header']['msg_type']
673 673 handler = self._notification_handlers.get(msg_type, None)
674 674 if handler is None:
675 675 raise Exception("Unhandled message type: %s"%msg.msg_type)
676 676 else:
677 677 handler(msg)
678 678 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
679 679
680 680 def _flush_results(self, sock):
681 681 """Flush task or queue results waiting in ZMQ queue."""
682 682 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
683 683 while msg is not None:
684 684 if self.debug:
685 685 pprint(msg)
686 686 msg_type = msg['header']['msg_type']
687 687 handler = self._queue_handlers.get(msg_type, None)
688 688 if handler is None:
689 689 raise Exception("Unhandled message type: %s"%msg.msg_type)
690 690 else:
691 691 handler(msg)
692 692 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
693 693
694 694 def _flush_control(self, sock):
695 695 """Flush replies from the control channel waiting
696 696 in the ZMQ queue.
697 697
698 698 Currently: ignore them."""
699 699 if self._ignored_control_replies <= 0:
700 700 return
701 701 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
702 702 while msg is not None:
703 703 self._ignored_control_replies -= 1
704 704 if self.debug:
705 705 pprint(msg)
706 706 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
707 707
708 708 def _flush_ignored_control(self):
709 709 """flush ignored control replies"""
710 710 while self._ignored_control_replies > 0:
711 711 self.session.recv(self._control_socket)
712 712 self._ignored_control_replies -= 1
713 713
714 714 def _flush_ignored_hub_replies(self):
715 715 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
716 716 while msg is not None:
717 717 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
718 718
719 719 def _flush_iopub(self, sock):
720 720 """Flush replies from the iopub channel waiting
721 721 in the ZMQ queue.
722 722 """
723 723 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
724 724 while msg is not None:
725 725 if self.debug:
726 726 pprint(msg)
727 727 parent = msg['parent_header']
728 728 msg_id = parent['msg_id']
729 729 content = msg['content']
730 730 header = msg['header']
731 731 msg_type = msg['header']['msg_type']
732 732
733 733 # init metadata:
734 734 md = self.metadata[msg_id]
735 735
736 736 if msg_type == 'stream':
737 737 name = content['name']
738 738 s = md[name] or ''
739 739 md[name] = s + content['data']
740 740 elif msg_type == 'pyerr':
741 741 md.update({'pyerr' : self._unwrap_exception(content)})
742 742 elif msg_type == 'pyin':
743 743 md.update({'pyin' : content['code']})
744 744 else:
745 745 md.update({msg_type : content.get('data', '')})
746 746
747 747 # reduntant?
748 748 self.metadata[msg_id] = md
749 749
750 750 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
751 751
752 752 #--------------------------------------------------------------------------
753 753 # len, getitem
754 754 #--------------------------------------------------------------------------
755 755
756 756 def __len__(self):
757 757 """len(client) returns # of engines."""
758 758 return len(self.ids)
759 759
760 760 def __getitem__(self, key):
761 761 """index access returns DirectView multiplexer objects
762 762
763 763 Must be int, slice, or list/tuple/xrange of ints"""
764 764 if not isinstance(key, (int, slice, tuple, list, xrange)):
765 765 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
766 766 else:
767 767 return self.direct_view(key)
768 768
769 769 #--------------------------------------------------------------------------
770 770 # Begin public methods
771 771 #--------------------------------------------------------------------------
772 772
773 773 @property
774 774 def ids(self):
775 775 """Always up-to-date ids property."""
776 776 self._flush_notifications()
777 777 # always copy:
778 778 return list(self._ids)
779 779
780 780 def close(self):
781 781 if self._closed:
782 782 return
783 783 snames = filter(lambda n: n.endswith('socket'), dir(self))
784 784 for socket in map(lambda name: getattr(self, name), snames):
785 785 if isinstance(socket, zmq.Socket) and not socket.closed:
786 786 socket.close()
787 787 self._closed = True
788 788
789 789 def spin(self):
790 790 """Flush any registration notifications and execution results
791 791 waiting in the ZMQ queue.
792 792 """
793 793 if self._notification_socket:
794 794 self._flush_notifications()
795 795 if self._mux_socket:
796 796 self._flush_results(self._mux_socket)
797 797 if self._task_socket:
798 798 self._flush_results(self._task_socket)
799 799 if self._control_socket:
800 800 self._flush_control(self._control_socket)
801 801 if self._iopub_socket:
802 802 self._flush_iopub(self._iopub_socket)
803 803 if self._query_socket:
804 804 self._flush_ignored_hub_replies()
805 805
806 806 def wait(self, jobs=None, timeout=-1):
807 807 """waits on one or more `jobs`, for up to `timeout` seconds.
808 808
809 809 Parameters
810 810 ----------
811 811
812 812 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
813 813 ints are indices to self.history
814 814 strs are msg_ids
815 815 default: wait on all outstanding messages
816 816 timeout : float
817 817 a time in seconds, after which to give up.
818 818 default is -1, which means no timeout
819 819
820 820 Returns
821 821 -------
822 822
823 823 True : when all msg_ids are done
824 824 False : timeout reached, some msg_ids still outstanding
825 825 """
826 826 tic = time.time()
827 827 if jobs is None:
828 828 theids = self.outstanding
829 829 else:
830 830 if isinstance(jobs, (int, basestring, AsyncResult)):
831 831 jobs = [jobs]
832 832 theids = set()
833 833 for job in jobs:
834 834 if isinstance(job, int):
835 835 # index access
836 836 job = self.history[job]
837 837 elif isinstance(job, AsyncResult):
838 838 map(theids.add, job.msg_ids)
839 839 continue
840 840 theids.add(job)
841 841 if not theids.intersection(self.outstanding):
842 842 return True
843 843 self.spin()
844 844 while theids.intersection(self.outstanding):
845 845 if timeout >= 0 and ( time.time()-tic ) > timeout:
846 846 break
847 847 time.sleep(1e-3)
848 848 self.spin()
849 849 return len(theids.intersection(self.outstanding)) == 0
850 850
851 851 #--------------------------------------------------------------------------
852 852 # Control methods
853 853 #--------------------------------------------------------------------------
854 854
855 855 @spin_first
856 856 def clear(self, targets=None, block=None):
857 857 """Clear the namespace in target(s)."""
858 858 block = self.block if block is None else block
859 859 targets = self._build_targets(targets)[0]
860 860 for t in targets:
861 861 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
862 862 error = False
863 863 if block:
864 864 self._flush_ignored_control()
865 865 for i in range(len(targets)):
866 866 idents,msg = self.session.recv(self._control_socket,0)
867 867 if self.debug:
868 868 pprint(msg)
869 869 if msg['content']['status'] != 'ok':
870 870 error = self._unwrap_exception(msg['content'])
871 871 else:
872 872 self._ignored_control_replies += len(targets)
873 873 if error:
874 874 raise error
875 875
876 876
877 877 @spin_first
878 878 def abort(self, jobs=None, targets=None, block=None):
879 879 """Abort specific jobs from the execution queues of target(s).
880 880
881 881 This is a mechanism to prevent jobs that have already been submitted
882 882 from executing.
883 883
884 884 Parameters
885 885 ----------
886 886
887 887 jobs : msg_id, list of msg_ids, or AsyncResult
888 888 The jobs to be aborted
889
889
890 If unspecified/None: abort all outstanding jobs.
890 891
891 892 """
892 893 block = self.block if block is None else block
894 jobs = jobs if jobs is not None else list(self.outstanding)
893 895 targets = self._build_targets(targets)[0]
896
894 897 msg_ids = []
895 898 if isinstance(jobs, (basestring,AsyncResult)):
896 899 jobs = [jobs]
897 900 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
898 901 if bad_ids:
899 902 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
900 903 for j in jobs:
901 904 if isinstance(j, AsyncResult):
902 905 msg_ids.extend(j.msg_ids)
903 906 else:
904 907 msg_ids.append(j)
905 908 content = dict(msg_ids=msg_ids)
906 909 for t in targets:
907 910 self.session.send(self._control_socket, 'abort_request',
908 911 content=content, ident=t)
909 912 error = False
910 913 if block:
911 914 self._flush_ignored_control()
912 915 for i in range(len(targets)):
913 916 idents,msg = self.session.recv(self._control_socket,0)
914 917 if self.debug:
915 918 pprint(msg)
916 919 if msg['content']['status'] != 'ok':
917 920 error = self._unwrap_exception(msg['content'])
918 921 else:
919 922 self._ignored_control_replies += len(targets)
920 923 if error:
921 924 raise error
922 925
923 926 @spin_first
924 927 def shutdown(self, targets=None, restart=False, hub=False, block=None):
925 928 """Terminates one or more engine processes, optionally including the hub."""
926 929 block = self.block if block is None else block
927 930 if hub:
928 931 targets = 'all'
929 932 targets = self._build_targets(targets)[0]
930 933 for t in targets:
931 934 self.session.send(self._control_socket, 'shutdown_request',
932 935 content={'restart':restart},ident=t)
933 936 error = False
934 937 if block or hub:
935 938 self._flush_ignored_control()
936 939 for i in range(len(targets)):
937 940 idents,msg = self.session.recv(self._control_socket, 0)
938 941 if self.debug:
939 942 pprint(msg)
940 943 if msg['content']['status'] != 'ok':
941 944 error = self._unwrap_exception(msg['content'])
942 945 else:
943 946 self._ignored_control_replies += len(targets)
944 947
945 948 if hub:
946 949 time.sleep(0.25)
947 950 self.session.send(self._query_socket, 'shutdown_request')
948 951 idents,msg = self.session.recv(self._query_socket, 0)
949 952 if self.debug:
950 953 pprint(msg)
951 954 if msg['content']['status'] != 'ok':
952 955 error = self._unwrap_exception(msg['content'])
953 956
954 957 if error:
955 958 raise error
956 959
957 960 #--------------------------------------------------------------------------
958 961 # Execution related methods
959 962 #--------------------------------------------------------------------------
960 963
961 964 def _maybe_raise(self, result):
962 965 """wrapper for maybe raising an exception if apply failed."""
963 966 if isinstance(result, error.RemoteError):
964 967 raise result
965 968
966 969 return result
967 970
968 971 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
969 972 ident=None):
970 973 """construct and send an apply message via a socket.
971 974
972 975 This is the principal method with which all engine execution is performed by views.
973 976 """
974 977
975 978 assert not self._closed, "cannot use me anymore, I'm closed!"
976 979 # defaults:
977 980 args = args if args is not None else []
978 981 kwargs = kwargs if kwargs is not None else {}
979 982 subheader = subheader if subheader is not None else {}
980 983
981 984 # validate arguments
982 985 if not callable(f):
983 986 raise TypeError("f must be callable, not %s"%type(f))
984 987 if not isinstance(args, (tuple, list)):
985 988 raise TypeError("args must be tuple or list, not %s"%type(args))
986 989 if not isinstance(kwargs, dict):
987 990 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
988 991 if not isinstance(subheader, dict):
989 992 raise TypeError("subheader must be dict, not %s"%type(subheader))
990 993
991 994 bufs = util.pack_apply_message(f,args,kwargs)
992 995
993 996 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
994 997 subheader=subheader, track=track)
995 998
996 999 msg_id = msg['header']['msg_id']
997 1000 self.outstanding.add(msg_id)
998 1001 if ident:
999 1002 # possibly routed to a specific engine
1000 1003 if isinstance(ident, list):
1001 1004 ident = ident[-1]
1002 1005 if ident in self._engines.values():
1003 1006 # save for later, in case of engine death
1004 1007 self._outstanding_dict[ident].add(msg_id)
1005 1008 self.history.append(msg_id)
1006 1009 self.metadata[msg_id]['submitted'] = datetime.now()
1007 1010
1008 1011 return msg
1009 1012
1010 1013 #--------------------------------------------------------------------------
1011 1014 # construct a View object
1012 1015 #--------------------------------------------------------------------------
1013 1016
1014 1017 def load_balanced_view(self, targets=None):
1015 1018 """construct a DirectView object.
1016 1019
1017 1020 If no arguments are specified, create a LoadBalancedView
1018 1021 using all engines.
1019 1022
1020 1023 Parameters
1021 1024 ----------
1022 1025
1023 1026 targets: list,slice,int,etc. [default: use all engines]
1024 1027 The subset of engines across which to load-balance
1025 1028 """
1026 1029 if targets == 'all':
1027 1030 targets = None
1028 1031 if targets is not None:
1029 1032 targets = self._build_targets(targets)[1]
1030 1033 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1031 1034
1032 1035 def direct_view(self, targets='all'):
1033 1036 """construct a DirectView object.
1034 1037
1035 1038 If no targets are specified, create a DirectView using all engines.
1036 1039
1037 1040 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1038 1041 evaluate the target engines at each execution, whereas rc[:] will connect to
1039 1042 all *current* engines, and that list will not change.
1040 1043
1041 1044 That is, 'all' will always use all engines, whereas rc[:] will not use
1042 1045 engines added after the DirectView is constructed.
1043 1046
1044 1047 Parameters
1045 1048 ----------
1046 1049
1047 1050 targets: list,slice,int,etc. [default: use all engines]
1048 1051 The engines to use for the View
1049 1052 """
1050 1053 single = isinstance(targets, int)
1051 1054 # allow 'all' to be lazily evaluated at each execution
1052 1055 if targets != 'all':
1053 1056 targets = self._build_targets(targets)[1]
1054 1057 if single:
1055 1058 targets = targets[0]
1056 1059 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1057 1060
1058 1061 #--------------------------------------------------------------------------
1059 1062 # Query methods
1060 1063 #--------------------------------------------------------------------------
1061 1064
1062 1065 @spin_first
1063 1066 def get_result(self, indices_or_msg_ids=None, block=None):
1064 1067 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1065 1068
1066 1069 If the client already has the results, no request to the Hub will be made.
1067 1070
1068 1071 This is a convenient way to construct AsyncResult objects, which are wrappers
1069 1072 that include metadata about execution, and allow for awaiting results that
1070 1073 were not submitted by this Client.
1071 1074
1072 1075 It can also be a convenient way to retrieve the metadata associated with
1073 1076 blocking execution, since it always retrieves
1074 1077
1075 1078 Examples
1076 1079 --------
1077 1080 ::
1078 1081
1079 1082 In [10]: r = client.apply()
1080 1083
1081 1084 Parameters
1082 1085 ----------
1083 1086
1084 1087 indices_or_msg_ids : integer history index, str msg_id, or list of either
1085 1088 The indices or msg_ids of indices to be retrieved
1086 1089
1087 1090 block : bool
1088 1091 Whether to wait for the result to be done
1089 1092
1090 1093 Returns
1091 1094 -------
1092 1095
1093 1096 AsyncResult
1094 1097 A single AsyncResult object will always be returned.
1095 1098
1096 1099 AsyncHubResult
1097 1100 A subclass of AsyncResult that retrieves results from the Hub
1098 1101
1099 1102 """
1100 1103 block = self.block if block is None else block
1101 1104 if indices_or_msg_ids is None:
1102 1105 indices_or_msg_ids = -1
1103 1106
1104 1107 if not isinstance(indices_or_msg_ids, (list,tuple)):
1105 1108 indices_or_msg_ids = [indices_or_msg_ids]
1106 1109
1107 1110 theids = []
1108 1111 for id in indices_or_msg_ids:
1109 1112 if isinstance(id, int):
1110 1113 id = self.history[id]
1111 1114 if not isinstance(id, basestring):
1112 1115 raise TypeError("indices must be str or int, not %r"%id)
1113 1116 theids.append(id)
1114 1117
1115 1118 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1116 1119 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1117 1120
1118 1121 if remote_ids:
1119 1122 ar = AsyncHubResult(self, msg_ids=theids)
1120 1123 else:
1121 1124 ar = AsyncResult(self, msg_ids=theids)
1122 1125
1123 1126 if block:
1124 1127 ar.wait()
1125 1128
1126 1129 return ar
1127 1130
1128 1131 @spin_first
1129 1132 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1130 1133 """Resubmit one or more tasks.
1131 1134
1132 1135 in-flight tasks may not be resubmitted.
1133 1136
1134 1137 Parameters
1135 1138 ----------
1136 1139
1137 1140 indices_or_msg_ids : integer history index, str msg_id, or list of either
1138 1141 The indices or msg_ids of indices to be retrieved
1139 1142
1140 1143 block : bool
1141 1144 Whether to wait for the result to be done
1142 1145
1143 1146 Returns
1144 1147 -------
1145 1148
1146 1149 AsyncHubResult
1147 1150 A subclass of AsyncResult that retrieves results from the Hub
1148 1151
1149 1152 """
1150 1153 block = self.block if block is None else block
1151 1154 if indices_or_msg_ids is None:
1152 1155 indices_or_msg_ids = -1
1153 1156
1154 1157 if not isinstance(indices_or_msg_ids, (list,tuple)):
1155 1158 indices_or_msg_ids = [indices_or_msg_ids]
1156 1159
1157 1160 theids = []
1158 1161 for id in indices_or_msg_ids:
1159 1162 if isinstance(id, int):
1160 1163 id = self.history[id]
1161 1164 if not isinstance(id, basestring):
1162 1165 raise TypeError("indices must be str or int, not %r"%id)
1163 1166 theids.append(id)
1164 1167
1165 1168 for msg_id in theids:
1166 1169 self.outstanding.discard(msg_id)
1167 1170 if msg_id in self.history:
1168 1171 self.history.remove(msg_id)
1169 1172 self.results.pop(msg_id, None)
1170 1173 self.metadata.pop(msg_id, None)
1171 1174 content = dict(msg_ids = theids)
1172 1175
1173 1176 self.session.send(self._query_socket, 'resubmit_request', content)
1174 1177
1175 1178 zmq.select([self._query_socket], [], [])
1176 1179 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1177 1180 if self.debug:
1178 1181 pprint(msg)
1179 1182 content = msg['content']
1180 1183 if content['status'] != 'ok':
1181 1184 raise self._unwrap_exception(content)
1182 1185
1183 1186 ar = AsyncHubResult(self, msg_ids=theids)
1184 1187
1185 1188 if block:
1186 1189 ar.wait()
1187 1190
1188 1191 return ar
1189 1192
1190 1193 @spin_first
1191 1194 def result_status(self, msg_ids, status_only=True):
1192 1195 """Check on the status of the result(s) of the apply request with `msg_ids`.
1193 1196
1194 1197 If status_only is False, then the actual results will be retrieved, else
1195 1198 only the status of the results will be checked.
1196 1199
1197 1200 Parameters
1198 1201 ----------
1199 1202
1200 1203 msg_ids : list of msg_ids
1201 1204 if int:
1202 1205 Passed as index to self.history for convenience.
1203 1206 status_only : bool (default: True)
1204 1207 if False:
1205 1208 Retrieve the actual results of completed tasks.
1206 1209
1207 1210 Returns
1208 1211 -------
1209 1212
1210 1213 results : dict
1211 1214 There will always be the keys 'pending' and 'completed', which will
1212 1215 be lists of msg_ids that are incomplete or complete. If `status_only`
1213 1216 is False, then completed results will be keyed by their `msg_id`.
1214 1217 """
1215 1218 if not isinstance(msg_ids, (list,tuple)):
1216 1219 msg_ids = [msg_ids]
1217 1220
1218 1221 theids = []
1219 1222 for msg_id in msg_ids:
1220 1223 if isinstance(msg_id, int):
1221 1224 msg_id = self.history[msg_id]
1222 1225 if not isinstance(msg_id, basestring):
1223 1226 raise TypeError("msg_ids must be str, not %r"%msg_id)
1224 1227 theids.append(msg_id)
1225 1228
1226 1229 completed = []
1227 1230 local_results = {}
1228 1231
1229 1232 # comment this block out to temporarily disable local shortcut:
1230 1233 for msg_id in theids:
1231 1234 if msg_id in self.results:
1232 1235 completed.append(msg_id)
1233 1236 local_results[msg_id] = self.results[msg_id]
1234 1237 theids.remove(msg_id)
1235 1238
1236 1239 if theids: # some not locally cached
1237 1240 content = dict(msg_ids=theids, status_only=status_only)
1238 1241 msg = self.session.send(self._query_socket, "result_request", content=content)
1239 1242 zmq.select([self._query_socket], [], [])
1240 1243 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1241 1244 if self.debug:
1242 1245 pprint(msg)
1243 1246 content = msg['content']
1244 1247 if content['status'] != 'ok':
1245 1248 raise self._unwrap_exception(content)
1246 1249 buffers = msg['buffers']
1247 1250 else:
1248 1251 content = dict(completed=[],pending=[])
1249 1252
1250 1253 content['completed'].extend(completed)
1251 1254
1252 1255 if status_only:
1253 1256 return content
1254 1257
1255 1258 failures = []
1256 1259 # load cached results into result:
1257 1260 content.update(local_results)
1258 1261
1259 1262 # update cache with results:
1260 1263 for msg_id in sorted(theids):
1261 1264 if msg_id in content['completed']:
1262 1265 rec = content[msg_id]
1263 1266 parent = rec['header']
1264 1267 header = rec['result_header']
1265 1268 rcontent = rec['result_content']
1266 1269 iodict = rec['io']
1267 1270 if isinstance(rcontent, str):
1268 1271 rcontent = self.session.unpack(rcontent)
1269 1272
1270 1273 md = self.metadata[msg_id]
1271 1274 md.update(self._extract_metadata(header, parent, rcontent))
1272 1275 md.update(iodict)
1273 1276
1274 1277 if rcontent['status'] == 'ok':
1275 1278 res,buffers = util.unserialize_object(buffers)
1276 1279 else:
1277 1280 print rcontent
1278 1281 res = self._unwrap_exception(rcontent)
1279 1282 failures.append(res)
1280 1283
1281 1284 self.results[msg_id] = res
1282 1285 content[msg_id] = res
1283 1286
1284 1287 if len(theids) == 1 and failures:
1285 1288 raise failures[0]
1286 1289
1287 1290 error.collect_exceptions(failures, "result_status")
1288 1291 return content
1289 1292
1290 1293 @spin_first
1291 1294 def queue_status(self, targets='all', verbose=False):
1292 1295 """Fetch the status of engine queues.
1293 1296
1294 1297 Parameters
1295 1298 ----------
1296 1299
1297 1300 targets : int/str/list of ints/strs
1298 1301 the engines whose states are to be queried.
1299 1302 default : all
1300 1303 verbose : bool
1301 1304 Whether to return lengths only, or lists of ids for each element
1302 1305 """
1303 1306 engine_ids = self._build_targets(targets)[1]
1304 1307 content = dict(targets=engine_ids, verbose=verbose)
1305 1308 self.session.send(self._query_socket, "queue_request", content=content)
1306 1309 idents,msg = self.session.recv(self._query_socket, 0)
1307 1310 if self.debug:
1308 1311 pprint(msg)
1309 1312 content = msg['content']
1310 1313 status = content.pop('status')
1311 1314 if status != 'ok':
1312 1315 raise self._unwrap_exception(content)
1313 1316 content = rekey(content)
1314 1317 if isinstance(targets, int):
1315 1318 return content[targets]
1316 1319 else:
1317 1320 return content
1318 1321
1319 1322 @spin_first
1320 1323 def purge_results(self, jobs=[], targets=[]):
1321 1324 """Tell the Hub to forget results.
1322 1325
1323 1326 Individual results can be purged by msg_id, or the entire
1324 1327 history of specific targets can be purged.
1325 1328
1326 1329 Use `purge_results('all')` to scrub everything from the Hub's db.
1327 1330
1328 1331 Parameters
1329 1332 ----------
1330 1333
1331 1334 jobs : str or list of str or AsyncResult objects
1332 1335 the msg_ids whose results should be forgotten.
1333 1336 targets : int/str/list of ints/strs
1334 1337 The targets, by int_id, whose entire history is to be purged.
1335 1338
1336 1339 default : None
1337 1340 """
1338 1341 if not targets and not jobs:
1339 1342 raise ValueError("Must specify at least one of `targets` and `jobs`")
1340 1343 if targets:
1341 1344 targets = self._build_targets(targets)[1]
1342 1345
1343 1346 # construct msg_ids from jobs
1344 1347 if jobs == 'all':
1345 1348 msg_ids = jobs
1346 1349 else:
1347 1350 msg_ids = []
1348 1351 if isinstance(jobs, (basestring,AsyncResult)):
1349 1352 jobs = [jobs]
1350 1353 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1351 1354 if bad_ids:
1352 1355 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1353 1356 for j in jobs:
1354 1357 if isinstance(j, AsyncResult):
1355 1358 msg_ids.extend(j.msg_ids)
1356 1359 else:
1357 1360 msg_ids.append(j)
1358 1361
1359 1362 content = dict(engine_ids=targets, msg_ids=msg_ids)
1360 1363 self.session.send(self._query_socket, "purge_request", content=content)
1361 1364 idents, msg = self.session.recv(self._query_socket, 0)
1362 1365 if self.debug:
1363 1366 pprint(msg)
1364 1367 content = msg['content']
1365 1368 if content['status'] != 'ok':
1366 1369 raise self._unwrap_exception(content)
1367 1370
1368 1371 @spin_first
1369 1372 def hub_history(self):
1370 1373 """Get the Hub's history
1371 1374
1372 1375 Just like the Client, the Hub has a history, which is a list of msg_ids.
1373 1376 This will contain the history of all clients, and, depending on configuration,
1374 1377 may contain history across multiple cluster sessions.
1375 1378
1376 1379 Any msg_id returned here is a valid argument to `get_result`.
1377 1380
1378 1381 Returns
1379 1382 -------
1380 1383
1381 1384 msg_ids : list of strs
1382 1385 list of all msg_ids, ordered by task submission time.
1383 1386 """
1384 1387
1385 1388 self.session.send(self._query_socket, "history_request", content={})
1386 1389 idents, msg = self.session.recv(self._query_socket, 0)
1387 1390
1388 1391 if self.debug:
1389 1392 pprint(msg)
1390 1393 content = msg['content']
1391 1394 if content['status'] != 'ok':
1392 1395 raise self._unwrap_exception(content)
1393 1396 else:
1394 1397 return content['history']
1395 1398
1396 1399 @spin_first
1397 1400 def db_query(self, query, keys=None):
1398 1401 """Query the Hub's TaskRecord database
1399 1402
1400 1403 This will return a list of task record dicts that match `query`
1401 1404
1402 1405 Parameters
1403 1406 ----------
1404 1407
1405 1408 query : mongodb query dict
1406 1409 The search dict. See mongodb query docs for details.
1407 1410 keys : list of strs [optional]
1408 1411 The subset of keys to be returned. The default is to fetch everything but buffers.
1409 1412 'msg_id' will *always* be included.
1410 1413 """
1411 1414 if isinstance(keys, basestring):
1412 1415 keys = [keys]
1413 1416 content = dict(query=query, keys=keys)
1414 1417 self.session.send(self._query_socket, "db_request", content=content)
1415 1418 idents, msg = self.session.recv(self._query_socket, 0)
1416 1419 if self.debug:
1417 1420 pprint(msg)
1418 1421 content = msg['content']
1419 1422 if content['status'] != 'ok':
1420 1423 raise self._unwrap_exception(content)
1421 1424
1422 1425 records = content['records']
1423 1426
1424 1427 buffer_lens = content['buffer_lens']
1425 1428 result_buffer_lens = content['result_buffer_lens']
1426 1429 buffers = msg['buffers']
1427 1430 has_bufs = buffer_lens is not None
1428 1431 has_rbufs = result_buffer_lens is not None
1429 1432 for i,rec in enumerate(records):
1430 1433 # relink buffers
1431 1434 if has_bufs:
1432 1435 blen = buffer_lens[i]
1433 1436 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1434 1437 if has_rbufs:
1435 1438 blen = result_buffer_lens[i]
1436 1439 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1437 1440
1438 1441 return records
1439 1442
1440 1443 __all__ = [ 'Client' ]
@@ -1,1057 +1,1059 b''
1 1 """Views of remote engines.
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import imp
19 19 import sys
20 20 import warnings
21 21 from contextlib import contextmanager
22 22 from types import ModuleType
23 23
24 24 import zmq
25 25
26 26 from IPython.testing.skipdoctest import skip_doctest
27 27 from IPython.utils.traitlets import (
28 28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
29 29 )
30 30 from IPython.external.decorator import decorator
31 31
32 32 from IPython.parallel import util
33 33 from IPython.parallel.controller.dependency import Dependency, dependent
34 34
35 35 from . import map as Map
36 36 from .asyncresult import AsyncResult, AsyncMapResult
37 37 from .remotefunction import ParallelFunction, parallel, remote
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Decorators
41 41 #-----------------------------------------------------------------------------
42 42
43 43 @decorator
44 44 def save_ids(f, self, *args, **kwargs):
45 45 """Keep our history and outstanding attributes up to date after a method call."""
46 46 n_previous = len(self.client.history)
47 47 try:
48 48 ret = f(self, *args, **kwargs)
49 49 finally:
50 50 nmsgs = len(self.client.history) - n_previous
51 51 msg_ids = self.client.history[-nmsgs:]
52 52 self.history.extend(msg_ids)
53 53 map(self.outstanding.add, msg_ids)
54 54 return ret
55 55
56 56 @decorator
57 57 def sync_results(f, self, *args, **kwargs):
58 58 """sync relevant results from self.client to our results attribute."""
59 59 ret = f(self, *args, **kwargs)
60 60 delta = self.outstanding.difference(self.client.outstanding)
61 61 completed = self.outstanding.intersection(delta)
62 62 self.outstanding = self.outstanding.difference(completed)
63 63 for msg_id in completed:
64 64 self.results[msg_id] = self.client.results[msg_id]
65 65 return ret
66 66
67 67 @decorator
68 68 def spin_after(f, self, *args, **kwargs):
69 69 """call spin after the method."""
70 70 ret = f(self, *args, **kwargs)
71 71 self.spin()
72 72 return ret
73 73
74 74 #-----------------------------------------------------------------------------
75 75 # Classes
76 76 #-----------------------------------------------------------------------------
77 77
78 78 @skip_doctest
79 79 class View(HasTraits):
80 80 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
81 81
82 82 Don't use this class, use subclasses.
83 83
84 84 Methods
85 85 -------
86 86
87 87 spin
88 88 flushes incoming results and registration state changes
89 89 control methods spin, and requesting `ids` also ensures up to date
90 90
91 91 wait
92 92 wait on one or more msg_ids
93 93
94 94 execution methods
95 95 apply
96 96 legacy: execute, run
97 97
98 98 data movement
99 99 push, pull, scatter, gather
100 100
101 101 query methods
102 102 get_result, queue_status, purge_results, result_status
103 103
104 104 control methods
105 105 abort, shutdown
106 106
107 107 """
108 108 # flags
109 109 block=Bool(False)
110 110 track=Bool(True)
111 111 targets = Any()
112 112
113 113 history=List()
114 114 outstanding = Set()
115 115 results = Dict()
116 116 client = Instance('IPython.parallel.Client')
117 117
118 118 _socket = Instance('zmq.Socket')
119 119 _flag_names = List(['targets', 'block', 'track'])
120 120 _targets = Any()
121 121 _idents = Any()
122 122
123 123 def __init__(self, client=None, socket=None, **flags):
124 124 super(View, self).__init__(client=client, _socket=socket)
125 125 self.block = client.block
126 126
127 127 self.set_flags(**flags)
128 128
129 129 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
130 130
131 131
132 132 def __repr__(self):
133 133 strtargets = str(self.targets)
134 134 if len(strtargets) > 16:
135 135 strtargets = strtargets[:12]+'...]'
136 136 return "<%s %s>"%(self.__class__.__name__, strtargets)
137 137
138 138 def set_flags(self, **kwargs):
139 139 """set my attribute flags by keyword.
140 140
141 141 Views determine behavior with a few attributes (`block`, `track`, etc.).
142 142 These attributes can be set all at once by name with this method.
143 143
144 144 Parameters
145 145 ----------
146 146
147 147 block : bool
148 148 whether to wait for results
149 149 track : bool
150 150 whether to create a MessageTracker to allow the user to
151 151 safely edit after arrays and buffers during non-copying
152 152 sends.
153 153 """
154 154 for name, value in kwargs.iteritems():
155 155 if name not in self._flag_names:
156 156 raise KeyError("Invalid name: %r"%name)
157 157 else:
158 158 setattr(self, name, value)
159 159
160 160 @contextmanager
161 161 def temp_flags(self, **kwargs):
162 162 """temporarily set flags, for use in `with` statements.
163 163
164 164 See set_flags for permanent setting of flags
165 165
166 166 Examples
167 167 --------
168 168
169 169 >>> view.track=False
170 170 ...
171 171 >>> with view.temp_flags(track=True):
172 172 ... ar = view.apply(dostuff, my_big_array)
173 173 ... ar.tracker.wait() # wait for send to finish
174 174 >>> view.track
175 175 False
176 176
177 177 """
178 178 # preflight: save flags, and set temporaries
179 179 saved_flags = {}
180 180 for f in self._flag_names:
181 181 saved_flags[f] = getattr(self, f)
182 182 self.set_flags(**kwargs)
183 183 # yield to the with-statement block
184 184 try:
185 185 yield
186 186 finally:
187 187 # postflight: restore saved flags
188 188 self.set_flags(**saved_flags)
189 189
190 190
191 191 #----------------------------------------------------------------
192 192 # apply
193 193 #----------------------------------------------------------------
194 194
195 195 @sync_results
196 196 @save_ids
197 197 def _really_apply(self, f, args, kwargs, block=None, **options):
198 198 """wrapper for client.send_apply_message"""
199 199 raise NotImplementedError("Implement in subclasses")
200 200
201 201 def apply(self, f, *args, **kwargs):
202 202 """calls f(*args, **kwargs) on remote engines, returning the result.
203 203
204 204 This method sets all apply flags via this View's attributes.
205 205
206 206 if self.block is False:
207 207 returns AsyncResult
208 208 else:
209 209 returns actual result of f(*args, **kwargs)
210 210 """
211 211 return self._really_apply(f, args, kwargs)
212 212
213 213 def apply_async(self, f, *args, **kwargs):
214 214 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
215 215
216 216 returns AsyncResult
217 217 """
218 218 return self._really_apply(f, args, kwargs, block=False)
219 219
220 220 @spin_after
221 221 def apply_sync(self, f, *args, **kwargs):
222 222 """calls f(*args, **kwargs) on remote engines in a blocking manner,
223 223 returning the result.
224 224
225 225 returns: actual result of f(*args, **kwargs)
226 226 """
227 227 return self._really_apply(f, args, kwargs, block=True)
228 228
229 229 #----------------------------------------------------------------
230 230 # wrappers for client and control methods
231 231 #----------------------------------------------------------------
232 232 @sync_results
233 233 def spin(self):
234 234 """spin the client, and sync"""
235 235 self.client.spin()
236 236
237 237 @sync_results
238 238 def wait(self, jobs=None, timeout=-1):
239 239 """waits on one or more `jobs`, for up to `timeout` seconds.
240 240
241 241 Parameters
242 242 ----------
243 243
244 244 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
245 245 ints are indices to self.history
246 246 strs are msg_ids
247 247 default: wait on all outstanding messages
248 248 timeout : float
249 249 a time in seconds, after which to give up.
250 250 default is -1, which means no timeout
251 251
252 252 Returns
253 253 -------
254 254
255 255 True : when all msg_ids are done
256 256 False : timeout reached, some msg_ids still outstanding
257 257 """
258 258 if jobs is None:
259 259 jobs = self.history
260 260 return self.client.wait(jobs, timeout)
261 261
262 262 def abort(self, jobs=None, targets=None, block=None):
263 263 """Abort jobs on my engines.
264 264
265 265 Parameters
266 266 ----------
267 267
268 268 jobs : None, str, list of strs, optional
269 269 if None: abort all jobs.
270 270 else: abort specific msg_id(s).
271 271 """
272 272 block = block if block is not None else self.block
273 273 targets = targets if targets is not None else self.targets
274 jobs = jobs if jobs is not None else list(self.outstanding)
275
274 276 return self.client.abort(jobs=jobs, targets=targets, block=block)
275 277
276 278 def queue_status(self, targets=None, verbose=False):
277 279 """Fetch the Queue status of my engines"""
278 280 targets = targets if targets is not None else self.targets
279 281 return self.client.queue_status(targets=targets, verbose=verbose)
280 282
281 283 def purge_results(self, jobs=[], targets=[]):
282 284 """Instruct the controller to forget specific results."""
283 285 if targets is None or targets == 'all':
284 286 targets = self.targets
285 287 return self.client.purge_results(jobs=jobs, targets=targets)
286 288
287 289 def shutdown(self, targets=None, restart=False, hub=False, block=None):
288 290 """Terminates one or more engine processes, optionally including the hub.
289 291 """
290 292 block = self.block if block is None else block
291 293 if targets is None or targets == 'all':
292 294 targets = self.targets
293 295 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
294 296
295 297 @spin_after
296 298 def get_result(self, indices_or_msg_ids=None):
297 299 """return one or more results, specified by history index or msg_id.
298 300
299 301 See client.get_result for details.
300 302
301 303 """
302 304
303 305 if indices_or_msg_ids is None:
304 306 indices_or_msg_ids = -1
305 307 if isinstance(indices_or_msg_ids, int):
306 308 indices_or_msg_ids = self.history[indices_or_msg_ids]
307 309 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
308 310 indices_or_msg_ids = list(indices_or_msg_ids)
309 311 for i,index in enumerate(indices_or_msg_ids):
310 312 if isinstance(index, int):
311 313 indices_or_msg_ids[i] = self.history[index]
312 314 return self.client.get_result(indices_or_msg_ids)
313 315
314 316 #-------------------------------------------------------------------
315 317 # Map
316 318 #-------------------------------------------------------------------
317 319
318 320 def map(self, f, *sequences, **kwargs):
319 321 """override in subclasses"""
320 322 raise NotImplementedError
321 323
322 324 def map_async(self, f, *sequences, **kwargs):
323 325 """Parallel version of builtin `map`, using this view's engines.
324 326
325 327 This is equivalent to map(...block=False)
326 328
327 329 See `self.map` for details.
328 330 """
329 331 if 'block' in kwargs:
330 332 raise TypeError("map_async doesn't take a `block` keyword argument.")
331 333 kwargs['block'] = False
332 334 return self.map(f,*sequences,**kwargs)
333 335
334 336 def map_sync(self, f, *sequences, **kwargs):
335 337 """Parallel version of builtin `map`, using this view's engines.
336 338
337 339 This is equivalent to map(...block=True)
338 340
339 341 See `self.map` for details.
340 342 """
341 343 if 'block' in kwargs:
342 344 raise TypeError("map_sync doesn't take a `block` keyword argument.")
343 345 kwargs['block'] = True
344 346 return self.map(f,*sequences,**kwargs)
345 347
346 348 def imap(self, f, *sequences, **kwargs):
347 349 """Parallel version of `itertools.imap`.
348 350
349 351 See `self.map` for details.
350 352
351 353 """
352 354
353 355 return iter(self.map_async(f,*sequences, **kwargs))
354 356
355 357 #-------------------------------------------------------------------
356 358 # Decorators
357 359 #-------------------------------------------------------------------
358 360
359 361 def remote(self, block=True, **flags):
360 362 """Decorator for making a RemoteFunction"""
361 363 block = self.block if block is None else block
362 364 return remote(self, block=block, **flags)
363 365
364 366 def parallel(self, dist='b', block=None, **flags):
365 367 """Decorator for making a ParallelFunction"""
366 368 block = self.block if block is None else block
367 369 return parallel(self, dist=dist, block=block, **flags)
368 370
369 371 @skip_doctest
370 372 class DirectView(View):
371 373 """Direct Multiplexer View of one or more engines.
372 374
373 375 These are created via indexed access to a client:
374 376
375 377 >>> dv_1 = client[1]
376 378 >>> dv_all = client[:]
377 379 >>> dv_even = client[::2]
378 380 >>> dv_some = client[1:3]
379 381
380 382 This object provides dictionary access to engine namespaces:
381 383
382 384 # push a=5:
383 385 >>> dv['a'] = 5
384 386 # pull 'foo':
385 387 >>> db['foo']
386 388
387 389 """
388 390
389 391 def __init__(self, client=None, socket=None, targets=None):
390 392 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
391 393
392 394 @property
393 395 def importer(self):
394 396 """sync_imports(local=True) as a property.
395 397
396 398 See sync_imports for details.
397 399
398 400 """
399 401 return self.sync_imports(True)
400 402
401 403 @contextmanager
402 404 def sync_imports(self, local=True):
403 405 """Context Manager for performing simultaneous local and remote imports.
404 406
405 407 'import x as y' will *not* work. The 'as y' part will simply be ignored.
406 408
407 409 >>> with view.sync_imports():
408 410 ... from numpy import recarray
409 411 importing recarray from numpy on engine(s)
410 412
411 413 """
412 414 import __builtin__
413 415 local_import = __builtin__.__import__
414 416 modules = set()
415 417 results = []
416 418 @util.interactive
417 419 def remote_import(name, fromlist, level):
418 420 """the function to be passed to apply, that actually performs the import
419 421 on the engine, and loads up the user namespace.
420 422 """
421 423 import sys
422 424 user_ns = globals()
423 425 mod = __import__(name, fromlist=fromlist, level=level)
424 426 if fromlist:
425 427 for key in fromlist:
426 428 user_ns[key] = getattr(mod, key)
427 429 else:
428 430 user_ns[name] = sys.modules[name]
429 431
430 432 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
431 433 """the drop-in replacement for __import__, that optionally imports
432 434 locally as well.
433 435 """
434 436 # don't override nested imports
435 437 save_import = __builtin__.__import__
436 438 __builtin__.__import__ = local_import
437 439
438 440 if imp.lock_held():
439 441 # this is a side-effect import, don't do it remotely, or even
440 442 # ignore the local effects
441 443 return local_import(name, globals, locals, fromlist, level)
442 444
443 445 imp.acquire_lock()
444 446 if local:
445 447 mod = local_import(name, globals, locals, fromlist, level)
446 448 else:
447 449 raise NotImplementedError("remote-only imports not yet implemented")
448 450 imp.release_lock()
449 451
450 452 key = name+':'+','.join(fromlist or [])
451 453 if level == -1 and key not in modules:
452 454 modules.add(key)
453 455 if fromlist:
454 456 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
455 457 else:
456 458 print "importing %s on engine(s)"%name
457 459 results.append(self.apply_async(remote_import, name, fromlist, level))
458 460 # restore override
459 461 __builtin__.__import__ = save_import
460 462
461 463 return mod
462 464
463 465 # override __import__
464 466 __builtin__.__import__ = view_import
465 467 try:
466 468 # enter the block
467 469 yield
468 470 except ImportError:
469 471 if not local:
470 472 # ignore import errors if not doing local imports
471 473 pass
472 474 finally:
473 475 # always restore __import__
474 476 __builtin__.__import__ = local_import
475 477
476 478 for r in results:
477 479 # raise possible remote ImportErrors here
478 480 r.get()
479 481
480 482
481 483 @sync_results
482 484 @save_ids
483 485 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
484 486 """calls f(*args, **kwargs) on remote engines, returning the result.
485 487
486 488 This method sets all of `apply`'s flags via this View's attributes.
487 489
488 490 Parameters
489 491 ----------
490 492
491 493 f : callable
492 494
493 495 args : list [default: empty]
494 496
495 497 kwargs : dict [default: empty]
496 498
497 499 targets : target list [default: self.targets]
498 500 where to run
499 501 block : bool [default: self.block]
500 502 whether to block
501 503 track : bool [default: self.track]
502 504 whether to ask zmq to track the message, for safe non-copying sends
503 505
504 506 Returns
505 507 -------
506 508
507 509 if self.block is False:
508 510 returns AsyncResult
509 511 else:
510 512 returns actual result of f(*args, **kwargs) on the engine(s)
511 513 This will be a list of self.targets is also a list (even length 1), or
512 514 the single result if self.targets is an integer engine id
513 515 """
514 516 args = [] if args is None else args
515 517 kwargs = {} if kwargs is None else kwargs
516 518 block = self.block if block is None else block
517 519 track = self.track if track is None else track
518 520 targets = self.targets if targets is None else targets
519 521
520 522 _idents = self.client._build_targets(targets)[0]
521 523 msg_ids = []
522 524 trackers = []
523 525 for ident in _idents:
524 526 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
525 527 ident=ident)
526 528 if track:
527 529 trackers.append(msg['tracker'])
528 530 msg_ids.append(msg['header']['msg_id'])
529 531 tracker = None if track is False else zmq.MessageTracker(*trackers)
530 532 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
531 533 if block:
532 534 try:
533 535 return ar.get()
534 536 except KeyboardInterrupt:
535 537 pass
536 538 return ar
537 539
538 540 @spin_after
539 541 def map(self, f, *sequences, **kwargs):
540 542 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
541 543
542 544 Parallel version of builtin `map`, using this View's `targets`.
543 545
544 546 There will be one task per target, so work will be chunked
545 547 if the sequences are longer than `targets`.
546 548
547 549 Results can be iterated as they are ready, but will become available in chunks.
548 550
549 551 Parameters
550 552 ----------
551 553
552 554 f : callable
553 555 function to be mapped
554 556 *sequences: one or more sequences of matching length
555 557 the sequences to be distributed and passed to `f`
556 558 block : bool
557 559 whether to wait for the result or not [default self.block]
558 560
559 561 Returns
560 562 -------
561 563
562 564 if block=False:
563 565 AsyncMapResult
564 566 An object like AsyncResult, but which reassembles the sequence of results
565 567 into a single list. AsyncMapResults can be iterated through before all
566 568 results are complete.
567 569 else:
568 570 list
569 571 the result of map(f,*sequences)
570 572 """
571 573
572 574 block = kwargs.pop('block', self.block)
573 575 for k in kwargs.keys():
574 576 if k not in ['block', 'track']:
575 577 raise TypeError("invalid keyword arg, %r"%k)
576 578
577 579 assert len(sequences) > 0, "must have some sequences to map onto!"
578 580 pf = ParallelFunction(self, f, block=block, **kwargs)
579 581 return pf.map(*sequences)
580 582
581 583 def execute(self, code, targets=None, block=None):
582 584 """Executes `code` on `targets` in blocking or nonblocking manner.
583 585
584 586 ``execute`` is always `bound` (affects engine namespace)
585 587
586 588 Parameters
587 589 ----------
588 590
589 591 code : str
590 592 the code string to be executed
591 593 block : bool
592 594 whether or not to wait until done to return
593 595 default: self.block
594 596 """
595 597 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
596 598
597 599 def run(self, filename, targets=None, block=None):
598 600 """Execute contents of `filename` on my engine(s).
599 601
600 602 This simply reads the contents of the file and calls `execute`.
601 603
602 604 Parameters
603 605 ----------
604 606
605 607 filename : str
606 608 The path to the file
607 609 targets : int/str/list of ints/strs
608 610 the engines on which to execute
609 611 default : all
610 612 block : bool
611 613 whether or not to wait until done
612 614 default: self.block
613 615
614 616 """
615 617 with open(filename, 'r') as f:
616 618 # add newline in case of trailing indented whitespace
617 619 # which will cause SyntaxError
618 620 code = f.read()+'\n'
619 621 return self.execute(code, block=block, targets=targets)
620 622
621 623 def update(self, ns):
622 624 """update remote namespace with dict `ns`
623 625
624 626 See `push` for details.
625 627 """
626 628 return self.push(ns, block=self.block, track=self.track)
627 629
628 630 def push(self, ns, targets=None, block=None, track=None):
629 631 """update remote namespace with dict `ns`
630 632
631 633 Parameters
632 634 ----------
633 635
634 636 ns : dict
635 637 dict of keys with which to update engine namespace(s)
636 638 block : bool [default : self.block]
637 639 whether to wait to be notified of engine receipt
638 640
639 641 """
640 642
641 643 block = block if block is not None else self.block
642 644 track = track if track is not None else self.track
643 645 targets = targets if targets is not None else self.targets
644 646 # applier = self.apply_sync if block else self.apply_async
645 647 if not isinstance(ns, dict):
646 648 raise TypeError("Must be a dict, not %s"%type(ns))
647 649 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
648 650
649 651 def get(self, key_s):
650 652 """get object(s) by `key_s` from remote namespace
651 653
652 654 see `pull` for details.
653 655 """
654 656 # block = block if block is not None else self.block
655 657 return self.pull(key_s, block=True)
656 658
657 659 def pull(self, names, targets=None, block=None):
658 660 """get object(s) by `name` from remote namespace
659 661
660 662 will return one object if it is a key.
661 663 can also take a list of keys, in which case it will return a list of objects.
662 664 """
663 665 block = block if block is not None else self.block
664 666 targets = targets if targets is not None else self.targets
665 667 applier = self.apply_sync if block else self.apply_async
666 668 if isinstance(names, basestring):
667 669 pass
668 670 elif isinstance(names, (list,tuple,set)):
669 671 for key in names:
670 672 if not isinstance(key, basestring):
671 673 raise TypeError("keys must be str, not type %r"%type(key))
672 674 else:
673 675 raise TypeError("names must be strs, not %r"%names)
674 676 return self._really_apply(util._pull, (names,), block=block, targets=targets)
675 677
676 678 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
677 679 """
678 680 Partition a Python sequence and send the partitions to a set of engines.
679 681 """
680 682 block = block if block is not None else self.block
681 683 track = track if track is not None else self.track
682 684 targets = targets if targets is not None else self.targets
683 685
684 686 mapObject = Map.dists[dist]()
685 687 nparts = len(targets)
686 688 msg_ids = []
687 689 trackers = []
688 690 for index, engineid in enumerate(targets):
689 691 partition = mapObject.getPartition(seq, index, nparts)
690 692 if flatten and len(partition) == 1:
691 693 ns = {key: partition[0]}
692 694 else:
693 695 ns = {key: partition}
694 696 r = self.push(ns, block=False, track=track, targets=engineid)
695 697 msg_ids.extend(r.msg_ids)
696 698 if track:
697 699 trackers.append(r._tracker)
698 700
699 701 if track:
700 702 tracker = zmq.MessageTracker(*trackers)
701 703 else:
702 704 tracker = None
703 705
704 706 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
705 707 if block:
706 708 r.wait()
707 709 else:
708 710 return r
709 711
710 712 @sync_results
711 713 @save_ids
712 714 def gather(self, key, dist='b', targets=None, block=None):
713 715 """
714 716 Gather a partitioned sequence on a set of engines as a single local seq.
715 717 """
716 718 block = block if block is not None else self.block
717 719 targets = targets if targets is not None else self.targets
718 720 mapObject = Map.dists[dist]()
719 721 msg_ids = []
720 722
721 723 for index, engineid in enumerate(targets):
722 724 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
723 725
724 726 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
725 727
726 728 if block:
727 729 try:
728 730 return r.get()
729 731 except KeyboardInterrupt:
730 732 pass
731 733 return r
732 734
733 735 def __getitem__(self, key):
734 736 return self.get(key)
735 737
736 738 def __setitem__(self,key, value):
737 739 self.update({key:value})
738 740
739 741 def clear(self, targets=None, block=False):
740 742 """Clear the remote namespaces on my engines."""
741 743 block = block if block is not None else self.block
742 744 targets = targets if targets is not None else self.targets
743 745 return self.client.clear(targets=targets, block=block)
744 746
745 747 def kill(self, targets=None, block=True):
746 748 """Kill my engines."""
747 749 block = block if block is not None else self.block
748 750 targets = targets if targets is not None else self.targets
749 751 return self.client.kill(targets=targets, block=block)
750 752
751 753 #----------------------------------------
752 754 # activate for %px,%autopx magics
753 755 #----------------------------------------
754 756 def activate(self):
755 757 """Make this `View` active for parallel magic commands.
756 758
757 759 IPython has a magic command syntax to work with `MultiEngineClient` objects.
758 760 In a given IPython session there is a single active one. While
759 761 there can be many `Views` created and used by the user,
760 762 there is only one active one. The active `View` is used whenever
761 763 the magic commands %px and %autopx are used.
762 764
763 765 The activate() method is called on a given `View` to make it
764 766 active. Once this has been done, the magic commands can be used.
765 767 """
766 768
767 769 try:
768 770 # This is injected into __builtins__.
769 771 ip = get_ipython()
770 772 except NameError:
771 773 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
772 774 else:
773 775 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
774 776 if pmagic is None:
775 777 ip.magic_load_ext('parallelmagic')
776 778 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
777 779
778 780 pmagic.active_view = self
779 781
780 782
781 783 @skip_doctest
782 784 class LoadBalancedView(View):
783 785 """An load-balancing View that only executes via the Task scheduler.
784 786
785 787 Load-balanced views can be created with the client's `view` method:
786 788
787 789 >>> v = client.load_balanced_view()
788 790
789 791 or targets can be specified, to restrict the potential destinations:
790 792
791 793 >>> v = client.client.load_balanced_view([1,3])
792 794
793 795 which would restrict loadbalancing to between engines 1 and 3.
794 796
795 797 """
796 798
797 799 follow=Any()
798 800 after=Any()
799 801 timeout=CFloat()
800 802 retries = Integer(0)
801 803
802 804 _task_scheme = Any()
803 805 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
804 806
805 807 def __init__(self, client=None, socket=None, **flags):
806 808 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
807 809 self._task_scheme=client._task_scheme
808 810
809 811 def _validate_dependency(self, dep):
810 812 """validate a dependency.
811 813
812 814 For use in `set_flags`.
813 815 """
814 816 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
815 817 return True
816 818 elif isinstance(dep, (list,set, tuple)):
817 819 for d in dep:
818 820 if not isinstance(d, (basestring, AsyncResult)):
819 821 return False
820 822 elif isinstance(dep, dict):
821 823 if set(dep.keys()) != set(Dependency().as_dict().keys()):
822 824 return False
823 825 if not isinstance(dep['msg_ids'], list):
824 826 return False
825 827 for d in dep['msg_ids']:
826 828 if not isinstance(d, basestring):
827 829 return False
828 830 else:
829 831 return False
830 832
831 833 return True
832 834
833 835 def _render_dependency(self, dep):
834 836 """helper for building jsonable dependencies from various input forms."""
835 837 if isinstance(dep, Dependency):
836 838 return dep.as_dict()
837 839 elif isinstance(dep, AsyncResult):
838 840 return dep.msg_ids
839 841 elif dep is None:
840 842 return []
841 843 else:
842 844 # pass to Dependency constructor
843 845 return list(Dependency(dep))
844 846
845 847 def set_flags(self, **kwargs):
846 848 """set my attribute flags by keyword.
847 849
848 850 A View is a wrapper for the Client's apply method, but with attributes
849 851 that specify keyword arguments, those attributes can be set by keyword
850 852 argument with this method.
851 853
852 854 Parameters
853 855 ----------
854 856
855 857 block : bool
856 858 whether to wait for results
857 859 track : bool
858 860 whether to create a MessageTracker to allow the user to
859 861 safely edit after arrays and buffers during non-copying
860 862 sends.
861 863
862 864 after : Dependency or collection of msg_ids
863 865 Only for load-balanced execution (targets=None)
864 866 Specify a list of msg_ids as a time-based dependency.
865 867 This job will only be run *after* the dependencies
866 868 have been met.
867 869
868 870 follow : Dependency or collection of msg_ids
869 871 Only for load-balanced execution (targets=None)
870 872 Specify a list of msg_ids as a location-based dependency.
871 873 This job will only be run on an engine where this dependency
872 874 is met.
873 875
874 876 timeout : float/int or None
875 877 Only for load-balanced execution (targets=None)
876 878 Specify an amount of time (in seconds) for the scheduler to
877 879 wait for dependencies to be met before failing with a
878 880 DependencyTimeout.
879 881
880 882 retries : int
881 883 Number of times a task will be retried on failure.
882 884 """
883 885
884 886 super(LoadBalancedView, self).set_flags(**kwargs)
885 887 for name in ('follow', 'after'):
886 888 if name in kwargs:
887 889 value = kwargs[name]
888 890 if self._validate_dependency(value):
889 891 setattr(self, name, value)
890 892 else:
891 893 raise ValueError("Invalid dependency: %r"%value)
892 894 if 'timeout' in kwargs:
893 895 t = kwargs['timeout']
894 896 if not isinstance(t, (int, long, float, type(None))):
895 897 raise TypeError("Invalid type for timeout: %r"%type(t))
896 898 if t is not None:
897 899 if t < 0:
898 900 raise ValueError("Invalid timeout: %s"%t)
899 901 self.timeout = t
900 902
901 903 @sync_results
902 904 @save_ids
903 905 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
904 906 after=None, follow=None, timeout=None,
905 907 targets=None, retries=None):
906 908 """calls f(*args, **kwargs) on a remote engine, returning the result.
907 909
908 910 This method temporarily sets all of `apply`'s flags for a single call.
909 911
910 912 Parameters
911 913 ----------
912 914
913 915 f : callable
914 916
915 917 args : list [default: empty]
916 918
917 919 kwargs : dict [default: empty]
918 920
919 921 block : bool [default: self.block]
920 922 whether to block
921 923 track : bool [default: self.track]
922 924 whether to ask zmq to track the message, for safe non-copying sends
923 925
924 926 !!!!!! TODO: THE REST HERE !!!!
925 927
926 928 Returns
927 929 -------
928 930
929 931 if self.block is False:
930 932 returns AsyncResult
931 933 else:
932 934 returns actual result of f(*args, **kwargs) on the engine(s)
933 935 This will be a list of self.targets is also a list (even length 1), or
934 936 the single result if self.targets is an integer engine id
935 937 """
936 938
937 939 # validate whether we can run
938 940 if self._socket.closed:
939 941 msg = "Task farming is disabled"
940 942 if self._task_scheme == 'pure':
941 943 msg += " because the pure ZMQ scheduler cannot handle"
942 944 msg += " disappearing engines."
943 945 raise RuntimeError(msg)
944 946
945 947 if self._task_scheme == 'pure':
946 948 # pure zmq scheme doesn't support extra features
947 949 msg = "Pure ZMQ scheduler doesn't support the following flags:"
948 950 "follow, after, retries, targets, timeout"
949 951 if (follow or after or retries or targets or timeout):
950 952 # hard fail on Scheduler flags
951 953 raise RuntimeError(msg)
952 954 if isinstance(f, dependent):
953 955 # soft warn on functional dependencies
954 956 warnings.warn(msg, RuntimeWarning)
955 957
956 958 # build args
957 959 args = [] if args is None else args
958 960 kwargs = {} if kwargs is None else kwargs
959 961 block = self.block if block is None else block
960 962 track = self.track if track is None else track
961 963 after = self.after if after is None else after
962 964 retries = self.retries if retries is None else retries
963 965 follow = self.follow if follow is None else follow
964 966 timeout = self.timeout if timeout is None else timeout
965 967 targets = self.targets if targets is None else targets
966 968
967 969 if not isinstance(retries, int):
968 970 raise TypeError('retries must be int, not %r'%type(retries))
969 971
970 972 if targets is None:
971 973 idents = []
972 974 else:
973 975 idents = self.client._build_targets(targets)[0]
974 976 # ensure *not* bytes
975 977 idents = [ ident.decode() for ident in idents ]
976 978
977 979 after = self._render_dependency(after)
978 980 follow = self._render_dependency(follow)
979 981 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
980 982
981 983 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
982 984 subheader=subheader)
983 985 tracker = None if track is False else msg['tracker']
984 986
985 987 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
986 988
987 989 if block:
988 990 try:
989 991 return ar.get()
990 992 except KeyboardInterrupt:
991 993 pass
992 994 return ar
993 995
994 996 @spin_after
995 997 @save_ids
996 998 def map(self, f, *sequences, **kwargs):
997 999 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
998 1000
999 1001 Parallel version of builtin `map`, load-balanced by this View.
1000 1002
1001 1003 `block`, and `chunksize` can be specified by keyword only.
1002 1004
1003 1005 Each `chunksize` elements will be a separate task, and will be
1004 1006 load-balanced. This lets individual elements be available for iteration
1005 1007 as soon as they arrive.
1006 1008
1007 1009 Parameters
1008 1010 ----------
1009 1011
1010 1012 f : callable
1011 1013 function to be mapped
1012 1014 *sequences: one or more sequences of matching length
1013 1015 the sequences to be distributed and passed to `f`
1014 1016 block : bool [default self.block]
1015 1017 whether to wait for the result or not
1016 1018 track : bool
1017 1019 whether to create a MessageTracker to allow the user to
1018 1020 safely edit after arrays and buffers during non-copying
1019 1021 sends.
1020 1022 chunksize : int [default 1]
1021 1023 how many elements should be in each task.
1022 1024 ordered : bool [default True]
1023 1025 Whether the results should be gathered as they arrive, or enforce
1024 1026 the order of submission.
1025 1027
1026 1028 Only applies when iterating through AsyncMapResult as results arrive.
1027 1029 Has no effect when block=True.
1028 1030
1029 1031 Returns
1030 1032 -------
1031 1033
1032 1034 if block=False:
1033 1035 AsyncMapResult
1034 1036 An object like AsyncResult, but which reassembles the sequence of results
1035 1037 into a single list. AsyncMapResults can be iterated through before all
1036 1038 results are complete.
1037 1039 else:
1038 1040 the result of map(f,*sequences)
1039 1041
1040 1042 """
1041 1043
1042 1044 # default
1043 1045 block = kwargs.get('block', self.block)
1044 1046 chunksize = kwargs.get('chunksize', 1)
1045 1047 ordered = kwargs.get('ordered', True)
1046 1048
1047 1049 keyset = set(kwargs.keys())
1048 1050 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1049 1051 if extra_keys:
1050 1052 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1051 1053
1052 1054 assert len(sequences) > 0, "must have some sequences to map onto!"
1053 1055
1054 1056 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1055 1057 return pf.map(*sequences)
1056 1058
1057 1059 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,463 +1,472 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import time
21 21 from tempfile import mktemp
22 22 from StringIO import StringIO
23 23
24 24 import zmq
25 25 from nose import SkipTest
26 26
27 27 from IPython.testing import decorators as dec
28 28
29 29 from IPython import parallel as pmod
30 30 from IPython.parallel import error
31 31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
32 32 from IPython.parallel import DirectView
33 33 from IPython.parallel.util import interactive
34 34
35 35 from IPython.parallel.tests import add_engines
36 36
37 37 from .clienttest import ClusterTestCase, crash, wait, skip_without
38 38
39 39 def setup():
40 40 add_engines(3)
41 41
42 42 class TestView(ClusterTestCase):
43 43
44 44 def test_z_crash_mux(self):
45 45 """test graceful handling of engine death (direct)"""
46 46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
47 47 # self.add_engines(1)
48 48 eid = self.client.ids[-1]
49 49 ar = self.client[eid].apply_async(crash)
50 50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 51 eid = ar.engine_id
52 52 tic = time.time()
53 53 while eid in self.client.ids and time.time()-tic < 5:
54 54 time.sleep(.01)
55 55 self.client.spin()
56 56 self.assertFalse(eid in self.client.ids, "Engine should have died")
57 57
58 58 def test_push_pull(self):
59 59 """test pushing and pulling"""
60 60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
61 61 t = self.client.ids[-1]
62 62 v = self.client[t]
63 63 push = v.push
64 64 pull = v.pull
65 65 v.block=True
66 66 nengines = len(self.client)
67 67 push({'data':data})
68 68 d = pull('data')
69 69 self.assertEquals(d, data)
70 70 self.client[:].push({'data':data})
71 71 d = self.client[:].pull('data', block=True)
72 72 self.assertEquals(d, nengines*[data])
73 73 ar = push({'data':data}, block=False)
74 74 self.assertTrue(isinstance(ar, AsyncResult))
75 75 r = ar.get()
76 76 ar = self.client[:].pull('data', block=False)
77 77 self.assertTrue(isinstance(ar, AsyncResult))
78 78 r = ar.get()
79 79 self.assertEquals(r, nengines*[data])
80 80 self.client[:].push(dict(a=10,b=20))
81 81 r = self.client[:].pull(('a','b'), block=True)
82 82 self.assertEquals(r, nengines*[[10,20]])
83 83
84 84 def test_push_pull_function(self):
85 85 "test pushing and pulling functions"
86 86 def testf(x):
87 87 return 2.0*x
88 88
89 89 t = self.client.ids[-1]
90 90 v = self.client[t]
91 91 v.block=True
92 92 push = v.push
93 93 pull = v.pull
94 94 execute = v.execute
95 95 push({'testf':testf})
96 96 r = pull('testf')
97 97 self.assertEqual(r(1.0), testf(1.0))
98 98 execute('r = testf(10)')
99 99 r = pull('r')
100 100 self.assertEquals(r, testf(10))
101 101 ar = self.client[:].push({'testf':testf}, block=False)
102 102 ar.get()
103 103 ar = self.client[:].pull('testf', block=False)
104 104 rlist = ar.get()
105 105 for r in rlist:
106 106 self.assertEqual(r(1.0), testf(1.0))
107 107 execute("def g(x): return x*x")
108 108 r = pull(('testf','g'))
109 109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
110 110
111 111 def test_push_function_globals(self):
112 112 """test that pushed functions have access to globals"""
113 113 @interactive
114 114 def geta():
115 115 return a
116 116 # self.add_engines(1)
117 117 v = self.client[-1]
118 118 v.block=True
119 119 v['f'] = geta
120 120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 121 v.execute('a=5')
122 122 v.execute('b=f()')
123 123 self.assertEquals(v['b'], 5)
124 124
125 125 def test_push_function_defaults(self):
126 126 """test that pushed functions preserve default args"""
127 127 def echo(a=10):
128 128 return a
129 129 v = self.client[-1]
130 130 v.block=True
131 131 v['f'] = echo
132 132 v.execute('b=f()')
133 133 self.assertEquals(v['b'], 10)
134 134
135 135 def test_get_result(self):
136 136 """test getting results from the Hub."""
137 137 c = pmod.Client(profile='iptest')
138 138 # self.add_engines(1)
139 139 t = c.ids[-1]
140 140 v = c[t]
141 141 v2 = self.client[t]
142 142 ar = v.apply_async(wait, 1)
143 143 # give the monitor time to notice the message
144 144 time.sleep(.25)
145 145 ahr = v2.get_result(ar.msg_ids)
146 146 self.assertTrue(isinstance(ahr, AsyncHubResult))
147 147 self.assertEquals(ahr.get(), ar.get())
148 148 ar2 = v2.get_result(ar.msg_ids)
149 149 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 150 c.spin()
151 151 c.close()
152 152
153 153 def test_run_newline(self):
154 154 """test that run appends newline to files"""
155 155 tmpfile = mktemp()
156 156 with open(tmpfile, 'w') as f:
157 157 f.write("""def g():
158 158 return 5
159 159 """)
160 160 v = self.client[-1]
161 161 v.run(tmpfile, block=True)
162 162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
163 163
164 164 def test_apply_tracked(self):
165 165 """test tracking for apply"""
166 166 # self.add_engines(1)
167 167 t = self.client.ids[-1]
168 168 v = self.client[t]
169 169 v.block=False
170 170 def echo(n=1024*1024, **kwargs):
171 171 with v.temp_flags(**kwargs):
172 172 return v.apply(lambda x: x, 'x'*n)
173 173 ar = echo(1, track=False)
174 174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 175 self.assertTrue(ar.sent)
176 176 ar = echo(track=True)
177 177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 178 self.assertEquals(ar.sent, ar._tracker.done)
179 179 ar._tracker.wait()
180 180 self.assertTrue(ar.sent)
181 181
182 182 def test_push_tracked(self):
183 183 t = self.client.ids[-1]
184 184 ns = dict(x='x'*1024*1024)
185 185 v = self.client[t]
186 186 ar = v.push(ns, block=False, track=False)
187 187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 188 self.assertTrue(ar.sent)
189 189
190 190 ar = v.push(ns, block=False, track=True)
191 191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 192 ar._tracker.wait()
193 193 self.assertEquals(ar.sent, ar._tracker.done)
194 194 self.assertTrue(ar.sent)
195 195 ar.get()
196 196
197 197 def test_scatter_tracked(self):
198 198 t = self.client.ids
199 199 x='x'*1024*1024
200 200 ar = self.client[t].scatter('x', x, block=False, track=False)
201 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 202 self.assertTrue(ar.sent)
203 203
204 204 ar = self.client[t].scatter('x', x, block=False, track=True)
205 205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 206 self.assertEquals(ar.sent, ar._tracker.done)
207 207 ar._tracker.wait()
208 208 self.assertTrue(ar.sent)
209 209 ar.get()
210 210
211 211 def test_remote_reference(self):
212 212 v = self.client[-1]
213 213 v['a'] = 123
214 214 ra = pmod.Reference('a')
215 215 b = v.apply_sync(lambda x: x, ra)
216 216 self.assertEquals(b, 123)
217 217
218 218
219 219 def test_scatter_gather(self):
220 220 view = self.client[:]
221 221 seq1 = range(16)
222 222 view.scatter('a', seq1)
223 223 seq2 = view.gather('a', block=True)
224 224 self.assertEquals(seq2, seq1)
225 225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
226 226
227 227 @skip_without('numpy')
228 228 def test_scatter_gather_numpy(self):
229 229 import numpy
230 230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
231 231 view = self.client[:]
232 232 a = numpy.arange(64)
233 233 view.scatter('a', a)
234 234 b = view.gather('a', block=True)
235 235 assert_array_equal(b, a)
236 236
237 237 def test_map(self):
238 238 view = self.client[:]
239 239 def f(x):
240 240 return x**2
241 241 data = range(16)
242 242 r = view.map_sync(f, data)
243 243 self.assertEquals(r, map(f, data))
244 244
245 245 def test_map_iterable(self):
246 246 """test map on iterables (direct)"""
247 247 view = self.client[:]
248 248 # 101 is prime, so it won't be evenly distributed
249 249 arr = range(101)
250 250 # ensure it will be an iterator, even in Python 3
251 251 it = iter(arr)
252 252 r = view.map_sync(lambda x:x, arr)
253 253 self.assertEquals(r, list(arr))
254 254
255 255 def test_scatterGatherNonblocking(self):
256 256 data = range(16)
257 257 view = self.client[:]
258 258 view.scatter('a', data, block=False)
259 259 ar = view.gather('a', block=False)
260 260 self.assertEquals(ar.get(), data)
261 261
262 262 @skip_without('numpy')
263 263 def test_scatter_gather_numpy_nonblocking(self):
264 264 import numpy
265 265 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
266 266 a = numpy.arange(64)
267 267 view = self.client[:]
268 268 ar = view.scatter('a', a, block=False)
269 269 self.assertTrue(isinstance(ar, AsyncResult))
270 270 amr = view.gather('a', block=False)
271 271 self.assertTrue(isinstance(amr, AsyncMapResult))
272 272 assert_array_equal(amr.get(), a)
273 273
274 274 def test_execute(self):
275 275 view = self.client[:]
276 276 # self.client.debug=True
277 277 execute = view.execute
278 278 ar = execute('c=30', block=False)
279 279 self.assertTrue(isinstance(ar, AsyncResult))
280 280 ar = execute('d=[0,1,2]', block=False)
281 281 self.client.wait(ar, 1)
282 282 self.assertEquals(len(ar.get()), len(self.client))
283 283 for c in view['c']:
284 284 self.assertEquals(c, 30)
285 285
286 286 def test_abort(self):
287 287 view = self.client[-1]
288 288 ar = view.execute('import time; time.sleep(1)', block=False)
289 289 ar2 = view.apply_async(lambda : 2)
290 290 ar3 = view.apply_async(lambda : 3)
291 291 view.abort(ar2)
292 292 view.abort(ar3.msg_ids)
293 293 self.assertRaises(error.TaskAborted, ar2.get)
294 294 self.assertRaises(error.TaskAborted, ar3.get)
295
295
296 def test_abort_all(self):
297 """view.abort() aborts all outstanding tasks"""
298 view = self.client[-1]
299 ars = [ view.apply_async(time.sleep, 1) for i in range(10) ]
300 view.abort()
301 view.wait(timeout=5)
302 for ar in ars[5:]:
303 self.assertRaises(error.TaskAborted, ar.get)
304
296 305 def test_temp_flags(self):
297 306 view = self.client[-1]
298 307 view.block=True
299 308 with view.temp_flags(block=False):
300 309 self.assertFalse(view.block)
301 310 self.assertTrue(view.block)
302 311
303 312 @dec.known_failure_py3
304 313 def test_importer(self):
305 314 view = self.client[-1]
306 315 view.clear(block=True)
307 316 with view.importer:
308 317 import re
309 318
310 319 @interactive
311 320 def findall(pat, s):
312 321 # this globals() step isn't necessary in real code
313 322 # only to prevent a closure in the test
314 323 re = globals()['re']
315 324 return re.findall(pat, s)
316 325
317 326 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
318 327
319 328 # parallel magic tests
320 329
321 330 def test_magic_px_blocking(self):
322 331 ip = get_ipython()
323 332 v = self.client[-1]
324 333 v.activate()
325 334 v.block=True
326 335
327 336 ip.magic_px('a=5')
328 337 self.assertEquals(v['a'], 5)
329 338 ip.magic_px('a=10')
330 339 self.assertEquals(v['a'], 10)
331 340 sio = StringIO()
332 341 savestdout = sys.stdout
333 342 sys.stdout = sio
334 343 # just 'print a' worst ~99% of the time, but this ensures that
335 344 # the stdout message has arrived when the result is finished:
336 345 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
337 346 sys.stdout = savestdout
338 347 buf = sio.getvalue()
339 348 self.assertTrue('[stdout:' in buf, buf)
340 349 self.assertTrue(buf.rstrip().endswith('10'))
341 350 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
342 351
343 352 def test_magic_px_nonblocking(self):
344 353 ip = get_ipython()
345 354 v = self.client[-1]
346 355 v.activate()
347 356 v.block=False
348 357
349 358 ip.magic_px('a=5')
350 359 self.assertEquals(v['a'], 5)
351 360 ip.magic_px('a=10')
352 361 self.assertEquals(v['a'], 10)
353 362 sio = StringIO()
354 363 savestdout = sys.stdout
355 364 sys.stdout = sio
356 365 ip.magic_px('print a')
357 366 sys.stdout = savestdout
358 367 buf = sio.getvalue()
359 368 self.assertFalse('[stdout:%i]'%v.targets in buf)
360 369 ip.magic_px('1/0')
361 370 ar = v.get_result(-1)
362 371 self.assertRaisesRemote(ZeroDivisionError, ar.get)
363 372
364 373 def test_magic_autopx_blocking(self):
365 374 ip = get_ipython()
366 375 v = self.client[-1]
367 376 v.activate()
368 377 v.block=True
369 378
370 379 sio = StringIO()
371 380 savestdout = sys.stdout
372 381 sys.stdout = sio
373 382 ip.magic_autopx()
374 383 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
375 384 ip.run_cell('print b')
376 385 ip.run_cell("b/c")
377 386 ip.run_code(compile('b*=2', '', 'single'))
378 387 ip.magic_autopx()
379 388 sys.stdout = savestdout
380 389 output = sio.getvalue().strip()
381 390 self.assertTrue(output.startswith('%autopx enabled'))
382 391 self.assertTrue(output.endswith('%autopx disabled'))
383 392 self.assertTrue('RemoteError: ZeroDivisionError' in output)
384 393 ar = v.get_result(-2)
385 394 self.assertEquals(v['a'], 5)
386 395 self.assertEquals(v['b'], 20)
387 396 self.assertRaisesRemote(ZeroDivisionError, ar.get)
388 397
389 398 def test_magic_autopx_nonblocking(self):
390 399 ip = get_ipython()
391 400 v = self.client[-1]
392 401 v.activate()
393 402 v.block=False
394 403
395 404 sio = StringIO()
396 405 savestdout = sys.stdout
397 406 sys.stdout = sio
398 407 ip.magic_autopx()
399 408 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
400 409 ip.run_cell('print b')
401 410 ip.run_cell("b/c")
402 411 ip.run_code(compile('b*=2', '', 'single'))
403 412 ip.magic_autopx()
404 413 sys.stdout = savestdout
405 414 output = sio.getvalue().strip()
406 415 self.assertTrue(output.startswith('%autopx enabled'))
407 416 self.assertTrue(output.endswith('%autopx disabled'))
408 417 self.assertFalse('ZeroDivisionError' in output)
409 418 ar = v.get_result(-2)
410 419 self.assertEquals(v['a'], 5)
411 420 self.assertEquals(v['b'], 20)
412 421 self.assertRaisesRemote(ZeroDivisionError, ar.get)
413 422
414 423 def test_magic_result(self):
415 424 ip = get_ipython()
416 425 v = self.client[-1]
417 426 v.activate()
418 427 v['a'] = 111
419 428 ra = v['a']
420 429
421 430 ar = ip.magic_result()
422 431 self.assertEquals(ar.msg_ids, [v.history[-1]])
423 432 self.assertEquals(ar.get(), 111)
424 433 ar = ip.magic_result('-2')
425 434 self.assertEquals(ar.msg_ids, [v.history[-2]])
426 435
427 436 def test_unicode_execute(self):
428 437 """test executing unicode strings"""
429 438 v = self.client[-1]
430 439 v.block=True
431 440 if sys.version_info[0] >= 3:
432 441 code="a='é'"
433 442 else:
434 443 code=u"a=u'é'"
435 444 v.execute(code)
436 445 self.assertEquals(v['a'], u'é')
437 446
438 447 def test_unicode_apply_result(self):
439 448 """test unicode apply results"""
440 449 v = self.client[-1]
441 450 r = v.apply_sync(lambda : u'é')
442 451 self.assertEquals(r, u'é')
443 452
444 453 def test_unicode_apply_arg(self):
445 454 """test passing unicode arguments to apply"""
446 455 v = self.client[-1]
447 456
448 457 @interactive
449 458 def check_unicode(a, check):
450 459 assert isinstance(a, unicode), "%r is not unicode"%a
451 460 assert isinstance(check, bytes), "%r is not bytes"%check
452 461 assert a.encode('utf8') == check, "%s != %s"%(a,check)
453 462
454 463 for s in [ u'é', u'ßø®∫',u'asdf' ]:
455 464 try:
456 465 v.apply_sync(check_unicode, s, s.encode('utf8'))
457 466 except error.RemoteError as e:
458 467 if e.ename == 'AssertionError':
459 468 self.fail(e.evalue)
460 469 else:
461 470 raise e
462 471
463 472
General Comments 0
You need to be logged in to leave comments. Login now