##// END OF EJS Templates
Changed documentation so it correctly identifies it as quieting the attempt at an import, and not the success of the import.
Ben Edwards -
Show More
@@ -1,1069 +1,1069 b''
1 1 """Views of remote engines.
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import imp
19 19 import sys
20 20 import warnings
21 21 from contextlib import contextmanager
22 22 from types import ModuleType
23 23
24 24 import zmq
25 25
26 26 from IPython.testing.skipdoctest import skip_doctest
27 27 from IPython.utils.traitlets import (
28 28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
29 29 )
30 30 from IPython.external.decorator import decorator
31 31
32 32 from IPython.parallel import util
33 33 from IPython.parallel.controller.dependency import Dependency, dependent
34 34
35 35 from . import map as Map
36 36 from .asyncresult import AsyncResult, AsyncMapResult
37 37 from .remotefunction import ParallelFunction, parallel, remote, getname
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Decorators
41 41 #-----------------------------------------------------------------------------
42 42
43 43 @decorator
44 44 def save_ids(f, self, *args, **kwargs):
45 45 """Keep our history and outstanding attributes up to date after a method call."""
46 46 n_previous = len(self.client.history)
47 47 try:
48 48 ret = f(self, *args, **kwargs)
49 49 finally:
50 50 nmsgs = len(self.client.history) - n_previous
51 51 msg_ids = self.client.history[-nmsgs:]
52 52 self.history.extend(msg_ids)
53 53 map(self.outstanding.add, msg_ids)
54 54 return ret
55 55
56 56 @decorator
57 57 def sync_results(f, self, *args, **kwargs):
58 58 """sync relevant results from self.client to our results attribute."""
59 59 ret = f(self, *args, **kwargs)
60 60 delta = self.outstanding.difference(self.client.outstanding)
61 61 completed = self.outstanding.intersection(delta)
62 62 self.outstanding = self.outstanding.difference(completed)
63 63 for msg_id in completed:
64 64 self.results[msg_id] = self.client.results[msg_id]
65 65 return ret
66 66
67 67 @decorator
68 68 def spin_after(f, self, *args, **kwargs):
69 69 """call spin after the method."""
70 70 ret = f(self, *args, **kwargs)
71 71 self.spin()
72 72 return ret
73 73
74 74 #-----------------------------------------------------------------------------
75 75 # Classes
76 76 #-----------------------------------------------------------------------------
77 77
78 78 @skip_doctest
79 79 class View(HasTraits):
80 80 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
81 81
82 82 Don't use this class, use subclasses.
83 83
84 84 Methods
85 85 -------
86 86
87 87 spin
88 88 flushes incoming results and registration state changes
89 89 control methods spin, and requesting `ids` also ensures up to date
90 90
91 91 wait
92 92 wait on one or more msg_ids
93 93
94 94 execution methods
95 95 apply
96 96 legacy: execute, run
97 97
98 98 data movement
99 99 push, pull, scatter, gather
100 100
101 101 query methods
102 102 get_result, queue_status, purge_results, result_status
103 103
104 104 control methods
105 105 abort, shutdown
106 106
107 107 """
108 108 # flags
109 109 block=Bool(False)
110 110 track=Bool(True)
111 111 targets = Any()
112 112
113 113 history=List()
114 114 outstanding = Set()
115 115 results = Dict()
116 116 client = Instance('IPython.parallel.Client')
117 117
118 118 _socket = Instance('zmq.Socket')
119 119 _flag_names = List(['targets', 'block', 'track'])
120 120 _targets = Any()
121 121 _idents = Any()
122 122
123 123 def __init__(self, client=None, socket=None, **flags):
124 124 super(View, self).__init__(client=client, _socket=socket)
125 125 self.block = client.block
126 126
127 127 self.set_flags(**flags)
128 128
129 129 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
130 130
131 131
132 132 def __repr__(self):
133 133 strtargets = str(self.targets)
134 134 if len(strtargets) > 16:
135 135 strtargets = strtargets[:12]+'...]'
136 136 return "<%s %s>"%(self.__class__.__name__, strtargets)
137 137
138 138 def set_flags(self, **kwargs):
139 139 """set my attribute flags by keyword.
140 140
141 141 Views determine behavior with a few attributes (`block`, `track`, etc.).
142 142 These attributes can be set all at once by name with this method.
143 143
144 144 Parameters
145 145 ----------
146 146
147 147 block : bool
148 148 whether to wait for results
149 149 track : bool
150 150 whether to create a MessageTracker to allow the user to
151 151 safely edit after arrays and buffers during non-copying
152 152 sends.
153 153 """
154 154 for name, value in kwargs.iteritems():
155 155 if name not in self._flag_names:
156 156 raise KeyError("Invalid name: %r"%name)
157 157 else:
158 158 setattr(self, name, value)
159 159
160 160 @contextmanager
161 161 def temp_flags(self, **kwargs):
162 162 """temporarily set flags, for use in `with` statements.
163 163
164 164 See set_flags for permanent setting of flags
165 165
166 166 Examples
167 167 --------
168 168
169 169 >>> view.track=False
170 170 ...
171 171 >>> with view.temp_flags(track=True):
172 172 ... ar = view.apply(dostuff, my_big_array)
173 173 ... ar.tracker.wait() # wait for send to finish
174 174 >>> view.track
175 175 False
176 176
177 177 """
178 178 # preflight: save flags, and set temporaries
179 179 saved_flags = {}
180 180 for f in self._flag_names:
181 181 saved_flags[f] = getattr(self, f)
182 182 self.set_flags(**kwargs)
183 183 # yield to the with-statement block
184 184 try:
185 185 yield
186 186 finally:
187 187 # postflight: restore saved flags
188 188 self.set_flags(**saved_flags)
189 189
190 190
191 191 #----------------------------------------------------------------
192 192 # apply
193 193 #----------------------------------------------------------------
194 194
195 195 @sync_results
196 196 @save_ids
197 197 def _really_apply(self, f, args, kwargs, block=None, **options):
198 198 """wrapper for client.send_apply_message"""
199 199 raise NotImplementedError("Implement in subclasses")
200 200
201 201 def apply(self, f, *args, **kwargs):
202 202 """calls f(*args, **kwargs) on remote engines, returning the result.
203 203
204 204 This method sets all apply flags via this View's attributes.
205 205
206 206 if self.block is False:
207 207 returns AsyncResult
208 208 else:
209 209 returns actual result of f(*args, **kwargs)
210 210 """
211 211 return self._really_apply(f, args, kwargs)
212 212
213 213 def apply_async(self, f, *args, **kwargs):
214 214 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
215 215
216 216 returns AsyncResult
217 217 """
218 218 return self._really_apply(f, args, kwargs, block=False)
219 219
220 220 @spin_after
221 221 def apply_sync(self, f, *args, **kwargs):
222 222 """calls f(*args, **kwargs) on remote engines in a blocking manner,
223 223 returning the result.
224 224
225 225 returns: actual result of f(*args, **kwargs)
226 226 """
227 227 return self._really_apply(f, args, kwargs, block=True)
228 228
229 229 #----------------------------------------------------------------
230 230 # wrappers for client and control methods
231 231 #----------------------------------------------------------------
232 232 @sync_results
233 233 def spin(self):
234 234 """spin the client, and sync"""
235 235 self.client.spin()
236 236
237 237 @sync_results
238 238 def wait(self, jobs=None, timeout=-1):
239 239 """waits on one or more `jobs`, for up to `timeout` seconds.
240 240
241 241 Parameters
242 242 ----------
243 243
244 244 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
245 245 ints are indices to self.history
246 246 strs are msg_ids
247 247 default: wait on all outstanding messages
248 248 timeout : float
249 249 a time in seconds, after which to give up.
250 250 default is -1, which means no timeout
251 251
252 252 Returns
253 253 -------
254 254
255 255 True : when all msg_ids are done
256 256 False : timeout reached, some msg_ids still outstanding
257 257 """
258 258 if jobs is None:
259 259 jobs = self.history
260 260 return self.client.wait(jobs, timeout)
261 261
262 262 def abort(self, jobs=None, targets=None, block=None):
263 263 """Abort jobs on my engines.
264 264
265 265 Parameters
266 266 ----------
267 267
268 268 jobs : None, str, list of strs, optional
269 269 if None: abort all jobs.
270 270 else: abort specific msg_id(s).
271 271 """
272 272 block = block if block is not None else self.block
273 273 targets = targets if targets is not None else self.targets
274 274 jobs = jobs if jobs is not None else list(self.outstanding)
275 275
276 276 return self.client.abort(jobs=jobs, targets=targets, block=block)
277 277
278 278 def queue_status(self, targets=None, verbose=False):
279 279 """Fetch the Queue status of my engines"""
280 280 targets = targets if targets is not None else self.targets
281 281 return self.client.queue_status(targets=targets, verbose=verbose)
282 282
283 283 def purge_results(self, jobs=[], targets=[]):
284 284 """Instruct the controller to forget specific results."""
285 285 if targets is None or targets == 'all':
286 286 targets = self.targets
287 287 return self.client.purge_results(jobs=jobs, targets=targets)
288 288
289 289 def shutdown(self, targets=None, restart=False, hub=False, block=None):
290 290 """Terminates one or more engine processes, optionally including the hub.
291 291 """
292 292 block = self.block if block is None else block
293 293 if targets is None or targets == 'all':
294 294 targets = self.targets
295 295 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
296 296
297 297 @spin_after
298 298 def get_result(self, indices_or_msg_ids=None):
299 299 """return one or more results, specified by history index or msg_id.
300 300
301 301 See client.get_result for details.
302 302
303 303 """
304 304
305 305 if indices_or_msg_ids is None:
306 306 indices_or_msg_ids = -1
307 307 if isinstance(indices_or_msg_ids, int):
308 308 indices_or_msg_ids = self.history[indices_or_msg_ids]
309 309 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
310 310 indices_or_msg_ids = list(indices_or_msg_ids)
311 311 for i,index in enumerate(indices_or_msg_ids):
312 312 if isinstance(index, int):
313 313 indices_or_msg_ids[i] = self.history[index]
314 314 return self.client.get_result(indices_or_msg_ids)
315 315
316 316 #-------------------------------------------------------------------
317 317 # Map
318 318 #-------------------------------------------------------------------
319 319
320 320 def map(self, f, *sequences, **kwargs):
321 321 """override in subclasses"""
322 322 raise NotImplementedError
323 323
324 324 def map_async(self, f, *sequences, **kwargs):
325 325 """Parallel version of builtin `map`, using this view's engines.
326 326
327 327 This is equivalent to map(...block=False)
328 328
329 329 See `self.map` for details.
330 330 """
331 331 if 'block' in kwargs:
332 332 raise TypeError("map_async doesn't take a `block` keyword argument.")
333 333 kwargs['block'] = False
334 334 return self.map(f,*sequences,**kwargs)
335 335
336 336 def map_sync(self, f, *sequences, **kwargs):
337 337 """Parallel version of builtin `map`, using this view's engines.
338 338
339 339 This is equivalent to map(...block=True)
340 340
341 341 See `self.map` for details.
342 342 """
343 343 if 'block' in kwargs:
344 344 raise TypeError("map_sync doesn't take a `block` keyword argument.")
345 345 kwargs['block'] = True
346 346 return self.map(f,*sequences,**kwargs)
347 347
348 348 def imap(self, f, *sequences, **kwargs):
349 349 """Parallel version of `itertools.imap`.
350 350
351 351 See `self.map` for details.
352 352
353 353 """
354 354
355 355 return iter(self.map_async(f,*sequences, **kwargs))
356 356
357 357 #-------------------------------------------------------------------
358 358 # Decorators
359 359 #-------------------------------------------------------------------
360 360
361 361 def remote(self, block=True, **flags):
362 362 """Decorator for making a RemoteFunction"""
363 363 block = self.block if block is None else block
364 364 return remote(self, block=block, **flags)
365 365
366 366 def parallel(self, dist='b', block=None, **flags):
367 367 """Decorator for making a ParallelFunction"""
368 368 block = self.block if block is None else block
369 369 return parallel(self, dist=dist, block=block, **flags)
370 370
371 371 @skip_doctest
372 372 class DirectView(View):
373 373 """Direct Multiplexer View of one or more engines.
374 374
375 375 These are created via indexed access to a client:
376 376
377 377 >>> dv_1 = client[1]
378 378 >>> dv_all = client[:]
379 379 >>> dv_even = client[::2]
380 380 >>> dv_some = client[1:3]
381 381
382 382 This object provides dictionary access to engine namespaces:
383 383
384 384 # push a=5:
385 385 >>> dv['a'] = 5
386 386 # pull 'foo':
387 387 >>> db['foo']
388 388
389 389 """
390 390
391 391 def __init__(self, client=None, socket=None, targets=None):
392 392 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
393 393
394 394 @property
395 395 def importer(self):
396 396 """sync_imports(local=True) as a property.
397 397
398 398 See sync_imports for details.
399 399
400 400 """
401 401 return self.sync_imports(True)
402 402
403 403 @contextmanager
404 404 def sync_imports(self, local=True, quiet=False):
405 405 """Context Manager for performing simultaneous local and remote imports.
406 406
407 407 'import x as y' will *not* work. The 'as y' part will simply be ignored.
408 408
409 409 If `local=True`, then the package will also be imported locally.
410 410
411 If `quiet=True`, then no message concerning the success of import will be
412 reported.
411 If `quiet=True`, no output will be produced when attempting remote
412 imports.
413 413
414 414 Note that remote-only (`local=False`) imports have not been implemented.
415 415
416 416 >>> with view.sync_imports():
417 417 ... from numpy import recarray
418 418 importing recarray from numpy on engine(s)
419 419
420 420 """
421 421 import __builtin__
422 422 local_import = __builtin__.__import__
423 423 modules = set()
424 424 results = []
425 425 @util.interactive
426 426 def remote_import(name, fromlist, level):
427 427 """the function to be passed to apply, that actually performs the import
428 428 on the engine, and loads up the user namespace.
429 429 """
430 430 import sys
431 431 user_ns = globals()
432 432 mod = __import__(name, fromlist=fromlist, level=level)
433 433 if fromlist:
434 434 for key in fromlist:
435 435 user_ns[key] = getattr(mod, key)
436 436 else:
437 437 user_ns[name] = sys.modules[name]
438 438
439 439 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
440 440 """the drop-in replacement for __import__, that optionally imports
441 441 locally as well.
442 442 """
443 443 # don't override nested imports
444 444 save_import = __builtin__.__import__
445 445 __builtin__.__import__ = local_import
446 446
447 447 if imp.lock_held():
448 448 # this is a side-effect import, don't do it remotely, or even
449 449 # ignore the local effects
450 450 return local_import(name, globals, locals, fromlist, level)
451 451
452 452 imp.acquire_lock()
453 453 if local:
454 454 mod = local_import(name, globals, locals, fromlist, level)
455 455 else:
456 456 raise NotImplementedError("remote-only imports not yet implemented")
457 457 imp.release_lock()
458 458
459 459 key = name+':'+','.join(fromlist or [])
460 460 if level == -1 and key not in modules:
461 461 modules.add(key)
462 462 if not quiet:
463 463 if fromlist:
464 464 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
465 465 else:
466 466 print "importing %s on engine(s)"%name
467 467 results.append(self.apply_async(remote_import, name, fromlist, level))
468 468 # restore override
469 469 __builtin__.__import__ = save_import
470 470
471 471 return mod
472 472
473 473 # override __import__
474 474 __builtin__.__import__ = view_import
475 475 try:
476 476 # enter the block
477 477 yield
478 478 except ImportError:
479 479 if local:
480 480 raise
481 481 else:
482 482 # ignore import errors if not doing local imports
483 483 pass
484 484 finally:
485 485 # always restore __import__
486 486 __builtin__.__import__ = local_import
487 487
488 488 for r in results:
489 489 # raise possible remote ImportErrors here
490 490 r.get()
491 491
492 492
493 493 @sync_results
494 494 @save_ids
495 495 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
496 496 """calls f(*args, **kwargs) on remote engines, returning the result.
497 497
498 498 This method sets all of `apply`'s flags via this View's attributes.
499 499
500 500 Parameters
501 501 ----------
502 502
503 503 f : callable
504 504
505 505 args : list [default: empty]
506 506
507 507 kwargs : dict [default: empty]
508 508
509 509 targets : target list [default: self.targets]
510 510 where to run
511 511 block : bool [default: self.block]
512 512 whether to block
513 513 track : bool [default: self.track]
514 514 whether to ask zmq to track the message, for safe non-copying sends
515 515
516 516 Returns
517 517 -------
518 518
519 519 if self.block is False:
520 520 returns AsyncResult
521 521 else:
522 522 returns actual result of f(*args, **kwargs) on the engine(s)
523 523 This will be a list of self.targets is also a list (even length 1), or
524 524 the single result if self.targets is an integer engine id
525 525 """
526 526 args = [] if args is None else args
527 527 kwargs = {} if kwargs is None else kwargs
528 528 block = self.block if block is None else block
529 529 track = self.track if track is None else track
530 530 targets = self.targets if targets is None else targets
531 531
532 532 _idents = self.client._build_targets(targets)[0]
533 533 msg_ids = []
534 534 trackers = []
535 535 for ident in _idents:
536 536 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
537 537 ident=ident)
538 538 if track:
539 539 trackers.append(msg['tracker'])
540 540 msg_ids.append(msg['header']['msg_id'])
541 541 tracker = None if track is False else zmq.MessageTracker(*trackers)
542 542 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker)
543 543 if block:
544 544 try:
545 545 return ar.get()
546 546 except KeyboardInterrupt:
547 547 pass
548 548 return ar
549 549
550 550 @spin_after
551 551 def map(self, f, *sequences, **kwargs):
552 552 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
553 553
554 554 Parallel version of builtin `map`, using this View's `targets`.
555 555
556 556 There will be one task per target, so work will be chunked
557 557 if the sequences are longer than `targets`.
558 558
559 559 Results can be iterated as they are ready, but will become available in chunks.
560 560
561 561 Parameters
562 562 ----------
563 563
564 564 f : callable
565 565 function to be mapped
566 566 *sequences: one or more sequences of matching length
567 567 the sequences to be distributed and passed to `f`
568 568 block : bool
569 569 whether to wait for the result or not [default self.block]
570 570
571 571 Returns
572 572 -------
573 573
574 574 if block=False:
575 575 AsyncMapResult
576 576 An object like AsyncResult, but which reassembles the sequence of results
577 577 into a single list. AsyncMapResults can be iterated through before all
578 578 results are complete.
579 579 else:
580 580 list
581 581 the result of map(f,*sequences)
582 582 """
583 583
584 584 block = kwargs.pop('block', self.block)
585 585 for k in kwargs.keys():
586 586 if k not in ['block', 'track']:
587 587 raise TypeError("invalid keyword arg, %r"%k)
588 588
589 589 assert len(sequences) > 0, "must have some sequences to map onto!"
590 590 pf = ParallelFunction(self, f, block=block, **kwargs)
591 591 return pf.map(*sequences)
592 592
593 593 def execute(self, code, targets=None, block=None):
594 594 """Executes `code` on `targets` in blocking or nonblocking manner.
595 595
596 596 ``execute`` is always `bound` (affects engine namespace)
597 597
598 598 Parameters
599 599 ----------
600 600
601 601 code : str
602 602 the code string to be executed
603 603 block : bool
604 604 whether or not to wait until done to return
605 605 default: self.block
606 606 """
607 607 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
608 608
609 609 def run(self, filename, targets=None, block=None):
610 610 """Execute contents of `filename` on my engine(s).
611 611
612 612 This simply reads the contents of the file and calls `execute`.
613 613
614 614 Parameters
615 615 ----------
616 616
617 617 filename : str
618 618 The path to the file
619 619 targets : int/str/list of ints/strs
620 620 the engines on which to execute
621 621 default : all
622 622 block : bool
623 623 whether or not to wait until done
624 624 default: self.block
625 625
626 626 """
627 627 with open(filename, 'r') as f:
628 628 # add newline in case of trailing indented whitespace
629 629 # which will cause SyntaxError
630 630 code = f.read()+'\n'
631 631 return self.execute(code, block=block, targets=targets)
632 632
633 633 def update(self, ns):
634 634 """update remote namespace with dict `ns`
635 635
636 636 See `push` for details.
637 637 """
638 638 return self.push(ns, block=self.block, track=self.track)
639 639
640 640 def push(self, ns, targets=None, block=None, track=None):
641 641 """update remote namespace with dict `ns`
642 642
643 643 Parameters
644 644 ----------
645 645
646 646 ns : dict
647 647 dict of keys with which to update engine namespace(s)
648 648 block : bool [default : self.block]
649 649 whether to wait to be notified of engine receipt
650 650
651 651 """
652 652
653 653 block = block if block is not None else self.block
654 654 track = track if track is not None else self.track
655 655 targets = targets if targets is not None else self.targets
656 656 # applier = self.apply_sync if block else self.apply_async
657 657 if not isinstance(ns, dict):
658 658 raise TypeError("Must be a dict, not %s"%type(ns))
659 659 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
660 660
661 661 def get(self, key_s):
662 662 """get object(s) by `key_s` from remote namespace
663 663
664 664 see `pull` for details.
665 665 """
666 666 # block = block if block is not None else self.block
667 667 return self.pull(key_s, block=True)
668 668
669 669 def pull(self, names, targets=None, block=None):
670 670 """get object(s) by `name` from remote namespace
671 671
672 672 will return one object if it is a key.
673 673 can also take a list of keys, in which case it will return a list of objects.
674 674 """
675 675 block = block if block is not None else self.block
676 676 targets = targets if targets is not None else self.targets
677 677 applier = self.apply_sync if block else self.apply_async
678 678 if isinstance(names, basestring):
679 679 pass
680 680 elif isinstance(names, (list,tuple,set)):
681 681 for key in names:
682 682 if not isinstance(key, basestring):
683 683 raise TypeError("keys must be str, not type %r"%type(key))
684 684 else:
685 685 raise TypeError("names must be strs, not %r"%names)
686 686 return self._really_apply(util._pull, (names,), block=block, targets=targets)
687 687
688 688 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
689 689 """
690 690 Partition a Python sequence and send the partitions to a set of engines.
691 691 """
692 692 block = block if block is not None else self.block
693 693 track = track if track is not None else self.track
694 694 targets = targets if targets is not None else self.targets
695 695
696 696 mapObject = Map.dists[dist]()
697 697 nparts = len(targets)
698 698 msg_ids = []
699 699 trackers = []
700 700 for index, engineid in enumerate(targets):
701 701 partition = mapObject.getPartition(seq, index, nparts)
702 702 if flatten and len(partition) == 1:
703 703 ns = {key: partition[0]}
704 704 else:
705 705 ns = {key: partition}
706 706 r = self.push(ns, block=False, track=track, targets=engineid)
707 707 msg_ids.extend(r.msg_ids)
708 708 if track:
709 709 trackers.append(r._tracker)
710 710
711 711 if track:
712 712 tracker = zmq.MessageTracker(*trackers)
713 713 else:
714 714 tracker = None
715 715
716 716 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
717 717 if block:
718 718 r.wait()
719 719 else:
720 720 return r
721 721
722 722 @sync_results
723 723 @save_ids
724 724 def gather(self, key, dist='b', targets=None, block=None):
725 725 """
726 726 Gather a partitioned sequence on a set of engines as a single local seq.
727 727 """
728 728 block = block if block is not None else self.block
729 729 targets = targets if targets is not None else self.targets
730 730 mapObject = Map.dists[dist]()
731 731 msg_ids = []
732 732
733 733 for index, engineid in enumerate(targets):
734 734 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
735 735
736 736 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
737 737
738 738 if block:
739 739 try:
740 740 return r.get()
741 741 except KeyboardInterrupt:
742 742 pass
743 743 return r
744 744
745 745 def __getitem__(self, key):
746 746 return self.get(key)
747 747
748 748 def __setitem__(self,key, value):
749 749 self.update({key:value})
750 750
751 751 def clear(self, targets=None, block=False):
752 752 """Clear the remote namespaces on my engines."""
753 753 block = block if block is not None else self.block
754 754 targets = targets if targets is not None else self.targets
755 755 return self.client.clear(targets=targets, block=block)
756 756
757 757 def kill(self, targets=None, block=True):
758 758 """Kill my engines."""
759 759 block = block if block is not None else self.block
760 760 targets = targets if targets is not None else self.targets
761 761 return self.client.kill(targets=targets, block=block)
762 762
763 763 #----------------------------------------
764 764 # activate for %px,%autopx magics
765 765 #----------------------------------------
766 766 def activate(self):
767 767 """Make this `View` active for parallel magic commands.
768 768
769 769 IPython has a magic command syntax to work with `MultiEngineClient` objects.
770 770 In a given IPython session there is a single active one. While
771 771 there can be many `Views` created and used by the user,
772 772 there is only one active one. The active `View` is used whenever
773 773 the magic commands %px and %autopx are used.
774 774
775 775 The activate() method is called on a given `View` to make it
776 776 active. Once this has been done, the magic commands can be used.
777 777 """
778 778
779 779 try:
780 780 # This is injected into __builtins__.
781 781 ip = get_ipython()
782 782 except NameError:
783 783 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
784 784 else:
785 785 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
786 786 if pmagic is None:
787 787 ip.magic_load_ext('parallelmagic')
788 788 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
789 789
790 790 pmagic.active_view = self
791 791
792 792
793 793 @skip_doctest
794 794 class LoadBalancedView(View):
795 795 """An load-balancing View that only executes via the Task scheduler.
796 796
797 797 Load-balanced views can be created with the client's `view` method:
798 798
799 799 >>> v = client.load_balanced_view()
800 800
801 801 or targets can be specified, to restrict the potential destinations:
802 802
803 803 >>> v = client.client.load_balanced_view([1,3])
804 804
805 805 which would restrict loadbalancing to between engines 1 and 3.
806 806
807 807 """
808 808
809 809 follow=Any()
810 810 after=Any()
811 811 timeout=CFloat()
812 812 retries = Integer(0)
813 813
814 814 _task_scheme = Any()
815 815 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
816 816
817 817 def __init__(self, client=None, socket=None, **flags):
818 818 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
819 819 self._task_scheme=client._task_scheme
820 820
821 821 def _validate_dependency(self, dep):
822 822 """validate a dependency.
823 823
824 824 For use in `set_flags`.
825 825 """
826 826 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
827 827 return True
828 828 elif isinstance(dep, (list,set, tuple)):
829 829 for d in dep:
830 830 if not isinstance(d, (basestring, AsyncResult)):
831 831 return False
832 832 elif isinstance(dep, dict):
833 833 if set(dep.keys()) != set(Dependency().as_dict().keys()):
834 834 return False
835 835 if not isinstance(dep['msg_ids'], list):
836 836 return False
837 837 for d in dep['msg_ids']:
838 838 if not isinstance(d, basestring):
839 839 return False
840 840 else:
841 841 return False
842 842
843 843 return True
844 844
845 845 def _render_dependency(self, dep):
846 846 """helper for building jsonable dependencies from various input forms."""
847 847 if isinstance(dep, Dependency):
848 848 return dep.as_dict()
849 849 elif isinstance(dep, AsyncResult):
850 850 return dep.msg_ids
851 851 elif dep is None:
852 852 return []
853 853 else:
854 854 # pass to Dependency constructor
855 855 return list(Dependency(dep))
856 856
857 857 def set_flags(self, **kwargs):
858 858 """set my attribute flags by keyword.
859 859
860 860 A View is a wrapper for the Client's apply method, but with attributes
861 861 that specify keyword arguments, those attributes can be set by keyword
862 862 argument with this method.
863 863
864 864 Parameters
865 865 ----------
866 866
867 867 block : bool
868 868 whether to wait for results
869 869 track : bool
870 870 whether to create a MessageTracker to allow the user to
871 871 safely edit after arrays and buffers during non-copying
872 872 sends.
873 873
874 874 after : Dependency or collection of msg_ids
875 875 Only for load-balanced execution (targets=None)
876 876 Specify a list of msg_ids as a time-based dependency.
877 877 This job will only be run *after* the dependencies
878 878 have been met.
879 879
880 880 follow : Dependency or collection of msg_ids
881 881 Only for load-balanced execution (targets=None)
882 882 Specify a list of msg_ids as a location-based dependency.
883 883 This job will only be run on an engine where this dependency
884 884 is met.
885 885
886 886 timeout : float/int or None
887 887 Only for load-balanced execution (targets=None)
888 888 Specify an amount of time (in seconds) for the scheduler to
889 889 wait for dependencies to be met before failing with a
890 890 DependencyTimeout.
891 891
892 892 retries : int
893 893 Number of times a task will be retried on failure.
894 894 """
895 895
896 896 super(LoadBalancedView, self).set_flags(**kwargs)
897 897 for name in ('follow', 'after'):
898 898 if name in kwargs:
899 899 value = kwargs[name]
900 900 if self._validate_dependency(value):
901 901 setattr(self, name, value)
902 902 else:
903 903 raise ValueError("Invalid dependency: %r"%value)
904 904 if 'timeout' in kwargs:
905 905 t = kwargs['timeout']
906 906 if not isinstance(t, (int, long, float, type(None))):
907 907 raise TypeError("Invalid type for timeout: %r"%type(t))
908 908 if t is not None:
909 909 if t < 0:
910 910 raise ValueError("Invalid timeout: %s"%t)
911 911 self.timeout = t
912 912
913 913 @sync_results
914 914 @save_ids
915 915 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
916 916 after=None, follow=None, timeout=None,
917 917 targets=None, retries=None):
918 918 """calls f(*args, **kwargs) on a remote engine, returning the result.
919 919
920 920 This method temporarily sets all of `apply`'s flags for a single call.
921 921
922 922 Parameters
923 923 ----------
924 924
925 925 f : callable
926 926
927 927 args : list [default: empty]
928 928
929 929 kwargs : dict [default: empty]
930 930
931 931 block : bool [default: self.block]
932 932 whether to block
933 933 track : bool [default: self.track]
934 934 whether to ask zmq to track the message, for safe non-copying sends
935 935
936 936 !!!!!! TODO: THE REST HERE !!!!
937 937
938 938 Returns
939 939 -------
940 940
941 941 if self.block is False:
942 942 returns AsyncResult
943 943 else:
944 944 returns actual result of f(*args, **kwargs) on the engine(s)
945 945 This will be a list of self.targets is also a list (even length 1), or
946 946 the single result if self.targets is an integer engine id
947 947 """
948 948
949 949 # validate whether we can run
950 950 if self._socket.closed:
951 951 msg = "Task farming is disabled"
952 952 if self._task_scheme == 'pure':
953 953 msg += " because the pure ZMQ scheduler cannot handle"
954 954 msg += " disappearing engines."
955 955 raise RuntimeError(msg)
956 956
957 957 if self._task_scheme == 'pure':
958 958 # pure zmq scheme doesn't support extra features
959 959 msg = "Pure ZMQ scheduler doesn't support the following flags:"
960 960 "follow, after, retries, targets, timeout"
961 961 if (follow or after or retries or targets or timeout):
962 962 # hard fail on Scheduler flags
963 963 raise RuntimeError(msg)
964 964 if isinstance(f, dependent):
965 965 # soft warn on functional dependencies
966 966 warnings.warn(msg, RuntimeWarning)
967 967
968 968 # build args
969 969 args = [] if args is None else args
970 970 kwargs = {} if kwargs is None else kwargs
971 971 block = self.block if block is None else block
972 972 track = self.track if track is None else track
973 973 after = self.after if after is None else after
974 974 retries = self.retries if retries is None else retries
975 975 follow = self.follow if follow is None else follow
976 976 timeout = self.timeout if timeout is None else timeout
977 977 targets = self.targets if targets is None else targets
978 978
979 979 if not isinstance(retries, int):
980 980 raise TypeError('retries must be int, not %r'%type(retries))
981 981
982 982 if targets is None:
983 983 idents = []
984 984 else:
985 985 idents = self.client._build_targets(targets)[0]
986 986 # ensure *not* bytes
987 987 idents = [ ident.decode() for ident in idents ]
988 988
989 989 after = self._render_dependency(after)
990 990 follow = self._render_dependency(follow)
991 991 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
992 992
993 993 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
994 994 subheader=subheader)
995 995 tracker = None if track is False else msg['tracker']
996 996
997 997 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
998 998
999 999 if block:
1000 1000 try:
1001 1001 return ar.get()
1002 1002 except KeyboardInterrupt:
1003 1003 pass
1004 1004 return ar
1005 1005
1006 1006 @spin_after
1007 1007 @save_ids
1008 1008 def map(self, f, *sequences, **kwargs):
1009 1009 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1010 1010
1011 1011 Parallel version of builtin `map`, load-balanced by this View.
1012 1012
1013 1013 `block`, and `chunksize` can be specified by keyword only.
1014 1014
1015 1015 Each `chunksize` elements will be a separate task, and will be
1016 1016 load-balanced. This lets individual elements be available for iteration
1017 1017 as soon as they arrive.
1018 1018
1019 1019 Parameters
1020 1020 ----------
1021 1021
1022 1022 f : callable
1023 1023 function to be mapped
1024 1024 *sequences: one or more sequences of matching length
1025 1025 the sequences to be distributed and passed to `f`
1026 1026 block : bool [default self.block]
1027 1027 whether to wait for the result or not
1028 1028 track : bool
1029 1029 whether to create a MessageTracker to allow the user to
1030 1030 safely edit after arrays and buffers during non-copying
1031 1031 sends.
1032 1032 chunksize : int [default 1]
1033 1033 how many elements should be in each task.
1034 1034 ordered : bool [default True]
1035 1035 Whether the results should be gathered as they arrive, or enforce
1036 1036 the order of submission.
1037 1037
1038 1038 Only applies when iterating through AsyncMapResult as results arrive.
1039 1039 Has no effect when block=True.
1040 1040
1041 1041 Returns
1042 1042 -------
1043 1043
1044 1044 if block=False:
1045 1045 AsyncMapResult
1046 1046 An object like AsyncResult, but which reassembles the sequence of results
1047 1047 into a single list. AsyncMapResults can be iterated through before all
1048 1048 results are complete.
1049 1049 else:
1050 1050 the result of map(f,*sequences)
1051 1051
1052 1052 """
1053 1053
1054 1054 # default
1055 1055 block = kwargs.get('block', self.block)
1056 1056 chunksize = kwargs.get('chunksize', 1)
1057 1057 ordered = kwargs.get('ordered', True)
1058 1058
1059 1059 keyset = set(kwargs.keys())
1060 1060 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1061 1061 if extra_keys:
1062 1062 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1063 1063
1064 1064 assert len(sequences) > 0, "must have some sequences to map onto!"
1065 1065
1066 1066 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1067 1067 return pf.map(*sequences)
1068 1068
1069 1069 __all__ = ['LoadBalancedView', 'DirectView']
General Comments 0
You need to be logged in to leave comments. Login now