##// END OF EJS Templates
allow Reference as callable in map/apply...
MinRK -
Show More
@@ -1,1443 +1,1444 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 import time
22 22 import warnings
23 23 from datetime import datetime
24 24 from getpass import getpass
25 25 from pprint import pprint
26 26
27 27 pjoin = os.path.join
28 28
29 29 import zmq
30 30 # from zmq.eventloop import ioloop, zmqstream
31 31
32 32 from IPython.config.configurable import MultipleInstanceError
33 33 from IPython.core.application import BaseIPythonApplication
34 34
35 35 from IPython.utils.jsonutil import rekey
36 36 from IPython.utils.localinterfaces import LOCAL_IPS
37 37 from IPython.utils.path import get_ipython_dir
38 38 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
39 39 Dict, List, Bool, Set)
40 40 from IPython.external.decorator import decorator
41 41 from IPython.external.ssh import tunnel
42 42
43 from IPython.parallel import Reference
43 44 from IPython.parallel import error
44 45 from IPython.parallel import util
45 46
46 47 from IPython.zmq.session import Session, Message
47 48
48 49 from .asyncresult import AsyncResult, AsyncHubResult
49 50 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 51 from .view import DirectView, LoadBalancedView
51 52
52 53 if sys.version_info[0] >= 3:
53 54 # xrange is used in a couple 'isinstance' tests in py2
54 55 # should be just 'range' in 3k
55 56 xrange = range
56 57
57 58 #--------------------------------------------------------------------------
58 59 # Decorators for Client methods
59 60 #--------------------------------------------------------------------------
60 61
61 62 @decorator
62 63 def spin_first(f, self, *args, **kwargs):
63 64 """Call spin() to sync state prior to calling the method."""
64 65 self.spin()
65 66 return f(self, *args, **kwargs)
66 67
67 68
68 69 #--------------------------------------------------------------------------
69 70 # Classes
70 71 #--------------------------------------------------------------------------
71 72
72 73 class Metadata(dict):
73 74 """Subclass of dict for initializing metadata values.
74 75
75 76 Attribute access works on keys.
76 77
77 78 These objects have a strict set of keys - errors will raise if you try
78 79 to add new keys.
79 80 """
80 81 def __init__(self, *args, **kwargs):
81 82 dict.__init__(self)
82 83 md = {'msg_id' : None,
83 84 'submitted' : None,
84 85 'started' : None,
85 86 'completed' : None,
86 87 'received' : None,
87 88 'engine_uuid' : None,
88 89 'engine_id' : None,
89 90 'follow' : None,
90 91 'after' : None,
91 92 'status' : None,
92 93
93 94 'pyin' : None,
94 95 'pyout' : None,
95 96 'pyerr' : None,
96 97 'stdout' : '',
97 98 'stderr' : '',
98 99 }
99 100 self.update(md)
100 101 self.update(dict(*args, **kwargs))
101 102
102 103 def __getattr__(self, key):
103 104 """getattr aliased to getitem"""
104 105 if key in self.iterkeys():
105 106 return self[key]
106 107 else:
107 108 raise AttributeError(key)
108 109
109 110 def __setattr__(self, key, value):
110 111 """setattr aliased to setitem, with strict"""
111 112 if key in self.iterkeys():
112 113 self[key] = value
113 114 else:
114 115 raise AttributeError(key)
115 116
116 117 def __setitem__(self, key, value):
117 118 """strict static key enforcement"""
118 119 if key in self.iterkeys():
119 120 dict.__setitem__(self, key, value)
120 121 else:
121 122 raise KeyError(key)
122 123
123 124
124 125 class Client(HasTraits):
125 126 """A semi-synchronous client to the IPython ZMQ cluster
126 127
127 128 Parameters
128 129 ----------
129 130
130 131 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 132 Connection information for the Hub's registration. If a json connector
132 133 file is given, then likely no further configuration is necessary.
133 134 [Default: use profile]
134 135 profile : bytes
135 136 The name of the Cluster profile to be used to find connector information.
136 137 If run from an IPython application, the default profile will be the same
137 138 as the running application, otherwise it will be 'default'.
138 139 context : zmq.Context
139 140 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 141 debug : bool
141 142 flag for lots of message printing for debug purposes
142 143 timeout : int/float
143 144 time (in seconds) to wait for connection replies from the Hub
144 145 [Default: 10]
145 146
146 147 #-------------- session related args ----------------
147 148
148 149 config : Config object
149 150 If specified, this will be relayed to the Session for configuration
150 151 username : str
151 152 set username for the session object
152 153 packer : str (import_string) or callable
153 154 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 155 function to serialize messages. Must support same input as
155 156 JSON, and output must be bytes.
156 157 You can pass a callable directly as `pack`
157 158 unpacker : str (import_string) or callable
158 159 The inverse of packer. Only necessary if packer is specified as *not* one
159 160 of 'json' or 'pickle'.
160 161
161 162 #-------------- ssh related args ----------------
162 163 # These are args for configuring the ssh tunnel to be used
163 164 # credentials are used to forward connections over ssh to the Controller
164 165 # Note that the ip given in `addr` needs to be relative to sshserver
165 166 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 167 # and set sshserver as the same machine the Controller is on. However,
167 168 # the only requirement is that sshserver is able to see the Controller
168 169 # (i.e. is within the same trusted network).
169 170
170 171 sshserver : str
171 172 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 173 If keyfile or password is specified, and this is not, it will default to
173 174 the ip given in addr.
174 175 sshkey : str; path to ssh private key file
175 176 This specifies a key to be used in ssh login, default None.
176 177 Regular default ssh keys will be used without specifying this argument.
177 178 password : str
178 179 Your ssh password to sshserver. Note that if this is left None,
179 180 you will be prompted for it if passwordless key based login is unavailable.
180 181 paramiko : bool
181 182 flag for whether to use paramiko instead of shell ssh for tunneling.
182 183 [default: True on win32, False else]
183 184
184 185 ------- exec authentication args -------
185 186 If even localhost is untrusted, you can have some protection against
186 187 unauthorized execution by signing messages with HMAC digests.
187 188 Messages are still sent as cleartext, so if someone can snoop your
188 189 loopback traffic this will not protect your privacy, but will prevent
189 190 unauthorized execution.
190 191
191 192 exec_key : str
192 193 an authentication key or file containing a key
193 194 default: None
194 195
195 196
196 197 Attributes
197 198 ----------
198 199
199 200 ids : list of int engine IDs
200 201 requesting the ids attribute always synchronizes
201 202 the registration state. To request ids without synchronization,
202 203 use semi-private _ids attributes.
203 204
204 205 history : list of msg_ids
205 206 a list of msg_ids, keeping track of all the execution
206 207 messages you have submitted in order.
207 208
208 209 outstanding : set of msg_ids
209 210 a set of msg_ids that have been submitted, but whose
210 211 results have not yet been received.
211 212
212 213 results : dict
213 214 a dict of all our results, keyed by msg_id
214 215
215 216 block : bool
216 217 determines default behavior when block not specified
217 218 in execution methods
218 219
219 220 Methods
220 221 -------
221 222
222 223 spin
223 224 flushes incoming results and registration state changes
224 225 control methods spin, and requesting `ids` also ensures up to date
225 226
226 227 wait
227 228 wait on one or more msg_ids
228 229
229 230 execution methods
230 231 apply
231 232 legacy: execute, run
232 233
233 234 data movement
234 235 push, pull, scatter, gather
235 236
236 237 query methods
237 238 queue_status, get_result, purge, result_status
238 239
239 240 control methods
240 241 abort, shutdown
241 242
242 243 """
243 244
244 245
245 246 block = Bool(False)
246 247 outstanding = Set()
247 248 results = Instance('collections.defaultdict', (dict,))
248 249 metadata = Instance('collections.defaultdict', (Metadata,))
249 250 history = List()
250 251 debug = Bool(False)
251 252
252 253 profile=Unicode()
253 254 def _profile_default(self):
254 255 if BaseIPythonApplication.initialized():
255 256 # an IPython app *might* be running, try to get its profile
256 257 try:
257 258 return BaseIPythonApplication.instance().profile
258 259 except (AttributeError, MultipleInstanceError):
259 260 # could be a *different* subclass of config.Application,
260 261 # which would raise one of these two errors.
261 262 return u'default'
262 263 else:
263 264 return u'default'
264 265
265 266
266 267 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 268 _ids = List()
268 269 _connected=Bool(False)
269 270 _ssh=Bool(False)
270 271 _context = Instance('zmq.Context')
271 272 _config = Dict()
272 273 _engines=Instance(util.ReverseDict, (), {})
273 274 # _hub_socket=Instance('zmq.Socket')
274 275 _query_socket=Instance('zmq.Socket')
275 276 _control_socket=Instance('zmq.Socket')
276 277 _iopub_socket=Instance('zmq.Socket')
277 278 _notification_socket=Instance('zmq.Socket')
278 279 _mux_socket=Instance('zmq.Socket')
279 280 _task_socket=Instance('zmq.Socket')
280 281 _task_scheme=Unicode()
281 282 _closed = False
282 283 _ignored_control_replies=Integer(0)
283 284 _ignored_hub_replies=Integer(0)
284 285
285 286 def __new__(self, *args, **kw):
286 287 # don't raise on positional args
287 288 return HasTraits.__new__(self, **kw)
288 289
289 290 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 291 context=None, debug=False, exec_key=None,
291 292 sshserver=None, sshkey=None, password=None, paramiko=None,
292 293 timeout=10, **extra_args
293 294 ):
294 295 if profile:
295 296 super(Client, self).__init__(debug=debug, profile=profile)
296 297 else:
297 298 super(Client, self).__init__(debug=debug)
298 299 if context is None:
299 300 context = zmq.Context.instance()
300 301 self._context = context
301 302
302 303 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 304 if self._cd is not None:
304 305 if url_or_file is None:
305 306 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 307 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 308 " Please specify at least one of url_or_file or profile."
308 309
309 310 if not util.is_url(url_or_file):
310 311 # it's not a url, try for a file
311 312 if not os.path.exists(url_or_file):
312 313 if self._cd:
313 314 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
314 315 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
315 316 with open(url_or_file) as f:
316 317 cfg = json.loads(f.read())
317 318 else:
318 319 cfg = {'url':url_or_file}
319 320
320 321 # sync defaults from args, json:
321 322 if sshserver:
322 323 cfg['ssh'] = sshserver
323 324 if exec_key:
324 325 cfg['exec_key'] = exec_key
325 326 exec_key = cfg['exec_key']
326 327 location = cfg.setdefault('location', None)
327 328 cfg['url'] = util.disambiguate_url(cfg['url'], location)
328 329 url = cfg['url']
329 330 proto,addr,port = util.split_url(url)
330 331 if location is not None and addr == '127.0.0.1':
331 332 # location specified, and connection is expected to be local
332 333 if location not in LOCAL_IPS and not sshserver:
333 334 # load ssh from JSON *only* if the controller is not on
334 335 # this machine
335 336 sshserver=cfg['ssh']
336 337 if location not in LOCAL_IPS and not sshserver:
337 338 # warn if no ssh specified, but SSH is probably needed
338 339 # This is only a warning, because the most likely cause
339 340 # is a local Controller on a laptop whose IP is dynamic
340 341 warnings.warn("""
341 342 Controller appears to be listening on localhost, but not on this machine.
342 343 If this is true, you should specify Client(...,sshserver='you@%s')
343 344 or instruct your controller to listen on an external IP."""%location,
344 345 RuntimeWarning)
345 346 elif not sshserver:
346 347 # otherwise sync with cfg
347 348 sshserver = cfg['ssh']
348 349
349 350 self._config = cfg
350 351
351 352 self._ssh = bool(sshserver or sshkey or password)
352 353 if self._ssh and sshserver is None:
353 354 # default to ssh via localhost
354 355 sshserver = url.split('://')[1].split(':')[0]
355 356 if self._ssh and password is None:
356 357 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
357 358 password=False
358 359 else:
359 360 password = getpass("SSH Password for %s: "%sshserver)
360 361 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
361 362
362 363 # configure and construct the session
363 364 if exec_key is not None:
364 365 if os.path.isfile(exec_key):
365 366 extra_args['keyfile'] = exec_key
366 367 else:
367 368 exec_key = util.asbytes(exec_key)
368 369 extra_args['key'] = exec_key
369 370 self.session = Session(**extra_args)
370 371
371 372 self._query_socket = self._context.socket(zmq.DEALER)
372 373 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
373 374 if self._ssh:
374 375 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
375 376 else:
376 377 self._query_socket.connect(url)
377 378
378 379 self.session.debug = self.debug
379 380
380 381 self._notification_handlers = {'registration_notification' : self._register_engine,
381 382 'unregistration_notification' : self._unregister_engine,
382 383 'shutdown_notification' : lambda msg: self.close(),
383 384 }
384 385 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
385 386 'apply_reply' : self._handle_apply_reply}
386 387 self._connect(sshserver, ssh_kwargs, timeout)
387 388
388 389 def __del__(self):
389 390 """cleanup sockets, but _not_ context."""
390 391 self.close()
391 392
392 393 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
393 394 if ipython_dir is None:
394 395 ipython_dir = get_ipython_dir()
395 396 if profile_dir is not None:
396 397 try:
397 398 self._cd = ProfileDir.find_profile_dir(profile_dir)
398 399 return
399 400 except ProfileDirError:
400 401 pass
401 402 elif profile is not None:
402 403 try:
403 404 self._cd = ProfileDir.find_profile_dir_by_name(
404 405 ipython_dir, profile)
405 406 return
406 407 except ProfileDirError:
407 408 pass
408 409 self._cd = None
409 410
410 411 def _update_engines(self, engines):
411 412 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
412 413 for k,v in engines.iteritems():
413 414 eid = int(k)
414 415 self._engines[eid] = v
415 416 self._ids.append(eid)
416 417 self._ids = sorted(self._ids)
417 418 if sorted(self._engines.keys()) != range(len(self._engines)) and \
418 419 self._task_scheme == 'pure' and self._task_socket:
419 420 self._stop_scheduling_tasks()
420 421
421 422 def _stop_scheduling_tasks(self):
422 423 """Stop scheduling tasks because an engine has been unregistered
423 424 from a pure ZMQ scheduler.
424 425 """
425 426 self._task_socket.close()
426 427 self._task_socket = None
427 428 msg = "An engine has been unregistered, and we are using pure " +\
428 429 "ZMQ task scheduling. Task farming will be disabled."
429 430 if self.outstanding:
430 431 msg += " If you were running tasks when this happened, " +\
431 432 "some `outstanding` msg_ids may never resolve."
432 433 warnings.warn(msg, RuntimeWarning)
433 434
434 435 def _build_targets(self, targets):
435 436 """Turn valid target IDs or 'all' into two lists:
436 437 (int_ids, uuids).
437 438 """
438 439 if not self._ids:
439 440 # flush notification socket if no engines yet, just in case
440 441 if not self.ids:
441 442 raise error.NoEnginesRegistered("Can't build targets without any engines")
442 443
443 444 if targets is None:
444 445 targets = self._ids
445 446 elif isinstance(targets, basestring):
446 447 if targets.lower() == 'all':
447 448 targets = self._ids
448 449 else:
449 450 raise TypeError("%r not valid str target, must be 'all'"%(targets))
450 451 elif isinstance(targets, int):
451 452 if targets < 0:
452 453 targets = self.ids[targets]
453 454 if targets not in self._ids:
454 455 raise IndexError("No such engine: %i"%targets)
455 456 targets = [targets]
456 457
457 458 if isinstance(targets, slice):
458 459 indices = range(len(self._ids))[targets]
459 460 ids = self.ids
460 461 targets = [ ids[i] for i in indices ]
461 462
462 463 if not isinstance(targets, (tuple, list, xrange)):
463 464 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
464 465
465 466 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
466 467
467 468 def _connect(self, sshserver, ssh_kwargs, timeout):
468 469 """setup all our socket connections to the cluster. This is called from
469 470 __init__."""
470 471
471 472 # Maybe allow reconnecting?
472 473 if self._connected:
473 474 return
474 475 self._connected=True
475 476
476 477 def connect_socket(s, url):
477 478 url = util.disambiguate_url(url, self._config['location'])
478 479 if self._ssh:
479 480 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
480 481 else:
481 482 return s.connect(url)
482 483
483 484 self.session.send(self._query_socket, 'connection_request')
484 485 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
485 486 poller = zmq.Poller()
486 487 poller.register(self._query_socket, zmq.POLLIN)
487 488 # poll expects milliseconds, timeout is seconds
488 489 evts = poller.poll(timeout*1000)
489 490 if not evts:
490 491 raise error.TimeoutError("Hub connection request timed out")
491 492 idents,msg = self.session.recv(self._query_socket,mode=0)
492 493 if self.debug:
493 494 pprint(msg)
494 495 msg = Message(msg)
495 496 content = msg.content
496 497 self._config['registration'] = dict(content)
497 498 if content.status == 'ok':
498 499 ident = self.session.bsession
499 500 if content.mux:
500 501 self._mux_socket = self._context.socket(zmq.DEALER)
501 502 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
502 503 connect_socket(self._mux_socket, content.mux)
503 504 if content.task:
504 505 self._task_scheme, task_addr = content.task
505 506 self._task_socket = self._context.socket(zmq.DEALER)
506 507 self._task_socket.setsockopt(zmq.IDENTITY, ident)
507 508 connect_socket(self._task_socket, task_addr)
508 509 if content.notification:
509 510 self._notification_socket = self._context.socket(zmq.SUB)
510 511 connect_socket(self._notification_socket, content.notification)
511 512 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
512 513 # if content.query:
513 514 # self._query_socket = self._context.socket(zmq.DEALER)
514 515 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
515 516 # connect_socket(self._query_socket, content.query)
516 517 if content.control:
517 518 self._control_socket = self._context.socket(zmq.DEALER)
518 519 self._control_socket.setsockopt(zmq.IDENTITY, ident)
519 520 connect_socket(self._control_socket, content.control)
520 521 if content.iopub:
521 522 self._iopub_socket = self._context.socket(zmq.SUB)
522 523 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
523 524 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
524 525 connect_socket(self._iopub_socket, content.iopub)
525 526 self._update_engines(dict(content.engines))
526 527 else:
527 528 self._connected = False
528 529 raise Exception("Failed to connect!")
529 530
530 531 #--------------------------------------------------------------------------
531 532 # handlers and callbacks for incoming messages
532 533 #--------------------------------------------------------------------------
533 534
534 535 def _unwrap_exception(self, content):
535 536 """unwrap exception, and remap engine_id to int."""
536 537 e = error.unwrap_exception(content)
537 538 # print e.traceback
538 539 if e.engine_info:
539 540 e_uuid = e.engine_info['engine_uuid']
540 541 eid = self._engines[e_uuid]
541 542 e.engine_info['engine_id'] = eid
542 543 return e
543 544
544 545 def _extract_metadata(self, header, parent, content):
545 546 md = {'msg_id' : parent['msg_id'],
546 547 'received' : datetime.now(),
547 548 'engine_uuid' : header.get('engine', None),
548 549 'follow' : parent.get('follow', []),
549 550 'after' : parent.get('after', []),
550 551 'status' : content['status'],
551 552 }
552 553
553 554 if md['engine_uuid'] is not None:
554 555 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
555 556
556 557 if 'date' in parent:
557 558 md['submitted'] = parent['date']
558 559 if 'started' in header:
559 560 md['started'] = header['started']
560 561 if 'date' in header:
561 562 md['completed'] = header['date']
562 563 return md
563 564
564 565 def _register_engine(self, msg):
565 566 """Register a new engine, and update our connection info."""
566 567 content = msg['content']
567 568 eid = content['id']
568 569 d = {eid : content['queue']}
569 570 self._update_engines(d)
570 571
571 572 def _unregister_engine(self, msg):
572 573 """Unregister an engine that has died."""
573 574 content = msg['content']
574 575 eid = int(content['id'])
575 576 if eid in self._ids:
576 577 self._ids.remove(eid)
577 578 uuid = self._engines.pop(eid)
578 579
579 580 self._handle_stranded_msgs(eid, uuid)
580 581
581 582 if self._task_socket and self._task_scheme == 'pure':
582 583 self._stop_scheduling_tasks()
583 584
584 585 def _handle_stranded_msgs(self, eid, uuid):
585 586 """Handle messages known to be on an engine when the engine unregisters.
586 587
587 588 It is possible that this will fire prematurely - that is, an engine will
588 589 go down after completing a result, and the client will be notified
589 590 of the unregistration and later receive the successful result.
590 591 """
591 592
592 593 outstanding = self._outstanding_dict[uuid]
593 594
594 595 for msg_id in list(outstanding):
595 596 if msg_id in self.results:
596 597 # we already
597 598 continue
598 599 try:
599 600 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
600 601 except:
601 602 content = error.wrap_exception()
602 603 # build a fake message:
603 604 parent = {}
604 605 header = {}
605 606 parent['msg_id'] = msg_id
606 607 header['engine'] = uuid
607 608 header['date'] = datetime.now()
608 609 msg = dict(parent_header=parent, header=header, content=content)
609 610 self._handle_apply_reply(msg)
610 611
611 612 def _handle_execute_reply(self, msg):
612 613 """Save the reply to an execute_request into our results.
613 614
614 615 execute messages are never actually used. apply is used instead.
615 616 """
616 617
617 618 parent = msg['parent_header']
618 619 msg_id = parent['msg_id']
619 620 if msg_id not in self.outstanding:
620 621 if msg_id in self.history:
621 622 print ("got stale result: %s"%msg_id)
622 623 else:
623 624 print ("got unknown result: %s"%msg_id)
624 625 else:
625 626 self.outstanding.remove(msg_id)
626 627 self.results[msg_id] = self._unwrap_exception(msg['content'])
627 628
628 629 def _handle_apply_reply(self, msg):
629 630 """Save the reply to an apply_request into our results."""
630 631 parent = msg['parent_header']
631 632 msg_id = parent['msg_id']
632 633 if msg_id not in self.outstanding:
633 634 if msg_id in self.history:
634 635 print ("got stale result: %s"%msg_id)
635 636 print self.results[msg_id]
636 637 print msg
637 638 else:
638 639 print ("got unknown result: %s"%msg_id)
639 640 else:
640 641 self.outstanding.remove(msg_id)
641 642 content = msg['content']
642 643 header = msg['header']
643 644
644 645 # construct metadata:
645 646 md = self.metadata[msg_id]
646 647 md.update(self._extract_metadata(header, parent, content))
647 648 # is this redundant?
648 649 self.metadata[msg_id] = md
649 650
650 651 e_outstanding = self._outstanding_dict[md['engine_uuid']]
651 652 if msg_id in e_outstanding:
652 653 e_outstanding.remove(msg_id)
653 654
654 655 # construct result:
655 656 if content['status'] == 'ok':
656 657 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
657 658 elif content['status'] == 'aborted':
658 659 self.results[msg_id] = error.TaskAborted(msg_id)
659 660 elif content['status'] == 'resubmitted':
660 661 # TODO: handle resubmission
661 662 pass
662 663 else:
663 664 self.results[msg_id] = self._unwrap_exception(content)
664 665
665 666 def _flush_notifications(self):
666 667 """Flush notifications of engine registrations waiting
667 668 in ZMQ queue."""
668 669 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
669 670 while msg is not None:
670 671 if self.debug:
671 672 pprint(msg)
672 673 msg_type = msg['header']['msg_type']
673 674 handler = self._notification_handlers.get(msg_type, None)
674 675 if handler is None:
675 676 raise Exception("Unhandled message type: %s"%msg.msg_type)
676 677 else:
677 678 handler(msg)
678 679 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
679 680
680 681 def _flush_results(self, sock):
681 682 """Flush task or queue results waiting in ZMQ queue."""
682 683 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
683 684 while msg is not None:
684 685 if self.debug:
685 686 pprint(msg)
686 687 msg_type = msg['header']['msg_type']
687 688 handler = self._queue_handlers.get(msg_type, None)
688 689 if handler is None:
689 690 raise Exception("Unhandled message type: %s"%msg.msg_type)
690 691 else:
691 692 handler(msg)
692 693 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
693 694
694 695 def _flush_control(self, sock):
695 696 """Flush replies from the control channel waiting
696 697 in the ZMQ queue.
697 698
698 699 Currently: ignore them."""
699 700 if self._ignored_control_replies <= 0:
700 701 return
701 702 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
702 703 while msg is not None:
703 704 self._ignored_control_replies -= 1
704 705 if self.debug:
705 706 pprint(msg)
706 707 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
707 708
708 709 def _flush_ignored_control(self):
709 710 """flush ignored control replies"""
710 711 while self._ignored_control_replies > 0:
711 712 self.session.recv(self._control_socket)
712 713 self._ignored_control_replies -= 1
713 714
714 715 def _flush_ignored_hub_replies(self):
715 716 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
716 717 while msg is not None:
717 718 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
718 719
719 720 def _flush_iopub(self, sock):
720 721 """Flush replies from the iopub channel waiting
721 722 in the ZMQ queue.
722 723 """
723 724 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
724 725 while msg is not None:
725 726 if self.debug:
726 727 pprint(msg)
727 728 parent = msg['parent_header']
728 729 msg_id = parent['msg_id']
729 730 content = msg['content']
730 731 header = msg['header']
731 732 msg_type = msg['header']['msg_type']
732 733
733 734 # init metadata:
734 735 md = self.metadata[msg_id]
735 736
736 737 if msg_type == 'stream':
737 738 name = content['name']
738 739 s = md[name] or ''
739 740 md[name] = s + content['data']
740 741 elif msg_type == 'pyerr':
741 742 md.update({'pyerr' : self._unwrap_exception(content)})
742 743 elif msg_type == 'pyin':
743 744 md.update({'pyin' : content['code']})
744 745 else:
745 746 md.update({msg_type : content.get('data', '')})
746 747
747 748 # reduntant?
748 749 self.metadata[msg_id] = md
749 750
750 751 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
751 752
752 753 #--------------------------------------------------------------------------
753 754 # len, getitem
754 755 #--------------------------------------------------------------------------
755 756
756 757 def __len__(self):
757 758 """len(client) returns # of engines."""
758 759 return len(self.ids)
759 760
760 761 def __getitem__(self, key):
761 762 """index access returns DirectView multiplexer objects
762 763
763 764 Must be int, slice, or list/tuple/xrange of ints"""
764 765 if not isinstance(key, (int, slice, tuple, list, xrange)):
765 766 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
766 767 else:
767 768 return self.direct_view(key)
768 769
769 770 #--------------------------------------------------------------------------
770 771 # Begin public methods
771 772 #--------------------------------------------------------------------------
772 773
773 774 @property
774 775 def ids(self):
775 776 """Always up-to-date ids property."""
776 777 self._flush_notifications()
777 778 # always copy:
778 779 return list(self._ids)
779 780
780 781 def close(self):
781 782 if self._closed:
782 783 return
783 784 snames = filter(lambda n: n.endswith('socket'), dir(self))
784 785 for socket in map(lambda name: getattr(self, name), snames):
785 786 if isinstance(socket, zmq.Socket) and not socket.closed:
786 787 socket.close()
787 788 self._closed = True
788 789
789 790 def spin(self):
790 791 """Flush any registration notifications and execution results
791 792 waiting in the ZMQ queue.
792 793 """
793 794 if self._notification_socket:
794 795 self._flush_notifications()
795 796 if self._mux_socket:
796 797 self._flush_results(self._mux_socket)
797 798 if self._task_socket:
798 799 self._flush_results(self._task_socket)
799 800 if self._control_socket:
800 801 self._flush_control(self._control_socket)
801 802 if self._iopub_socket:
802 803 self._flush_iopub(self._iopub_socket)
803 804 if self._query_socket:
804 805 self._flush_ignored_hub_replies()
805 806
806 807 def wait(self, jobs=None, timeout=-1):
807 808 """waits on one or more `jobs`, for up to `timeout` seconds.
808 809
809 810 Parameters
810 811 ----------
811 812
812 813 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
813 814 ints are indices to self.history
814 815 strs are msg_ids
815 816 default: wait on all outstanding messages
816 817 timeout : float
817 818 a time in seconds, after which to give up.
818 819 default is -1, which means no timeout
819 820
820 821 Returns
821 822 -------
822 823
823 824 True : when all msg_ids are done
824 825 False : timeout reached, some msg_ids still outstanding
825 826 """
826 827 tic = time.time()
827 828 if jobs is None:
828 829 theids = self.outstanding
829 830 else:
830 831 if isinstance(jobs, (int, basestring, AsyncResult)):
831 832 jobs = [jobs]
832 833 theids = set()
833 834 for job in jobs:
834 835 if isinstance(job, int):
835 836 # index access
836 837 job = self.history[job]
837 838 elif isinstance(job, AsyncResult):
838 839 map(theids.add, job.msg_ids)
839 840 continue
840 841 theids.add(job)
841 842 if not theids.intersection(self.outstanding):
842 843 return True
843 844 self.spin()
844 845 while theids.intersection(self.outstanding):
845 846 if timeout >= 0 and ( time.time()-tic ) > timeout:
846 847 break
847 848 time.sleep(1e-3)
848 849 self.spin()
849 850 return len(theids.intersection(self.outstanding)) == 0
850 851
851 852 #--------------------------------------------------------------------------
852 853 # Control methods
853 854 #--------------------------------------------------------------------------
854 855
855 856 @spin_first
856 857 def clear(self, targets=None, block=None):
857 858 """Clear the namespace in target(s)."""
858 859 block = self.block if block is None else block
859 860 targets = self._build_targets(targets)[0]
860 861 for t in targets:
861 862 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
862 863 error = False
863 864 if block:
864 865 self._flush_ignored_control()
865 866 for i in range(len(targets)):
866 867 idents,msg = self.session.recv(self._control_socket,0)
867 868 if self.debug:
868 869 pprint(msg)
869 870 if msg['content']['status'] != 'ok':
870 871 error = self._unwrap_exception(msg['content'])
871 872 else:
872 873 self._ignored_control_replies += len(targets)
873 874 if error:
874 875 raise error
875 876
876 877
877 878 @spin_first
878 879 def abort(self, jobs=None, targets=None, block=None):
879 880 """Abort specific jobs from the execution queues of target(s).
880 881
881 882 This is a mechanism to prevent jobs that have already been submitted
882 883 from executing.
883 884
884 885 Parameters
885 886 ----------
886 887
887 888 jobs : msg_id, list of msg_ids, or AsyncResult
888 889 The jobs to be aborted
889 890
890 891 If unspecified/None: abort all outstanding jobs.
891 892
892 893 """
893 894 block = self.block if block is None else block
894 895 jobs = jobs if jobs is not None else list(self.outstanding)
895 896 targets = self._build_targets(targets)[0]
896 897
897 898 msg_ids = []
898 899 if isinstance(jobs, (basestring,AsyncResult)):
899 900 jobs = [jobs]
900 901 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
901 902 if bad_ids:
902 903 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
903 904 for j in jobs:
904 905 if isinstance(j, AsyncResult):
905 906 msg_ids.extend(j.msg_ids)
906 907 else:
907 908 msg_ids.append(j)
908 909 content = dict(msg_ids=msg_ids)
909 910 for t in targets:
910 911 self.session.send(self._control_socket, 'abort_request',
911 912 content=content, ident=t)
912 913 error = False
913 914 if block:
914 915 self._flush_ignored_control()
915 916 for i in range(len(targets)):
916 917 idents,msg = self.session.recv(self._control_socket,0)
917 918 if self.debug:
918 919 pprint(msg)
919 920 if msg['content']['status'] != 'ok':
920 921 error = self._unwrap_exception(msg['content'])
921 922 else:
922 923 self._ignored_control_replies += len(targets)
923 924 if error:
924 925 raise error
925 926
926 927 @spin_first
927 928 def shutdown(self, targets=None, restart=False, hub=False, block=None):
928 929 """Terminates one or more engine processes, optionally including the hub."""
929 930 block = self.block if block is None else block
930 931 if hub:
931 932 targets = 'all'
932 933 targets = self._build_targets(targets)[0]
933 934 for t in targets:
934 935 self.session.send(self._control_socket, 'shutdown_request',
935 936 content={'restart':restart},ident=t)
936 937 error = False
937 938 if block or hub:
938 939 self._flush_ignored_control()
939 940 for i in range(len(targets)):
940 941 idents,msg = self.session.recv(self._control_socket, 0)
941 942 if self.debug:
942 943 pprint(msg)
943 944 if msg['content']['status'] != 'ok':
944 945 error = self._unwrap_exception(msg['content'])
945 946 else:
946 947 self._ignored_control_replies += len(targets)
947 948
948 949 if hub:
949 950 time.sleep(0.25)
950 951 self.session.send(self._query_socket, 'shutdown_request')
951 952 idents,msg = self.session.recv(self._query_socket, 0)
952 953 if self.debug:
953 954 pprint(msg)
954 955 if msg['content']['status'] != 'ok':
955 956 error = self._unwrap_exception(msg['content'])
956 957
957 958 if error:
958 959 raise error
959 960
960 961 #--------------------------------------------------------------------------
961 962 # Execution related methods
962 963 #--------------------------------------------------------------------------
963 964
964 965 def _maybe_raise(self, result):
965 966 """wrapper for maybe raising an exception if apply failed."""
966 967 if isinstance(result, error.RemoteError):
967 968 raise result
968 969
969 970 return result
970 971
971 972 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
972 973 ident=None):
973 974 """construct and send an apply message via a socket.
974 975
975 976 This is the principal method with which all engine execution is performed by views.
976 977 """
977 978
978 979 assert not self._closed, "cannot use me anymore, I'm closed!"
979 980 # defaults:
980 981 args = args if args is not None else []
981 982 kwargs = kwargs if kwargs is not None else {}
982 983 subheader = subheader if subheader is not None else {}
983 984
984 985 # validate arguments
985 if not callable(f):
986 if not callable(f) and not isinstance(f, Reference):
986 987 raise TypeError("f must be callable, not %s"%type(f))
987 988 if not isinstance(args, (tuple, list)):
988 989 raise TypeError("args must be tuple or list, not %s"%type(args))
989 990 if not isinstance(kwargs, dict):
990 991 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
991 992 if not isinstance(subheader, dict):
992 993 raise TypeError("subheader must be dict, not %s"%type(subheader))
993 994
994 995 bufs = util.pack_apply_message(f,args,kwargs)
995 996
996 997 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
997 998 subheader=subheader, track=track)
998 999
999 1000 msg_id = msg['header']['msg_id']
1000 1001 self.outstanding.add(msg_id)
1001 1002 if ident:
1002 1003 # possibly routed to a specific engine
1003 1004 if isinstance(ident, list):
1004 1005 ident = ident[-1]
1005 1006 if ident in self._engines.values():
1006 1007 # save for later, in case of engine death
1007 1008 self._outstanding_dict[ident].add(msg_id)
1008 1009 self.history.append(msg_id)
1009 1010 self.metadata[msg_id]['submitted'] = datetime.now()
1010 1011
1011 1012 return msg
1012 1013
1013 1014 #--------------------------------------------------------------------------
1014 1015 # construct a View object
1015 1016 #--------------------------------------------------------------------------
1016 1017
1017 1018 def load_balanced_view(self, targets=None):
1018 1019 """construct a DirectView object.
1019 1020
1020 1021 If no arguments are specified, create a LoadBalancedView
1021 1022 using all engines.
1022 1023
1023 1024 Parameters
1024 1025 ----------
1025 1026
1026 1027 targets: list,slice,int,etc. [default: use all engines]
1027 1028 The subset of engines across which to load-balance
1028 1029 """
1029 1030 if targets == 'all':
1030 1031 targets = None
1031 1032 if targets is not None:
1032 1033 targets = self._build_targets(targets)[1]
1033 1034 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1034 1035
1035 1036 def direct_view(self, targets='all'):
1036 1037 """construct a DirectView object.
1037 1038
1038 1039 If no targets are specified, create a DirectView using all engines.
1039 1040
1040 1041 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1041 1042 evaluate the target engines at each execution, whereas rc[:] will connect to
1042 1043 all *current* engines, and that list will not change.
1043 1044
1044 1045 That is, 'all' will always use all engines, whereas rc[:] will not use
1045 1046 engines added after the DirectView is constructed.
1046 1047
1047 1048 Parameters
1048 1049 ----------
1049 1050
1050 1051 targets: list,slice,int,etc. [default: use all engines]
1051 1052 The engines to use for the View
1052 1053 """
1053 1054 single = isinstance(targets, int)
1054 1055 # allow 'all' to be lazily evaluated at each execution
1055 1056 if targets != 'all':
1056 1057 targets = self._build_targets(targets)[1]
1057 1058 if single:
1058 1059 targets = targets[0]
1059 1060 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1060 1061
1061 1062 #--------------------------------------------------------------------------
1062 1063 # Query methods
1063 1064 #--------------------------------------------------------------------------
1064 1065
1065 1066 @spin_first
1066 1067 def get_result(self, indices_or_msg_ids=None, block=None):
1067 1068 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1068 1069
1069 1070 If the client already has the results, no request to the Hub will be made.
1070 1071
1071 1072 This is a convenient way to construct AsyncResult objects, which are wrappers
1072 1073 that include metadata about execution, and allow for awaiting results that
1073 1074 were not submitted by this Client.
1074 1075
1075 1076 It can also be a convenient way to retrieve the metadata associated with
1076 1077 blocking execution, since it always retrieves
1077 1078
1078 1079 Examples
1079 1080 --------
1080 1081 ::
1081 1082
1082 1083 In [10]: r = client.apply()
1083 1084
1084 1085 Parameters
1085 1086 ----------
1086 1087
1087 1088 indices_or_msg_ids : integer history index, str msg_id, or list of either
1088 1089 The indices or msg_ids of indices to be retrieved
1089 1090
1090 1091 block : bool
1091 1092 Whether to wait for the result to be done
1092 1093
1093 1094 Returns
1094 1095 -------
1095 1096
1096 1097 AsyncResult
1097 1098 A single AsyncResult object will always be returned.
1098 1099
1099 1100 AsyncHubResult
1100 1101 A subclass of AsyncResult that retrieves results from the Hub
1101 1102
1102 1103 """
1103 1104 block = self.block if block is None else block
1104 1105 if indices_or_msg_ids is None:
1105 1106 indices_or_msg_ids = -1
1106 1107
1107 1108 if not isinstance(indices_or_msg_ids, (list,tuple)):
1108 1109 indices_or_msg_ids = [indices_or_msg_ids]
1109 1110
1110 1111 theids = []
1111 1112 for id in indices_or_msg_ids:
1112 1113 if isinstance(id, int):
1113 1114 id = self.history[id]
1114 1115 if not isinstance(id, basestring):
1115 1116 raise TypeError("indices must be str or int, not %r"%id)
1116 1117 theids.append(id)
1117 1118
1118 1119 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1119 1120 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1120 1121
1121 1122 if remote_ids:
1122 1123 ar = AsyncHubResult(self, msg_ids=theids)
1123 1124 else:
1124 1125 ar = AsyncResult(self, msg_ids=theids)
1125 1126
1126 1127 if block:
1127 1128 ar.wait()
1128 1129
1129 1130 return ar
1130 1131
1131 1132 @spin_first
1132 1133 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1133 1134 """Resubmit one or more tasks.
1134 1135
1135 1136 in-flight tasks may not be resubmitted.
1136 1137
1137 1138 Parameters
1138 1139 ----------
1139 1140
1140 1141 indices_or_msg_ids : integer history index, str msg_id, or list of either
1141 1142 The indices or msg_ids of indices to be retrieved
1142 1143
1143 1144 block : bool
1144 1145 Whether to wait for the result to be done
1145 1146
1146 1147 Returns
1147 1148 -------
1148 1149
1149 1150 AsyncHubResult
1150 1151 A subclass of AsyncResult that retrieves results from the Hub
1151 1152
1152 1153 """
1153 1154 block = self.block if block is None else block
1154 1155 if indices_or_msg_ids is None:
1155 1156 indices_or_msg_ids = -1
1156 1157
1157 1158 if not isinstance(indices_or_msg_ids, (list,tuple)):
1158 1159 indices_or_msg_ids = [indices_or_msg_ids]
1159 1160
1160 1161 theids = []
1161 1162 for id in indices_or_msg_ids:
1162 1163 if isinstance(id, int):
1163 1164 id = self.history[id]
1164 1165 if not isinstance(id, basestring):
1165 1166 raise TypeError("indices must be str or int, not %r"%id)
1166 1167 theids.append(id)
1167 1168
1168 1169 for msg_id in theids:
1169 1170 self.outstanding.discard(msg_id)
1170 1171 if msg_id in self.history:
1171 1172 self.history.remove(msg_id)
1172 1173 self.results.pop(msg_id, None)
1173 1174 self.metadata.pop(msg_id, None)
1174 1175 content = dict(msg_ids = theids)
1175 1176
1176 1177 self.session.send(self._query_socket, 'resubmit_request', content)
1177 1178
1178 1179 zmq.select([self._query_socket], [], [])
1179 1180 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1180 1181 if self.debug:
1181 1182 pprint(msg)
1182 1183 content = msg['content']
1183 1184 if content['status'] != 'ok':
1184 1185 raise self._unwrap_exception(content)
1185 1186
1186 1187 ar = AsyncHubResult(self, msg_ids=theids)
1187 1188
1188 1189 if block:
1189 1190 ar.wait()
1190 1191
1191 1192 return ar
1192 1193
1193 1194 @spin_first
1194 1195 def result_status(self, msg_ids, status_only=True):
1195 1196 """Check on the status of the result(s) of the apply request with `msg_ids`.
1196 1197
1197 1198 If status_only is False, then the actual results will be retrieved, else
1198 1199 only the status of the results will be checked.
1199 1200
1200 1201 Parameters
1201 1202 ----------
1202 1203
1203 1204 msg_ids : list of msg_ids
1204 1205 if int:
1205 1206 Passed as index to self.history for convenience.
1206 1207 status_only : bool (default: True)
1207 1208 if False:
1208 1209 Retrieve the actual results of completed tasks.
1209 1210
1210 1211 Returns
1211 1212 -------
1212 1213
1213 1214 results : dict
1214 1215 There will always be the keys 'pending' and 'completed', which will
1215 1216 be lists of msg_ids that are incomplete or complete. If `status_only`
1216 1217 is False, then completed results will be keyed by their `msg_id`.
1217 1218 """
1218 1219 if not isinstance(msg_ids, (list,tuple)):
1219 1220 msg_ids = [msg_ids]
1220 1221
1221 1222 theids = []
1222 1223 for msg_id in msg_ids:
1223 1224 if isinstance(msg_id, int):
1224 1225 msg_id = self.history[msg_id]
1225 1226 if not isinstance(msg_id, basestring):
1226 1227 raise TypeError("msg_ids must be str, not %r"%msg_id)
1227 1228 theids.append(msg_id)
1228 1229
1229 1230 completed = []
1230 1231 local_results = {}
1231 1232
1232 1233 # comment this block out to temporarily disable local shortcut:
1233 1234 for msg_id in theids:
1234 1235 if msg_id in self.results:
1235 1236 completed.append(msg_id)
1236 1237 local_results[msg_id] = self.results[msg_id]
1237 1238 theids.remove(msg_id)
1238 1239
1239 1240 if theids: # some not locally cached
1240 1241 content = dict(msg_ids=theids, status_only=status_only)
1241 1242 msg = self.session.send(self._query_socket, "result_request", content=content)
1242 1243 zmq.select([self._query_socket], [], [])
1243 1244 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1244 1245 if self.debug:
1245 1246 pprint(msg)
1246 1247 content = msg['content']
1247 1248 if content['status'] != 'ok':
1248 1249 raise self._unwrap_exception(content)
1249 1250 buffers = msg['buffers']
1250 1251 else:
1251 1252 content = dict(completed=[],pending=[])
1252 1253
1253 1254 content['completed'].extend(completed)
1254 1255
1255 1256 if status_only:
1256 1257 return content
1257 1258
1258 1259 failures = []
1259 1260 # load cached results into result:
1260 1261 content.update(local_results)
1261 1262
1262 1263 # update cache with results:
1263 1264 for msg_id in sorted(theids):
1264 1265 if msg_id in content['completed']:
1265 1266 rec = content[msg_id]
1266 1267 parent = rec['header']
1267 1268 header = rec['result_header']
1268 1269 rcontent = rec['result_content']
1269 1270 iodict = rec['io']
1270 1271 if isinstance(rcontent, str):
1271 1272 rcontent = self.session.unpack(rcontent)
1272 1273
1273 1274 md = self.metadata[msg_id]
1274 1275 md.update(self._extract_metadata(header, parent, rcontent))
1275 1276 md.update(iodict)
1276 1277
1277 1278 if rcontent['status'] == 'ok':
1278 1279 res,buffers = util.unserialize_object(buffers)
1279 1280 else:
1280 1281 print rcontent
1281 1282 res = self._unwrap_exception(rcontent)
1282 1283 failures.append(res)
1283 1284
1284 1285 self.results[msg_id] = res
1285 1286 content[msg_id] = res
1286 1287
1287 1288 if len(theids) == 1 and failures:
1288 1289 raise failures[0]
1289 1290
1290 1291 error.collect_exceptions(failures, "result_status")
1291 1292 return content
1292 1293
1293 1294 @spin_first
1294 1295 def queue_status(self, targets='all', verbose=False):
1295 1296 """Fetch the status of engine queues.
1296 1297
1297 1298 Parameters
1298 1299 ----------
1299 1300
1300 1301 targets : int/str/list of ints/strs
1301 1302 the engines whose states are to be queried.
1302 1303 default : all
1303 1304 verbose : bool
1304 1305 Whether to return lengths only, or lists of ids for each element
1305 1306 """
1306 1307 engine_ids = self._build_targets(targets)[1]
1307 1308 content = dict(targets=engine_ids, verbose=verbose)
1308 1309 self.session.send(self._query_socket, "queue_request", content=content)
1309 1310 idents,msg = self.session.recv(self._query_socket, 0)
1310 1311 if self.debug:
1311 1312 pprint(msg)
1312 1313 content = msg['content']
1313 1314 status = content.pop('status')
1314 1315 if status != 'ok':
1315 1316 raise self._unwrap_exception(content)
1316 1317 content = rekey(content)
1317 1318 if isinstance(targets, int):
1318 1319 return content[targets]
1319 1320 else:
1320 1321 return content
1321 1322
1322 1323 @spin_first
1323 1324 def purge_results(self, jobs=[], targets=[]):
1324 1325 """Tell the Hub to forget results.
1325 1326
1326 1327 Individual results can be purged by msg_id, or the entire
1327 1328 history of specific targets can be purged.
1328 1329
1329 1330 Use `purge_results('all')` to scrub everything from the Hub's db.
1330 1331
1331 1332 Parameters
1332 1333 ----------
1333 1334
1334 1335 jobs : str or list of str or AsyncResult objects
1335 1336 the msg_ids whose results should be forgotten.
1336 1337 targets : int/str/list of ints/strs
1337 1338 The targets, by int_id, whose entire history is to be purged.
1338 1339
1339 1340 default : None
1340 1341 """
1341 1342 if not targets and not jobs:
1342 1343 raise ValueError("Must specify at least one of `targets` and `jobs`")
1343 1344 if targets:
1344 1345 targets = self._build_targets(targets)[1]
1345 1346
1346 1347 # construct msg_ids from jobs
1347 1348 if jobs == 'all':
1348 1349 msg_ids = jobs
1349 1350 else:
1350 1351 msg_ids = []
1351 1352 if isinstance(jobs, (basestring,AsyncResult)):
1352 1353 jobs = [jobs]
1353 1354 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1354 1355 if bad_ids:
1355 1356 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1356 1357 for j in jobs:
1357 1358 if isinstance(j, AsyncResult):
1358 1359 msg_ids.extend(j.msg_ids)
1359 1360 else:
1360 1361 msg_ids.append(j)
1361 1362
1362 1363 content = dict(engine_ids=targets, msg_ids=msg_ids)
1363 1364 self.session.send(self._query_socket, "purge_request", content=content)
1364 1365 idents, msg = self.session.recv(self._query_socket, 0)
1365 1366 if self.debug:
1366 1367 pprint(msg)
1367 1368 content = msg['content']
1368 1369 if content['status'] != 'ok':
1369 1370 raise self._unwrap_exception(content)
1370 1371
1371 1372 @spin_first
1372 1373 def hub_history(self):
1373 1374 """Get the Hub's history
1374 1375
1375 1376 Just like the Client, the Hub has a history, which is a list of msg_ids.
1376 1377 This will contain the history of all clients, and, depending on configuration,
1377 1378 may contain history across multiple cluster sessions.
1378 1379
1379 1380 Any msg_id returned here is a valid argument to `get_result`.
1380 1381
1381 1382 Returns
1382 1383 -------
1383 1384
1384 1385 msg_ids : list of strs
1385 1386 list of all msg_ids, ordered by task submission time.
1386 1387 """
1387 1388
1388 1389 self.session.send(self._query_socket, "history_request", content={})
1389 1390 idents, msg = self.session.recv(self._query_socket, 0)
1390 1391
1391 1392 if self.debug:
1392 1393 pprint(msg)
1393 1394 content = msg['content']
1394 1395 if content['status'] != 'ok':
1395 1396 raise self._unwrap_exception(content)
1396 1397 else:
1397 1398 return content['history']
1398 1399
1399 1400 @spin_first
1400 1401 def db_query(self, query, keys=None):
1401 1402 """Query the Hub's TaskRecord database
1402 1403
1403 1404 This will return a list of task record dicts that match `query`
1404 1405
1405 1406 Parameters
1406 1407 ----------
1407 1408
1408 1409 query : mongodb query dict
1409 1410 The search dict. See mongodb query docs for details.
1410 1411 keys : list of strs [optional]
1411 1412 The subset of keys to be returned. The default is to fetch everything but buffers.
1412 1413 'msg_id' will *always* be included.
1413 1414 """
1414 1415 if isinstance(keys, basestring):
1415 1416 keys = [keys]
1416 1417 content = dict(query=query, keys=keys)
1417 1418 self.session.send(self._query_socket, "db_request", content=content)
1418 1419 idents, msg = self.session.recv(self._query_socket, 0)
1419 1420 if self.debug:
1420 1421 pprint(msg)
1421 1422 content = msg['content']
1422 1423 if content['status'] != 'ok':
1423 1424 raise self._unwrap_exception(content)
1424 1425
1425 1426 records = content['records']
1426 1427
1427 1428 buffer_lens = content['buffer_lens']
1428 1429 result_buffer_lens = content['result_buffer_lens']
1429 1430 buffers = msg['buffers']
1430 1431 has_bufs = buffer_lens is not None
1431 1432 has_rbufs = result_buffer_lens is not None
1432 1433 for i,rec in enumerate(records):
1433 1434 # relink buffers
1434 1435 if has_bufs:
1435 1436 blen = buffer_lens[i]
1436 1437 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1437 1438 if has_rbufs:
1438 1439 blen = result_buffer_lens[i]
1439 1440 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1440 1441
1441 1442 return records
1442 1443
1443 1444 __all__ = [ 'Client' ]
@@ -1,222 +1,241 b''
1 1 """Remote Functions and decorators for Views.
2 2
3 3 Authors:
4 4
5 5 * Brian Granger
6 6 * Min RK
7 7 """
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2010-2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import sys
22 22 import warnings
23 23
24 24 from IPython.testing.skipdoctest import skip_doctest
25 25
26 26 from . import map as Map
27 27 from .asyncresult import AsyncMapResult
28 28
29 29 #-----------------------------------------------------------------------------
30 # Decorators
30 # Functions and Decorators
31 31 #-----------------------------------------------------------------------------
32 32
33 33 @skip_doctest
34 34 def remote(view, block=None, **flags):
35 35 """Turn a function into a remote function.
36 36
37 37 This method can be used for map:
38 38
39 39 In [1]: @remote(view,block=True)
40 40 ...: def func(a):
41 41 ...: pass
42 42 """
43 43
44 44 def remote_function(f):
45 45 return RemoteFunction(view, f, block=block, **flags)
46 46 return remote_function
47 47
48 48 @skip_doctest
49 49 def parallel(view, dist='b', block=None, ordered=True, **flags):
50 50 """Turn a function into a parallel remote function.
51 51
52 52 This method can be used for map:
53 53
54 54 In [1]: @parallel(view, block=True)
55 55 ...: def func(a):
56 56 ...: pass
57 57 """
58 58
59 59 def parallel_function(f):
60 60 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
61 61 return parallel_function
62 62
63 def getname(f):
64 """Get the name of an object.
65
66 For use in case of callables that are not functions, and
67 thus may not have __name__ defined.
68
69 Order: f.__name__ > f.name > str(f)
70 """
71 try:
72 return f.__name__
73 except:
74 pass
75 try:
76 return f.name
77 except:
78 pass
79
80 return str(f)
81
63 82 #--------------------------------------------------------------------------
64 83 # Classes
65 84 #--------------------------------------------------------------------------
66 85
67 86 class RemoteFunction(object):
68 87 """Turn an existing function into a remote function.
69 88
70 89 Parameters
71 90 ----------
72 91
73 92 view : View instance
74 93 The view to be used for execution
75 94 f : callable
76 95 The function to be wrapped into a remote function
77 96 block : bool [default: None]
78 97 Whether to wait for results or not. The default behavior is
79 98 to use the current `block` attribute of `view`
80 99
81 100 **flags : remaining kwargs are passed to View.temp_flags
82 101 """
83 102
84 103 view = None # the remote connection
85 104 func = None # the wrapped function
86 105 block = None # whether to block
87 106 flags = None # dict of extra kwargs for temp_flags
88 107
89 108 def __init__(self, view, f, block=None, **flags):
90 109 self.view = view
91 110 self.func = f
92 111 self.block=block
93 112 self.flags=flags
94 113
95 114 def __call__(self, *args, **kwargs):
96 115 block = self.view.block if self.block is None else self.block
97 116 with self.view.temp_flags(block=block, **self.flags):
98 117 return self.view.apply(self.func, *args, **kwargs)
99 118
100 119
101 120 class ParallelFunction(RemoteFunction):
102 121 """Class for mapping a function to sequences.
103 122
104 123 This will distribute the sequences according the a mapper, and call
105 124 the function on each sub-sequence. If called via map, then the function
106 125 will be called once on each element, rather that each sub-sequence.
107 126
108 127 Parameters
109 128 ----------
110 129
111 130 view : View instance
112 131 The view to be used for execution
113 132 f : callable
114 133 The function to be wrapped into a remote function
115 134 dist : str [default: 'b']
116 135 The key for which mapObject to use to distribute sequences
117 136 options are:
118 137 * 'b' : use contiguous chunks in order
119 138 * 'r' : use round-robin striping
120 139 block : bool [default: None]
121 140 Whether to wait for results or not. The default behavior is
122 141 to use the current `block` attribute of `view`
123 142 chunksize : int or None
124 143 The size of chunk to use when breaking up sequences in a load-balanced manner
125 144 ordered : bool [default: True]
126 145 Whether
127 146 **flags : remaining kwargs are passed to View.temp_flags
128 147 """
129 148
130 149 chunksize=None
131 150 ordered=None
132 151 mapObject=None
133 152
134 153 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
135 154 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
136 155 self.chunksize = chunksize
137 156 self.ordered = ordered
138 157
139 158 mapClass = Map.dists[dist]
140 159 self.mapObject = mapClass()
141 160
142 161 def __call__(self, *sequences):
143 162 client = self.view.client
144 163
145 164 # check that the length of sequences match
146 165 len_0 = len(sequences[0])
147 166 for s in sequences:
148 167 if len(s)!=len_0:
149 168 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
150 169 raise ValueError(msg)
151 170 balanced = 'Balanced' in self.view.__class__.__name__
152 171 if balanced:
153 172 if self.chunksize:
154 173 nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0)
155 174 else:
156 175 nparts = len_0
157 176 targets = [None]*nparts
158 177 else:
159 178 if self.chunksize:
160 179 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
161 180 # multiplexed:
162 181 targets = self.view.targets
163 182 # 'all' is lazily evaluated at execution time, which is now:
164 183 if targets == 'all':
165 184 targets = client._build_targets(targets)[1]
166 185 nparts = len(targets)
167 186
168 187 msg_ids = []
169 188 for index, t in enumerate(targets):
170 189 args = []
171 190 for seq in sequences:
172 191 part = self.mapObject.getPartition(seq, index, nparts)
173 192 if len(part) == 0:
174 193 continue
175 194 else:
176 195 args.append(part)
177 196 if not args:
178 197 continue
179 198
180 199 # print (args)
181 200 if hasattr(self, '_map'):
182 201 if sys.version_info[0] >= 3:
183 202 f = lambda f, *sequences: list(map(f, *sequences))
184 203 else:
185 204 f = map
186 205 args = [self.func]+args
187 206 else:
188 207 f=self.func
189 208
190 209 view = self.view if balanced else client[t]
191 210 with view.temp_flags(block=False, **self.flags):
192 211 ar = view.apply(f, *args)
193 212
194 213 msg_ids.append(ar.msg_ids[0])
195 214
196 215 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
197 fname=self.func.__name__,
216 fname=getname(self.func),
198 217 ordered=self.ordered
199 218 )
200 219
201 220 if self.block:
202 221 try:
203 222 return r.get()
204 223 except KeyboardInterrupt:
205 224 return r
206 225 else:
207 226 return r
208 227
209 228 def map(self, *sequences):
210 229 """call a function on each element of a sequence remotely.
211 230 This should behave very much like the builtin map, but return an AsyncMapResult
212 231 if self.block is False.
213 232 """
214 233 # set _map as a flag for use inside self.__call__
215 234 self._map = True
216 235 try:
217 236 ret = self.__call__(*sequences)
218 237 finally:
219 238 del self._map
220 239 return ret
221 240
222 241 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,1065 +1,1065 b''
1 1 """Views of remote engines.
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import imp
19 19 import sys
20 20 import warnings
21 21 from contextlib import contextmanager
22 22 from types import ModuleType
23 23
24 24 import zmq
25 25
26 26 from IPython.testing.skipdoctest import skip_doctest
27 27 from IPython.utils.traitlets import (
28 28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
29 29 )
30 30 from IPython.external.decorator import decorator
31 31
32 32 from IPython.parallel import util
33 33 from IPython.parallel.controller.dependency import Dependency, dependent
34 34
35 35 from . import map as Map
36 36 from .asyncresult import AsyncResult, AsyncMapResult
37 from .remotefunction import ParallelFunction, parallel, remote
37 from .remotefunction import ParallelFunction, parallel, remote, getname
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Decorators
41 41 #-----------------------------------------------------------------------------
42 42
43 43 @decorator
44 44 def save_ids(f, self, *args, **kwargs):
45 45 """Keep our history and outstanding attributes up to date after a method call."""
46 46 n_previous = len(self.client.history)
47 47 try:
48 48 ret = f(self, *args, **kwargs)
49 49 finally:
50 50 nmsgs = len(self.client.history) - n_previous
51 51 msg_ids = self.client.history[-nmsgs:]
52 52 self.history.extend(msg_ids)
53 53 map(self.outstanding.add, msg_ids)
54 54 return ret
55 55
56 56 @decorator
57 57 def sync_results(f, self, *args, **kwargs):
58 58 """sync relevant results from self.client to our results attribute."""
59 59 ret = f(self, *args, **kwargs)
60 60 delta = self.outstanding.difference(self.client.outstanding)
61 61 completed = self.outstanding.intersection(delta)
62 62 self.outstanding = self.outstanding.difference(completed)
63 63 for msg_id in completed:
64 64 self.results[msg_id] = self.client.results[msg_id]
65 65 return ret
66 66
67 67 @decorator
68 68 def spin_after(f, self, *args, **kwargs):
69 69 """call spin after the method."""
70 70 ret = f(self, *args, **kwargs)
71 71 self.spin()
72 72 return ret
73 73
74 74 #-----------------------------------------------------------------------------
75 75 # Classes
76 76 #-----------------------------------------------------------------------------
77 77
78 78 @skip_doctest
79 79 class View(HasTraits):
80 80 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
81 81
82 82 Don't use this class, use subclasses.
83 83
84 84 Methods
85 85 -------
86 86
87 87 spin
88 88 flushes incoming results and registration state changes
89 89 control methods spin, and requesting `ids` also ensures up to date
90 90
91 91 wait
92 92 wait on one or more msg_ids
93 93
94 94 execution methods
95 95 apply
96 96 legacy: execute, run
97 97
98 98 data movement
99 99 push, pull, scatter, gather
100 100
101 101 query methods
102 102 get_result, queue_status, purge_results, result_status
103 103
104 104 control methods
105 105 abort, shutdown
106 106
107 107 """
108 108 # flags
109 109 block=Bool(False)
110 110 track=Bool(True)
111 111 targets = Any()
112 112
113 113 history=List()
114 114 outstanding = Set()
115 115 results = Dict()
116 116 client = Instance('IPython.parallel.Client')
117 117
118 118 _socket = Instance('zmq.Socket')
119 119 _flag_names = List(['targets', 'block', 'track'])
120 120 _targets = Any()
121 121 _idents = Any()
122 122
123 123 def __init__(self, client=None, socket=None, **flags):
124 124 super(View, self).__init__(client=client, _socket=socket)
125 125 self.block = client.block
126 126
127 127 self.set_flags(**flags)
128 128
129 129 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
130 130
131 131
132 132 def __repr__(self):
133 133 strtargets = str(self.targets)
134 134 if len(strtargets) > 16:
135 135 strtargets = strtargets[:12]+'...]'
136 136 return "<%s %s>"%(self.__class__.__name__, strtargets)
137 137
138 138 def set_flags(self, **kwargs):
139 139 """set my attribute flags by keyword.
140 140
141 141 Views determine behavior with a few attributes (`block`, `track`, etc.).
142 142 These attributes can be set all at once by name with this method.
143 143
144 144 Parameters
145 145 ----------
146 146
147 147 block : bool
148 148 whether to wait for results
149 149 track : bool
150 150 whether to create a MessageTracker to allow the user to
151 151 safely edit after arrays and buffers during non-copying
152 152 sends.
153 153 """
154 154 for name, value in kwargs.iteritems():
155 155 if name not in self._flag_names:
156 156 raise KeyError("Invalid name: %r"%name)
157 157 else:
158 158 setattr(self, name, value)
159 159
160 160 @contextmanager
161 161 def temp_flags(self, **kwargs):
162 162 """temporarily set flags, for use in `with` statements.
163 163
164 164 See set_flags for permanent setting of flags
165 165
166 166 Examples
167 167 --------
168 168
169 169 >>> view.track=False
170 170 ...
171 171 >>> with view.temp_flags(track=True):
172 172 ... ar = view.apply(dostuff, my_big_array)
173 173 ... ar.tracker.wait() # wait for send to finish
174 174 >>> view.track
175 175 False
176 176
177 177 """
178 178 # preflight: save flags, and set temporaries
179 179 saved_flags = {}
180 180 for f in self._flag_names:
181 181 saved_flags[f] = getattr(self, f)
182 182 self.set_flags(**kwargs)
183 183 # yield to the with-statement block
184 184 try:
185 185 yield
186 186 finally:
187 187 # postflight: restore saved flags
188 188 self.set_flags(**saved_flags)
189 189
190 190
191 191 #----------------------------------------------------------------
192 192 # apply
193 193 #----------------------------------------------------------------
194 194
195 195 @sync_results
196 196 @save_ids
197 197 def _really_apply(self, f, args, kwargs, block=None, **options):
198 198 """wrapper for client.send_apply_message"""
199 199 raise NotImplementedError("Implement in subclasses")
200 200
201 201 def apply(self, f, *args, **kwargs):
202 202 """calls f(*args, **kwargs) on remote engines, returning the result.
203 203
204 204 This method sets all apply flags via this View's attributes.
205 205
206 206 if self.block is False:
207 207 returns AsyncResult
208 208 else:
209 209 returns actual result of f(*args, **kwargs)
210 210 """
211 211 return self._really_apply(f, args, kwargs)
212 212
213 213 def apply_async(self, f, *args, **kwargs):
214 214 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
215 215
216 216 returns AsyncResult
217 217 """
218 218 return self._really_apply(f, args, kwargs, block=False)
219 219
220 220 @spin_after
221 221 def apply_sync(self, f, *args, **kwargs):
222 222 """calls f(*args, **kwargs) on remote engines in a blocking manner,
223 223 returning the result.
224 224
225 225 returns: actual result of f(*args, **kwargs)
226 226 """
227 227 return self._really_apply(f, args, kwargs, block=True)
228 228
229 229 #----------------------------------------------------------------
230 230 # wrappers for client and control methods
231 231 #----------------------------------------------------------------
232 232 @sync_results
233 233 def spin(self):
234 234 """spin the client, and sync"""
235 235 self.client.spin()
236 236
237 237 @sync_results
238 238 def wait(self, jobs=None, timeout=-1):
239 239 """waits on one or more `jobs`, for up to `timeout` seconds.
240 240
241 241 Parameters
242 242 ----------
243 243
244 244 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
245 245 ints are indices to self.history
246 246 strs are msg_ids
247 247 default: wait on all outstanding messages
248 248 timeout : float
249 249 a time in seconds, after which to give up.
250 250 default is -1, which means no timeout
251 251
252 252 Returns
253 253 -------
254 254
255 255 True : when all msg_ids are done
256 256 False : timeout reached, some msg_ids still outstanding
257 257 """
258 258 if jobs is None:
259 259 jobs = self.history
260 260 return self.client.wait(jobs, timeout)
261 261
262 262 def abort(self, jobs=None, targets=None, block=None):
263 263 """Abort jobs on my engines.
264 264
265 265 Parameters
266 266 ----------
267 267
268 268 jobs : None, str, list of strs, optional
269 269 if None: abort all jobs.
270 270 else: abort specific msg_id(s).
271 271 """
272 272 block = block if block is not None else self.block
273 273 targets = targets if targets is not None else self.targets
274 274 jobs = jobs if jobs is not None else list(self.outstanding)
275 275
276 276 return self.client.abort(jobs=jobs, targets=targets, block=block)
277 277
278 278 def queue_status(self, targets=None, verbose=False):
279 279 """Fetch the Queue status of my engines"""
280 280 targets = targets if targets is not None else self.targets
281 281 return self.client.queue_status(targets=targets, verbose=verbose)
282 282
283 283 def purge_results(self, jobs=[], targets=[]):
284 284 """Instruct the controller to forget specific results."""
285 285 if targets is None or targets == 'all':
286 286 targets = self.targets
287 287 return self.client.purge_results(jobs=jobs, targets=targets)
288 288
289 289 def shutdown(self, targets=None, restart=False, hub=False, block=None):
290 290 """Terminates one or more engine processes, optionally including the hub.
291 291 """
292 292 block = self.block if block is None else block
293 293 if targets is None or targets == 'all':
294 294 targets = self.targets
295 295 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
296 296
297 297 @spin_after
298 298 def get_result(self, indices_or_msg_ids=None):
299 299 """return one or more results, specified by history index or msg_id.
300 300
301 301 See client.get_result for details.
302 302
303 303 """
304 304
305 305 if indices_or_msg_ids is None:
306 306 indices_or_msg_ids = -1
307 307 if isinstance(indices_or_msg_ids, int):
308 308 indices_or_msg_ids = self.history[indices_or_msg_ids]
309 309 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
310 310 indices_or_msg_ids = list(indices_or_msg_ids)
311 311 for i,index in enumerate(indices_or_msg_ids):
312 312 if isinstance(index, int):
313 313 indices_or_msg_ids[i] = self.history[index]
314 314 return self.client.get_result(indices_or_msg_ids)
315 315
316 316 #-------------------------------------------------------------------
317 317 # Map
318 318 #-------------------------------------------------------------------
319 319
320 320 def map(self, f, *sequences, **kwargs):
321 321 """override in subclasses"""
322 322 raise NotImplementedError
323 323
324 324 def map_async(self, f, *sequences, **kwargs):
325 325 """Parallel version of builtin `map`, using this view's engines.
326 326
327 327 This is equivalent to map(...block=False)
328 328
329 329 See `self.map` for details.
330 330 """
331 331 if 'block' in kwargs:
332 332 raise TypeError("map_async doesn't take a `block` keyword argument.")
333 333 kwargs['block'] = False
334 334 return self.map(f,*sequences,**kwargs)
335 335
336 336 def map_sync(self, f, *sequences, **kwargs):
337 337 """Parallel version of builtin `map`, using this view's engines.
338 338
339 339 This is equivalent to map(...block=True)
340 340
341 341 See `self.map` for details.
342 342 """
343 343 if 'block' in kwargs:
344 344 raise TypeError("map_sync doesn't take a `block` keyword argument.")
345 345 kwargs['block'] = True
346 346 return self.map(f,*sequences,**kwargs)
347 347
348 348 def imap(self, f, *sequences, **kwargs):
349 349 """Parallel version of `itertools.imap`.
350 350
351 351 See `self.map` for details.
352 352
353 353 """
354 354
355 355 return iter(self.map_async(f,*sequences, **kwargs))
356 356
357 357 #-------------------------------------------------------------------
358 358 # Decorators
359 359 #-------------------------------------------------------------------
360 360
361 361 def remote(self, block=True, **flags):
362 362 """Decorator for making a RemoteFunction"""
363 363 block = self.block if block is None else block
364 364 return remote(self, block=block, **flags)
365 365
366 366 def parallel(self, dist='b', block=None, **flags):
367 367 """Decorator for making a ParallelFunction"""
368 368 block = self.block if block is None else block
369 369 return parallel(self, dist=dist, block=block, **flags)
370 370
371 371 @skip_doctest
372 372 class DirectView(View):
373 373 """Direct Multiplexer View of one or more engines.
374 374
375 375 These are created via indexed access to a client:
376 376
377 377 >>> dv_1 = client[1]
378 378 >>> dv_all = client[:]
379 379 >>> dv_even = client[::2]
380 380 >>> dv_some = client[1:3]
381 381
382 382 This object provides dictionary access to engine namespaces:
383 383
384 384 # push a=5:
385 385 >>> dv['a'] = 5
386 386 # pull 'foo':
387 387 >>> db['foo']
388 388
389 389 """
390 390
391 391 def __init__(self, client=None, socket=None, targets=None):
392 392 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
393 393
394 394 @property
395 395 def importer(self):
396 396 """sync_imports(local=True) as a property.
397 397
398 398 See sync_imports for details.
399 399
400 400 """
401 401 return self.sync_imports(True)
402 402
403 403 @contextmanager
404 404 def sync_imports(self, local=True):
405 405 """Context Manager for performing simultaneous local and remote imports.
406 406
407 407 'import x as y' will *not* work. The 'as y' part will simply be ignored.
408 408
409 409 If `local=True`, then the package will also be imported locally.
410 410
411 411 Note that remote-only (`local=False`) imports have not been implemented.
412 412
413 413 >>> with view.sync_imports():
414 414 ... from numpy import recarray
415 415 importing recarray from numpy on engine(s)
416 416
417 417 """
418 418 import __builtin__
419 419 local_import = __builtin__.__import__
420 420 modules = set()
421 421 results = []
422 422 @util.interactive
423 423 def remote_import(name, fromlist, level):
424 424 """the function to be passed to apply, that actually performs the import
425 425 on the engine, and loads up the user namespace.
426 426 """
427 427 import sys
428 428 user_ns = globals()
429 429 mod = __import__(name, fromlist=fromlist, level=level)
430 430 if fromlist:
431 431 for key in fromlist:
432 432 user_ns[key] = getattr(mod, key)
433 433 else:
434 434 user_ns[name] = sys.modules[name]
435 435
436 436 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
437 437 """the drop-in replacement for __import__, that optionally imports
438 438 locally as well.
439 439 """
440 440 # don't override nested imports
441 441 save_import = __builtin__.__import__
442 442 __builtin__.__import__ = local_import
443 443
444 444 if imp.lock_held():
445 445 # this is a side-effect import, don't do it remotely, or even
446 446 # ignore the local effects
447 447 return local_import(name, globals, locals, fromlist, level)
448 448
449 449 imp.acquire_lock()
450 450 if local:
451 451 mod = local_import(name, globals, locals, fromlist, level)
452 452 else:
453 453 raise NotImplementedError("remote-only imports not yet implemented")
454 454 imp.release_lock()
455 455
456 456 key = name+':'+','.join(fromlist or [])
457 457 if level == -1 and key not in modules:
458 458 modules.add(key)
459 459 if fromlist:
460 460 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
461 461 else:
462 462 print "importing %s on engine(s)"%name
463 463 results.append(self.apply_async(remote_import, name, fromlist, level))
464 464 # restore override
465 465 __builtin__.__import__ = save_import
466 466
467 467 return mod
468 468
469 469 # override __import__
470 470 __builtin__.__import__ = view_import
471 471 try:
472 472 # enter the block
473 473 yield
474 474 except ImportError:
475 475 if local:
476 476 raise
477 477 else:
478 478 # ignore import errors if not doing local imports
479 479 pass
480 480 finally:
481 481 # always restore __import__
482 482 __builtin__.__import__ = local_import
483 483
484 484 for r in results:
485 485 # raise possible remote ImportErrors here
486 486 r.get()
487 487
488 488
489 489 @sync_results
490 490 @save_ids
491 491 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
492 492 """calls f(*args, **kwargs) on remote engines, returning the result.
493 493
494 494 This method sets all of `apply`'s flags via this View's attributes.
495 495
496 496 Parameters
497 497 ----------
498 498
499 499 f : callable
500 500
501 501 args : list [default: empty]
502 502
503 503 kwargs : dict [default: empty]
504 504
505 505 targets : target list [default: self.targets]
506 506 where to run
507 507 block : bool [default: self.block]
508 508 whether to block
509 509 track : bool [default: self.track]
510 510 whether to ask zmq to track the message, for safe non-copying sends
511 511
512 512 Returns
513 513 -------
514 514
515 515 if self.block is False:
516 516 returns AsyncResult
517 517 else:
518 518 returns actual result of f(*args, **kwargs) on the engine(s)
519 519 This will be a list of self.targets is also a list (even length 1), or
520 520 the single result if self.targets is an integer engine id
521 521 """
522 522 args = [] if args is None else args
523 523 kwargs = {} if kwargs is None else kwargs
524 524 block = self.block if block is None else block
525 525 track = self.track if track is None else track
526 526 targets = self.targets if targets is None else targets
527 527
528 528 _idents = self.client._build_targets(targets)[0]
529 529 msg_ids = []
530 530 trackers = []
531 531 for ident in _idents:
532 532 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
533 533 ident=ident)
534 534 if track:
535 535 trackers.append(msg['tracker'])
536 536 msg_ids.append(msg['header']['msg_id'])
537 537 tracker = None if track is False else zmq.MessageTracker(*trackers)
538 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
538 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker)
539 539 if block:
540 540 try:
541 541 return ar.get()
542 542 except KeyboardInterrupt:
543 543 pass
544 544 return ar
545 545
546 546 @spin_after
547 547 def map(self, f, *sequences, **kwargs):
548 548 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
549 549
550 550 Parallel version of builtin `map`, using this View's `targets`.
551 551
552 552 There will be one task per target, so work will be chunked
553 553 if the sequences are longer than `targets`.
554 554
555 555 Results can be iterated as they are ready, but will become available in chunks.
556 556
557 557 Parameters
558 558 ----------
559 559
560 560 f : callable
561 561 function to be mapped
562 562 *sequences: one or more sequences of matching length
563 563 the sequences to be distributed and passed to `f`
564 564 block : bool
565 565 whether to wait for the result or not [default self.block]
566 566
567 567 Returns
568 568 -------
569 569
570 570 if block=False:
571 571 AsyncMapResult
572 572 An object like AsyncResult, but which reassembles the sequence of results
573 573 into a single list. AsyncMapResults can be iterated through before all
574 574 results are complete.
575 575 else:
576 576 list
577 577 the result of map(f,*sequences)
578 578 """
579 579
580 580 block = kwargs.pop('block', self.block)
581 581 for k in kwargs.keys():
582 582 if k not in ['block', 'track']:
583 583 raise TypeError("invalid keyword arg, %r"%k)
584 584
585 585 assert len(sequences) > 0, "must have some sequences to map onto!"
586 586 pf = ParallelFunction(self, f, block=block, **kwargs)
587 587 return pf.map(*sequences)
588 588
589 589 def execute(self, code, targets=None, block=None):
590 590 """Executes `code` on `targets` in blocking or nonblocking manner.
591 591
592 592 ``execute`` is always `bound` (affects engine namespace)
593 593
594 594 Parameters
595 595 ----------
596 596
597 597 code : str
598 598 the code string to be executed
599 599 block : bool
600 600 whether or not to wait until done to return
601 601 default: self.block
602 602 """
603 603 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
604 604
605 605 def run(self, filename, targets=None, block=None):
606 606 """Execute contents of `filename` on my engine(s).
607 607
608 608 This simply reads the contents of the file and calls `execute`.
609 609
610 610 Parameters
611 611 ----------
612 612
613 613 filename : str
614 614 The path to the file
615 615 targets : int/str/list of ints/strs
616 616 the engines on which to execute
617 617 default : all
618 618 block : bool
619 619 whether or not to wait until done
620 620 default: self.block
621 621
622 622 """
623 623 with open(filename, 'r') as f:
624 624 # add newline in case of trailing indented whitespace
625 625 # which will cause SyntaxError
626 626 code = f.read()+'\n'
627 627 return self.execute(code, block=block, targets=targets)
628 628
629 629 def update(self, ns):
630 630 """update remote namespace with dict `ns`
631 631
632 632 See `push` for details.
633 633 """
634 634 return self.push(ns, block=self.block, track=self.track)
635 635
636 636 def push(self, ns, targets=None, block=None, track=None):
637 637 """update remote namespace with dict `ns`
638 638
639 639 Parameters
640 640 ----------
641 641
642 642 ns : dict
643 643 dict of keys with which to update engine namespace(s)
644 644 block : bool [default : self.block]
645 645 whether to wait to be notified of engine receipt
646 646
647 647 """
648 648
649 649 block = block if block is not None else self.block
650 650 track = track if track is not None else self.track
651 651 targets = targets if targets is not None else self.targets
652 652 # applier = self.apply_sync if block else self.apply_async
653 653 if not isinstance(ns, dict):
654 654 raise TypeError("Must be a dict, not %s"%type(ns))
655 655 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
656 656
657 657 def get(self, key_s):
658 658 """get object(s) by `key_s` from remote namespace
659 659
660 660 see `pull` for details.
661 661 """
662 662 # block = block if block is not None else self.block
663 663 return self.pull(key_s, block=True)
664 664
665 665 def pull(self, names, targets=None, block=None):
666 666 """get object(s) by `name` from remote namespace
667 667
668 668 will return one object if it is a key.
669 669 can also take a list of keys, in which case it will return a list of objects.
670 670 """
671 671 block = block if block is not None else self.block
672 672 targets = targets if targets is not None else self.targets
673 673 applier = self.apply_sync if block else self.apply_async
674 674 if isinstance(names, basestring):
675 675 pass
676 676 elif isinstance(names, (list,tuple,set)):
677 677 for key in names:
678 678 if not isinstance(key, basestring):
679 679 raise TypeError("keys must be str, not type %r"%type(key))
680 680 else:
681 681 raise TypeError("names must be strs, not %r"%names)
682 682 return self._really_apply(util._pull, (names,), block=block, targets=targets)
683 683
684 684 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
685 685 """
686 686 Partition a Python sequence and send the partitions to a set of engines.
687 687 """
688 688 block = block if block is not None else self.block
689 689 track = track if track is not None else self.track
690 690 targets = targets if targets is not None else self.targets
691 691
692 692 mapObject = Map.dists[dist]()
693 693 nparts = len(targets)
694 694 msg_ids = []
695 695 trackers = []
696 696 for index, engineid in enumerate(targets):
697 697 partition = mapObject.getPartition(seq, index, nparts)
698 698 if flatten and len(partition) == 1:
699 699 ns = {key: partition[0]}
700 700 else:
701 701 ns = {key: partition}
702 702 r = self.push(ns, block=False, track=track, targets=engineid)
703 703 msg_ids.extend(r.msg_ids)
704 704 if track:
705 705 trackers.append(r._tracker)
706 706
707 707 if track:
708 708 tracker = zmq.MessageTracker(*trackers)
709 709 else:
710 710 tracker = None
711 711
712 712 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
713 713 if block:
714 714 r.wait()
715 715 else:
716 716 return r
717 717
718 718 @sync_results
719 719 @save_ids
720 720 def gather(self, key, dist='b', targets=None, block=None):
721 721 """
722 722 Gather a partitioned sequence on a set of engines as a single local seq.
723 723 """
724 724 block = block if block is not None else self.block
725 725 targets = targets if targets is not None else self.targets
726 726 mapObject = Map.dists[dist]()
727 727 msg_ids = []
728 728
729 729 for index, engineid in enumerate(targets):
730 730 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
731 731
732 732 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
733 733
734 734 if block:
735 735 try:
736 736 return r.get()
737 737 except KeyboardInterrupt:
738 738 pass
739 739 return r
740 740
741 741 def __getitem__(self, key):
742 742 return self.get(key)
743 743
744 744 def __setitem__(self,key, value):
745 745 self.update({key:value})
746 746
747 747 def clear(self, targets=None, block=False):
748 748 """Clear the remote namespaces on my engines."""
749 749 block = block if block is not None else self.block
750 750 targets = targets if targets is not None else self.targets
751 751 return self.client.clear(targets=targets, block=block)
752 752
753 753 def kill(self, targets=None, block=True):
754 754 """Kill my engines."""
755 755 block = block if block is not None else self.block
756 756 targets = targets if targets is not None else self.targets
757 757 return self.client.kill(targets=targets, block=block)
758 758
759 759 #----------------------------------------
760 760 # activate for %px,%autopx magics
761 761 #----------------------------------------
762 762 def activate(self):
763 763 """Make this `View` active for parallel magic commands.
764 764
765 765 IPython has a magic command syntax to work with `MultiEngineClient` objects.
766 766 In a given IPython session there is a single active one. While
767 767 there can be many `Views` created and used by the user,
768 768 there is only one active one. The active `View` is used whenever
769 769 the magic commands %px and %autopx are used.
770 770
771 771 The activate() method is called on a given `View` to make it
772 772 active. Once this has been done, the magic commands can be used.
773 773 """
774 774
775 775 try:
776 776 # This is injected into __builtins__.
777 777 ip = get_ipython()
778 778 except NameError:
779 779 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
780 780 else:
781 781 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
782 782 if pmagic is None:
783 783 ip.magic_load_ext('parallelmagic')
784 784 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
785 785
786 786 pmagic.active_view = self
787 787
788 788
789 789 @skip_doctest
790 790 class LoadBalancedView(View):
791 791 """An load-balancing View that only executes via the Task scheduler.
792 792
793 793 Load-balanced views can be created with the client's `view` method:
794 794
795 795 >>> v = client.load_balanced_view()
796 796
797 797 or targets can be specified, to restrict the potential destinations:
798 798
799 799 >>> v = client.client.load_balanced_view([1,3])
800 800
801 801 which would restrict loadbalancing to between engines 1 and 3.
802 802
803 803 """
804 804
805 805 follow=Any()
806 806 after=Any()
807 807 timeout=CFloat()
808 808 retries = Integer(0)
809 809
810 810 _task_scheme = Any()
811 811 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
812 812
813 813 def __init__(self, client=None, socket=None, **flags):
814 814 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
815 815 self._task_scheme=client._task_scheme
816 816
817 817 def _validate_dependency(self, dep):
818 818 """validate a dependency.
819 819
820 820 For use in `set_flags`.
821 821 """
822 822 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
823 823 return True
824 824 elif isinstance(dep, (list,set, tuple)):
825 825 for d in dep:
826 826 if not isinstance(d, (basestring, AsyncResult)):
827 827 return False
828 828 elif isinstance(dep, dict):
829 829 if set(dep.keys()) != set(Dependency().as_dict().keys()):
830 830 return False
831 831 if not isinstance(dep['msg_ids'], list):
832 832 return False
833 833 for d in dep['msg_ids']:
834 834 if not isinstance(d, basestring):
835 835 return False
836 836 else:
837 837 return False
838 838
839 839 return True
840 840
841 841 def _render_dependency(self, dep):
842 842 """helper for building jsonable dependencies from various input forms."""
843 843 if isinstance(dep, Dependency):
844 844 return dep.as_dict()
845 845 elif isinstance(dep, AsyncResult):
846 846 return dep.msg_ids
847 847 elif dep is None:
848 848 return []
849 849 else:
850 850 # pass to Dependency constructor
851 851 return list(Dependency(dep))
852 852
853 853 def set_flags(self, **kwargs):
854 854 """set my attribute flags by keyword.
855 855
856 856 A View is a wrapper for the Client's apply method, but with attributes
857 857 that specify keyword arguments, those attributes can be set by keyword
858 858 argument with this method.
859 859
860 860 Parameters
861 861 ----------
862 862
863 863 block : bool
864 864 whether to wait for results
865 865 track : bool
866 866 whether to create a MessageTracker to allow the user to
867 867 safely edit after arrays and buffers during non-copying
868 868 sends.
869 869
870 870 after : Dependency or collection of msg_ids
871 871 Only for load-balanced execution (targets=None)
872 872 Specify a list of msg_ids as a time-based dependency.
873 873 This job will only be run *after* the dependencies
874 874 have been met.
875 875
876 876 follow : Dependency or collection of msg_ids
877 877 Only for load-balanced execution (targets=None)
878 878 Specify a list of msg_ids as a location-based dependency.
879 879 This job will only be run on an engine where this dependency
880 880 is met.
881 881
882 882 timeout : float/int or None
883 883 Only for load-balanced execution (targets=None)
884 884 Specify an amount of time (in seconds) for the scheduler to
885 885 wait for dependencies to be met before failing with a
886 886 DependencyTimeout.
887 887
888 888 retries : int
889 889 Number of times a task will be retried on failure.
890 890 """
891 891
892 892 super(LoadBalancedView, self).set_flags(**kwargs)
893 893 for name in ('follow', 'after'):
894 894 if name in kwargs:
895 895 value = kwargs[name]
896 896 if self._validate_dependency(value):
897 897 setattr(self, name, value)
898 898 else:
899 899 raise ValueError("Invalid dependency: %r"%value)
900 900 if 'timeout' in kwargs:
901 901 t = kwargs['timeout']
902 902 if not isinstance(t, (int, long, float, type(None))):
903 903 raise TypeError("Invalid type for timeout: %r"%type(t))
904 904 if t is not None:
905 905 if t < 0:
906 906 raise ValueError("Invalid timeout: %s"%t)
907 907 self.timeout = t
908 908
909 909 @sync_results
910 910 @save_ids
911 911 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
912 912 after=None, follow=None, timeout=None,
913 913 targets=None, retries=None):
914 914 """calls f(*args, **kwargs) on a remote engine, returning the result.
915 915
916 916 This method temporarily sets all of `apply`'s flags for a single call.
917 917
918 918 Parameters
919 919 ----------
920 920
921 921 f : callable
922 922
923 923 args : list [default: empty]
924 924
925 925 kwargs : dict [default: empty]
926 926
927 927 block : bool [default: self.block]
928 928 whether to block
929 929 track : bool [default: self.track]
930 930 whether to ask zmq to track the message, for safe non-copying sends
931 931
932 932 !!!!!! TODO: THE REST HERE !!!!
933 933
934 934 Returns
935 935 -------
936 936
937 937 if self.block is False:
938 938 returns AsyncResult
939 939 else:
940 940 returns actual result of f(*args, **kwargs) on the engine(s)
941 941 This will be a list of self.targets is also a list (even length 1), or
942 942 the single result if self.targets is an integer engine id
943 943 """
944 944
945 945 # validate whether we can run
946 946 if self._socket.closed:
947 947 msg = "Task farming is disabled"
948 948 if self._task_scheme == 'pure':
949 949 msg += " because the pure ZMQ scheduler cannot handle"
950 950 msg += " disappearing engines."
951 951 raise RuntimeError(msg)
952 952
953 953 if self._task_scheme == 'pure':
954 954 # pure zmq scheme doesn't support extra features
955 955 msg = "Pure ZMQ scheduler doesn't support the following flags:"
956 956 "follow, after, retries, targets, timeout"
957 957 if (follow or after or retries or targets or timeout):
958 958 # hard fail on Scheduler flags
959 959 raise RuntimeError(msg)
960 960 if isinstance(f, dependent):
961 961 # soft warn on functional dependencies
962 962 warnings.warn(msg, RuntimeWarning)
963 963
964 964 # build args
965 965 args = [] if args is None else args
966 966 kwargs = {} if kwargs is None else kwargs
967 967 block = self.block if block is None else block
968 968 track = self.track if track is None else track
969 969 after = self.after if after is None else after
970 970 retries = self.retries if retries is None else retries
971 971 follow = self.follow if follow is None else follow
972 972 timeout = self.timeout if timeout is None else timeout
973 973 targets = self.targets if targets is None else targets
974 974
975 975 if not isinstance(retries, int):
976 976 raise TypeError('retries must be int, not %r'%type(retries))
977 977
978 978 if targets is None:
979 979 idents = []
980 980 else:
981 981 idents = self.client._build_targets(targets)[0]
982 982 # ensure *not* bytes
983 983 idents = [ ident.decode() for ident in idents ]
984 984
985 985 after = self._render_dependency(after)
986 986 follow = self._render_dependency(follow)
987 987 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
988 988
989 989 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
990 990 subheader=subheader)
991 991 tracker = None if track is False else msg['tracker']
992 992
993 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
993 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
994 994
995 995 if block:
996 996 try:
997 997 return ar.get()
998 998 except KeyboardInterrupt:
999 999 pass
1000 1000 return ar
1001 1001
1002 1002 @spin_after
1003 1003 @save_ids
1004 1004 def map(self, f, *sequences, **kwargs):
1005 1005 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1006 1006
1007 1007 Parallel version of builtin `map`, load-balanced by this View.
1008 1008
1009 1009 `block`, and `chunksize` can be specified by keyword only.
1010 1010
1011 1011 Each `chunksize` elements will be a separate task, and will be
1012 1012 load-balanced. This lets individual elements be available for iteration
1013 1013 as soon as they arrive.
1014 1014
1015 1015 Parameters
1016 1016 ----------
1017 1017
1018 1018 f : callable
1019 1019 function to be mapped
1020 1020 *sequences: one or more sequences of matching length
1021 1021 the sequences to be distributed and passed to `f`
1022 1022 block : bool [default self.block]
1023 1023 whether to wait for the result or not
1024 1024 track : bool
1025 1025 whether to create a MessageTracker to allow the user to
1026 1026 safely edit after arrays and buffers during non-copying
1027 1027 sends.
1028 1028 chunksize : int [default 1]
1029 1029 how many elements should be in each task.
1030 1030 ordered : bool [default True]
1031 1031 Whether the results should be gathered as they arrive, or enforce
1032 1032 the order of submission.
1033 1033
1034 1034 Only applies when iterating through AsyncMapResult as results arrive.
1035 1035 Has no effect when block=True.
1036 1036
1037 1037 Returns
1038 1038 -------
1039 1039
1040 1040 if block=False:
1041 1041 AsyncMapResult
1042 1042 An object like AsyncResult, but which reassembles the sequence of results
1043 1043 into a single list. AsyncMapResults can be iterated through before all
1044 1044 results are complete.
1045 1045 else:
1046 1046 the result of map(f,*sequences)
1047 1047
1048 1048 """
1049 1049
1050 1050 # default
1051 1051 block = kwargs.get('block', self.block)
1052 1052 chunksize = kwargs.get('chunksize', 1)
1053 1053 ordered = kwargs.get('ordered', True)
1054 1054
1055 1055 keyset = set(kwargs.keys())
1056 1056 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1057 1057 if extra_keys:
1058 1058 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1059 1059
1060 1060 assert len(sequences) > 0, "must have some sequences to map onto!"
1061 1061
1062 1062 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1063 1063 return pf.map(*sequences)
1064 1064
1065 1065 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,472 +1,493 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import time
21 21 from tempfile import mktemp
22 22 from StringIO import StringIO
23 23
24 24 import zmq
25 25 from nose import SkipTest
26 26
27 27 from IPython.testing import decorators as dec
28 28
29 29 from IPython import parallel as pmod
30 30 from IPython.parallel import error
31 31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
32 32 from IPython.parallel import DirectView
33 33 from IPython.parallel.util import interactive
34 34
35 35 from IPython.parallel.tests import add_engines
36 36
37 37 from .clienttest import ClusterTestCase, crash, wait, skip_without
38 38
39 39 def setup():
40 40 add_engines(3)
41 41
42 42 class TestView(ClusterTestCase):
43 43
44 44 def test_z_crash_mux(self):
45 45 """test graceful handling of engine death (direct)"""
46 46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
47 47 # self.add_engines(1)
48 48 eid = self.client.ids[-1]
49 49 ar = self.client[eid].apply_async(crash)
50 50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 51 eid = ar.engine_id
52 52 tic = time.time()
53 53 while eid in self.client.ids and time.time()-tic < 5:
54 54 time.sleep(.01)
55 55 self.client.spin()
56 56 self.assertFalse(eid in self.client.ids, "Engine should have died")
57 57
58 58 def test_push_pull(self):
59 59 """test pushing and pulling"""
60 60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
61 61 t = self.client.ids[-1]
62 62 v = self.client[t]
63 63 push = v.push
64 64 pull = v.pull
65 65 v.block=True
66 66 nengines = len(self.client)
67 67 push({'data':data})
68 68 d = pull('data')
69 69 self.assertEquals(d, data)
70 70 self.client[:].push({'data':data})
71 71 d = self.client[:].pull('data', block=True)
72 72 self.assertEquals(d, nengines*[data])
73 73 ar = push({'data':data}, block=False)
74 74 self.assertTrue(isinstance(ar, AsyncResult))
75 75 r = ar.get()
76 76 ar = self.client[:].pull('data', block=False)
77 77 self.assertTrue(isinstance(ar, AsyncResult))
78 78 r = ar.get()
79 79 self.assertEquals(r, nengines*[data])
80 80 self.client[:].push(dict(a=10,b=20))
81 81 r = self.client[:].pull(('a','b'), block=True)
82 82 self.assertEquals(r, nengines*[[10,20]])
83 83
84 84 def test_push_pull_function(self):
85 85 "test pushing and pulling functions"
86 86 def testf(x):
87 87 return 2.0*x
88 88
89 89 t = self.client.ids[-1]
90 90 v = self.client[t]
91 91 v.block=True
92 92 push = v.push
93 93 pull = v.pull
94 94 execute = v.execute
95 95 push({'testf':testf})
96 96 r = pull('testf')
97 97 self.assertEqual(r(1.0), testf(1.0))
98 98 execute('r = testf(10)')
99 99 r = pull('r')
100 100 self.assertEquals(r, testf(10))
101 101 ar = self.client[:].push({'testf':testf}, block=False)
102 102 ar.get()
103 103 ar = self.client[:].pull('testf', block=False)
104 104 rlist = ar.get()
105 105 for r in rlist:
106 106 self.assertEqual(r(1.0), testf(1.0))
107 107 execute("def g(x): return x*x")
108 108 r = pull(('testf','g'))
109 109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
110 110
111 111 def test_push_function_globals(self):
112 112 """test that pushed functions have access to globals"""
113 113 @interactive
114 114 def geta():
115 115 return a
116 116 # self.add_engines(1)
117 117 v = self.client[-1]
118 118 v.block=True
119 119 v['f'] = geta
120 120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 121 v.execute('a=5')
122 122 v.execute('b=f()')
123 123 self.assertEquals(v['b'], 5)
124 124
125 125 def test_push_function_defaults(self):
126 126 """test that pushed functions preserve default args"""
127 127 def echo(a=10):
128 128 return a
129 129 v = self.client[-1]
130 130 v.block=True
131 131 v['f'] = echo
132 132 v.execute('b=f()')
133 133 self.assertEquals(v['b'], 10)
134 134
135 135 def test_get_result(self):
136 136 """test getting results from the Hub."""
137 137 c = pmod.Client(profile='iptest')
138 138 # self.add_engines(1)
139 139 t = c.ids[-1]
140 140 v = c[t]
141 141 v2 = self.client[t]
142 142 ar = v.apply_async(wait, 1)
143 143 # give the monitor time to notice the message
144 144 time.sleep(.25)
145 145 ahr = v2.get_result(ar.msg_ids)
146 146 self.assertTrue(isinstance(ahr, AsyncHubResult))
147 147 self.assertEquals(ahr.get(), ar.get())
148 148 ar2 = v2.get_result(ar.msg_ids)
149 149 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 150 c.spin()
151 151 c.close()
152 152
153 153 def test_run_newline(self):
154 154 """test that run appends newline to files"""
155 155 tmpfile = mktemp()
156 156 with open(tmpfile, 'w') as f:
157 157 f.write("""def g():
158 158 return 5
159 159 """)
160 160 v = self.client[-1]
161 161 v.run(tmpfile, block=True)
162 162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
163 163
164 164 def test_apply_tracked(self):
165 165 """test tracking for apply"""
166 166 # self.add_engines(1)
167 167 t = self.client.ids[-1]
168 168 v = self.client[t]
169 169 v.block=False
170 170 def echo(n=1024*1024, **kwargs):
171 171 with v.temp_flags(**kwargs):
172 172 return v.apply(lambda x: x, 'x'*n)
173 173 ar = echo(1, track=False)
174 174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 175 self.assertTrue(ar.sent)
176 176 ar = echo(track=True)
177 177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 178 self.assertEquals(ar.sent, ar._tracker.done)
179 179 ar._tracker.wait()
180 180 self.assertTrue(ar.sent)
181 181
182 182 def test_push_tracked(self):
183 183 t = self.client.ids[-1]
184 184 ns = dict(x='x'*1024*1024)
185 185 v = self.client[t]
186 186 ar = v.push(ns, block=False, track=False)
187 187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 188 self.assertTrue(ar.sent)
189 189
190 190 ar = v.push(ns, block=False, track=True)
191 191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 192 ar._tracker.wait()
193 193 self.assertEquals(ar.sent, ar._tracker.done)
194 194 self.assertTrue(ar.sent)
195 195 ar.get()
196 196
197 197 def test_scatter_tracked(self):
198 198 t = self.client.ids
199 199 x='x'*1024*1024
200 200 ar = self.client[t].scatter('x', x, block=False, track=False)
201 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 202 self.assertTrue(ar.sent)
203 203
204 204 ar = self.client[t].scatter('x', x, block=False, track=True)
205 205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 206 self.assertEquals(ar.sent, ar._tracker.done)
207 207 ar._tracker.wait()
208 208 self.assertTrue(ar.sent)
209 209 ar.get()
210 210
211 211 def test_remote_reference(self):
212 212 v = self.client[-1]
213 213 v['a'] = 123
214 214 ra = pmod.Reference('a')
215 215 b = v.apply_sync(lambda x: x, ra)
216 216 self.assertEquals(b, 123)
217 217
218 218
219 219 def test_scatter_gather(self):
220 220 view = self.client[:]
221 221 seq1 = range(16)
222 222 view.scatter('a', seq1)
223 223 seq2 = view.gather('a', block=True)
224 224 self.assertEquals(seq2, seq1)
225 225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
226 226
227 227 @skip_without('numpy')
228 228 def test_scatter_gather_numpy(self):
229 229 import numpy
230 230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
231 231 view = self.client[:]
232 232 a = numpy.arange(64)
233 233 view.scatter('a', a)
234 234 b = view.gather('a', block=True)
235 235 assert_array_equal(b, a)
236 236
237 237 def test_map(self):
238 238 view = self.client[:]
239 239 def f(x):
240 240 return x**2
241 241 data = range(16)
242 242 r = view.map_sync(f, data)
243 243 self.assertEquals(r, map(f, data))
244 244
245 245 def test_map_iterable(self):
246 246 """test map on iterables (direct)"""
247 247 view = self.client[:]
248 248 # 101 is prime, so it won't be evenly distributed
249 249 arr = range(101)
250 250 # ensure it will be an iterator, even in Python 3
251 251 it = iter(arr)
252 252 r = view.map_sync(lambda x:x, arr)
253 253 self.assertEquals(r, list(arr))
254 254
255 255 def test_scatterGatherNonblocking(self):
256 256 data = range(16)
257 257 view = self.client[:]
258 258 view.scatter('a', data, block=False)
259 259 ar = view.gather('a', block=False)
260 260 self.assertEquals(ar.get(), data)
261 261
262 262 @skip_without('numpy')
263 263 def test_scatter_gather_numpy_nonblocking(self):
264 264 import numpy
265 265 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
266 266 a = numpy.arange(64)
267 267 view = self.client[:]
268 268 ar = view.scatter('a', a, block=False)
269 269 self.assertTrue(isinstance(ar, AsyncResult))
270 270 amr = view.gather('a', block=False)
271 271 self.assertTrue(isinstance(amr, AsyncMapResult))
272 272 assert_array_equal(amr.get(), a)
273 273
274 274 def test_execute(self):
275 275 view = self.client[:]
276 276 # self.client.debug=True
277 277 execute = view.execute
278 278 ar = execute('c=30', block=False)
279 279 self.assertTrue(isinstance(ar, AsyncResult))
280 280 ar = execute('d=[0,1,2]', block=False)
281 281 self.client.wait(ar, 1)
282 282 self.assertEquals(len(ar.get()), len(self.client))
283 283 for c in view['c']:
284 284 self.assertEquals(c, 30)
285 285
286 286 def test_abort(self):
287 287 view = self.client[-1]
288 288 ar = view.execute('import time; time.sleep(1)', block=False)
289 289 ar2 = view.apply_async(lambda : 2)
290 290 ar3 = view.apply_async(lambda : 3)
291 291 view.abort(ar2)
292 292 view.abort(ar3.msg_ids)
293 293 self.assertRaises(error.TaskAborted, ar2.get)
294 294 self.assertRaises(error.TaskAborted, ar3.get)
295 295
296 296 def test_abort_all(self):
297 297 """view.abort() aborts all outstanding tasks"""
298 298 view = self.client[-1]
299 299 ars = [ view.apply_async(time.sleep, 1) for i in range(10) ]
300 300 view.abort()
301 301 view.wait(timeout=5)
302 302 for ar in ars[5:]:
303 303 self.assertRaises(error.TaskAborted, ar.get)
304 304
305 305 def test_temp_flags(self):
306 306 view = self.client[-1]
307 307 view.block=True
308 308 with view.temp_flags(block=False):
309 309 self.assertFalse(view.block)
310 310 self.assertTrue(view.block)
311 311
312 312 @dec.known_failure_py3
313 313 def test_importer(self):
314 314 view = self.client[-1]
315 315 view.clear(block=True)
316 316 with view.importer:
317 317 import re
318 318
319 319 @interactive
320 320 def findall(pat, s):
321 321 # this globals() step isn't necessary in real code
322 322 # only to prevent a closure in the test
323 323 re = globals()['re']
324 324 return re.findall(pat, s)
325 325
326 326 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
327 327
328 328 # parallel magic tests
329 329
330 330 def test_magic_px_blocking(self):
331 331 ip = get_ipython()
332 332 v = self.client[-1]
333 333 v.activate()
334 334 v.block=True
335 335
336 336 ip.magic_px('a=5')
337 337 self.assertEquals(v['a'], 5)
338 338 ip.magic_px('a=10')
339 339 self.assertEquals(v['a'], 10)
340 340 sio = StringIO()
341 341 savestdout = sys.stdout
342 342 sys.stdout = sio
343 343 # just 'print a' worst ~99% of the time, but this ensures that
344 344 # the stdout message has arrived when the result is finished:
345 345 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
346 346 sys.stdout = savestdout
347 347 buf = sio.getvalue()
348 348 self.assertTrue('[stdout:' in buf, buf)
349 349 self.assertTrue(buf.rstrip().endswith('10'))
350 350 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
351 351
352 352 def test_magic_px_nonblocking(self):
353 353 ip = get_ipython()
354 354 v = self.client[-1]
355 355 v.activate()
356 356 v.block=False
357 357
358 358 ip.magic_px('a=5')
359 359 self.assertEquals(v['a'], 5)
360 360 ip.magic_px('a=10')
361 361 self.assertEquals(v['a'], 10)
362 362 sio = StringIO()
363 363 savestdout = sys.stdout
364 364 sys.stdout = sio
365 365 ip.magic_px('print a')
366 366 sys.stdout = savestdout
367 367 buf = sio.getvalue()
368 368 self.assertFalse('[stdout:%i]'%v.targets in buf)
369 369 ip.magic_px('1/0')
370 370 ar = v.get_result(-1)
371 371 self.assertRaisesRemote(ZeroDivisionError, ar.get)
372 372
373 373 def test_magic_autopx_blocking(self):
374 374 ip = get_ipython()
375 375 v = self.client[-1]
376 376 v.activate()
377 377 v.block=True
378 378
379 379 sio = StringIO()
380 380 savestdout = sys.stdout
381 381 sys.stdout = sio
382 382 ip.magic_autopx()
383 383 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
384 384 ip.run_cell('print b')
385 385 ip.run_cell("b/c")
386 386 ip.run_code(compile('b*=2', '', 'single'))
387 387 ip.magic_autopx()
388 388 sys.stdout = savestdout
389 389 output = sio.getvalue().strip()
390 390 self.assertTrue(output.startswith('%autopx enabled'))
391 391 self.assertTrue(output.endswith('%autopx disabled'))
392 392 self.assertTrue('RemoteError: ZeroDivisionError' in output)
393 393 ar = v.get_result(-2)
394 394 self.assertEquals(v['a'], 5)
395 395 self.assertEquals(v['b'], 20)
396 396 self.assertRaisesRemote(ZeroDivisionError, ar.get)
397 397
398 398 def test_magic_autopx_nonblocking(self):
399 399 ip = get_ipython()
400 400 v = self.client[-1]
401 401 v.activate()
402 402 v.block=False
403 403
404 404 sio = StringIO()
405 405 savestdout = sys.stdout
406 406 sys.stdout = sio
407 407 ip.magic_autopx()
408 408 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
409 409 ip.run_cell('print b')
410 410 ip.run_cell("b/c")
411 411 ip.run_code(compile('b*=2', '', 'single'))
412 412 ip.magic_autopx()
413 413 sys.stdout = savestdout
414 414 output = sio.getvalue().strip()
415 415 self.assertTrue(output.startswith('%autopx enabled'))
416 416 self.assertTrue(output.endswith('%autopx disabled'))
417 417 self.assertFalse('ZeroDivisionError' in output)
418 418 ar = v.get_result(-2)
419 419 self.assertEquals(v['a'], 5)
420 420 self.assertEquals(v['b'], 20)
421 421 self.assertRaisesRemote(ZeroDivisionError, ar.get)
422 422
423 423 def test_magic_result(self):
424 424 ip = get_ipython()
425 425 v = self.client[-1]
426 426 v.activate()
427 427 v['a'] = 111
428 428 ra = v['a']
429 429
430 430 ar = ip.magic_result()
431 431 self.assertEquals(ar.msg_ids, [v.history[-1]])
432 432 self.assertEquals(ar.get(), 111)
433 433 ar = ip.magic_result('-2')
434 434 self.assertEquals(ar.msg_ids, [v.history[-2]])
435 435
436 436 def test_unicode_execute(self):
437 437 """test executing unicode strings"""
438 438 v = self.client[-1]
439 439 v.block=True
440 440 if sys.version_info[0] >= 3:
441 441 code="a='é'"
442 442 else:
443 443 code=u"a=u'é'"
444 444 v.execute(code)
445 445 self.assertEquals(v['a'], u'é')
446 446
447 447 def test_unicode_apply_result(self):
448 448 """test unicode apply results"""
449 449 v = self.client[-1]
450 450 r = v.apply_sync(lambda : u'é')
451 451 self.assertEquals(r, u'é')
452 452
453 453 def test_unicode_apply_arg(self):
454 454 """test passing unicode arguments to apply"""
455 455 v = self.client[-1]
456 456
457 457 @interactive
458 458 def check_unicode(a, check):
459 459 assert isinstance(a, unicode), "%r is not unicode"%a
460 460 assert isinstance(check, bytes), "%r is not bytes"%check
461 461 assert a.encode('utf8') == check, "%s != %s"%(a,check)
462 462
463 463 for s in [ u'é', u'ßø®∫',u'asdf' ]:
464 464 try:
465 465 v.apply_sync(check_unicode, s, s.encode('utf8'))
466 466 except error.RemoteError as e:
467 467 if e.ename == 'AssertionError':
468 468 self.fail(e.evalue)
469 469 else:
470 470 raise e
471 471
472 def test_map_reference(self):
473 """view.map(<Reference>, *seqs) should work"""
474 v = self.client[:]
475 v.scatter('n', self.client.ids, flatten=True)
476 v.execute("f = lambda x,y: x*y")
477 rf = pmod.Reference('f')
478 nlist = list(range(10))
479 mlist = nlist[::-1]
480 expected = [ m*n for m,n in zip(mlist, nlist) ]
481 result = v.map_sync(rf, mlist, nlist)
482 self.assertEquals(result, expected)
483
484 def test_apply_reference(self):
485 """view.apply(<Reference>, *args) should work"""
486 v = self.client[:]
487 v.scatter('n', self.client.ids, flatten=True)
488 v.execute("f = lambda x: n*x")
489 rf = pmod.Reference('f')
490 result = v.apply_sync(rf, 5)
491 expected = [ 5*id for id in self.client.ids ]
492 self.assertEquals(result, expected)
472 493
General Comments 0
You need to be logged in to leave comments. Login now