##// END OF EJS Templates
Backport PR #4106: fix a couple of default block values...
Thomas Kluyver -
Show More
@@ -1,1116 +1,1116 b''
1 1 """Views of remote engines.
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import imp
19 19 import sys
20 20 import warnings
21 21 from contextlib import contextmanager
22 22 from types import ModuleType
23 23
24 24 import zmq
25 25
26 26 from IPython.testing.skipdoctest import skip_doctest
27 27 from IPython.utils.traitlets import (
28 28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
29 29 )
30 30 from IPython.external.decorator import decorator
31 31
32 32 from IPython.parallel import util
33 33 from IPython.parallel.controller.dependency import Dependency, dependent
34 34
35 35 from . import map as Map
36 36 from .asyncresult import AsyncResult, AsyncMapResult
37 37 from .remotefunction import ParallelFunction, parallel, remote, getname
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Decorators
41 41 #-----------------------------------------------------------------------------
42 42
43 43 @decorator
44 44 def save_ids(f, self, *args, **kwargs):
45 45 """Keep our history and outstanding attributes up to date after a method call."""
46 46 n_previous = len(self.client.history)
47 47 try:
48 48 ret = f(self, *args, **kwargs)
49 49 finally:
50 50 nmsgs = len(self.client.history) - n_previous
51 51 msg_ids = self.client.history[-nmsgs:]
52 52 self.history.extend(msg_ids)
53 53 map(self.outstanding.add, msg_ids)
54 54 return ret
55 55
56 56 @decorator
57 57 def sync_results(f, self, *args, **kwargs):
58 58 """sync relevant results from self.client to our results attribute."""
59 59 if self._in_sync_results:
60 60 return f(self, *args, **kwargs)
61 61 self._in_sync_results = True
62 62 try:
63 63 ret = f(self, *args, **kwargs)
64 64 finally:
65 65 self._in_sync_results = False
66 66 self._sync_results()
67 67 return ret
68 68
69 69 @decorator
70 70 def spin_after(f, self, *args, **kwargs):
71 71 """call spin after the method."""
72 72 ret = f(self, *args, **kwargs)
73 73 self.spin()
74 74 return ret
75 75
76 76 #-----------------------------------------------------------------------------
77 77 # Classes
78 78 #-----------------------------------------------------------------------------
79 79
80 80 @skip_doctest
81 81 class View(HasTraits):
82 82 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
83 83
84 84 Don't use this class, use subclasses.
85 85
86 86 Methods
87 87 -------
88 88
89 89 spin
90 90 flushes incoming results and registration state changes
91 91 control methods spin, and requesting `ids` also ensures up to date
92 92
93 93 wait
94 94 wait on one or more msg_ids
95 95
96 96 execution methods
97 97 apply
98 98 legacy: execute, run
99 99
100 100 data movement
101 101 push, pull, scatter, gather
102 102
103 103 query methods
104 104 get_result, queue_status, purge_results, result_status
105 105
106 106 control methods
107 107 abort, shutdown
108 108
109 109 """
110 110 # flags
111 111 block=Bool(False)
112 112 track=Bool(True)
113 113 targets = Any()
114 114
115 115 history=List()
116 116 outstanding = Set()
117 117 results = Dict()
118 118 client = Instance('IPython.parallel.Client')
119 119
120 120 _socket = Instance('zmq.Socket')
121 121 _flag_names = List(['targets', 'block', 'track'])
122 122 _in_sync_results = Bool(False)
123 123 _targets = Any()
124 124 _idents = Any()
125 125
126 126 def __init__(self, client=None, socket=None, **flags):
127 127 super(View, self).__init__(client=client, _socket=socket)
128 128 self.results = client.results
129 129 self.block = client.block
130 130
131 131 self.set_flags(**flags)
132 132
133 133 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
134 134
135 135 def __repr__(self):
136 136 strtargets = str(self.targets)
137 137 if len(strtargets) > 16:
138 138 strtargets = strtargets[:12]+'...]'
139 139 return "<%s %s>"%(self.__class__.__name__, strtargets)
140 140
141 141 def __len__(self):
142 142 if isinstance(self.targets, list):
143 143 return len(self.targets)
144 144 elif isinstance(self.targets, int):
145 145 return 1
146 146 else:
147 147 return len(self.client)
148 148
149 149 def set_flags(self, **kwargs):
150 150 """set my attribute flags by keyword.
151 151
152 152 Views determine behavior with a few attributes (`block`, `track`, etc.).
153 153 These attributes can be set all at once by name with this method.
154 154
155 155 Parameters
156 156 ----------
157 157
158 158 block : bool
159 159 whether to wait for results
160 160 track : bool
161 161 whether to create a MessageTracker to allow the user to
162 162 safely edit after arrays and buffers during non-copying
163 163 sends.
164 164 """
165 165 for name, value in kwargs.iteritems():
166 166 if name not in self._flag_names:
167 167 raise KeyError("Invalid name: %r"%name)
168 168 else:
169 169 setattr(self, name, value)
170 170
171 171 @contextmanager
172 172 def temp_flags(self, **kwargs):
173 173 """temporarily set flags, for use in `with` statements.
174 174
175 175 See set_flags for permanent setting of flags
176 176
177 177 Examples
178 178 --------
179 179
180 180 >>> view.track=False
181 181 ...
182 182 >>> with view.temp_flags(track=True):
183 183 ... ar = view.apply(dostuff, my_big_array)
184 184 ... ar.tracker.wait() # wait for send to finish
185 185 >>> view.track
186 186 False
187 187
188 188 """
189 189 # preflight: save flags, and set temporaries
190 190 saved_flags = {}
191 191 for f in self._flag_names:
192 192 saved_flags[f] = getattr(self, f)
193 193 self.set_flags(**kwargs)
194 194 # yield to the with-statement block
195 195 try:
196 196 yield
197 197 finally:
198 198 # postflight: restore saved flags
199 199 self.set_flags(**saved_flags)
200 200
201 201
202 202 #----------------------------------------------------------------
203 203 # apply
204 204 #----------------------------------------------------------------
205 205
206 206 def _sync_results(self):
207 207 """to be called by @sync_results decorator
208 208
209 209 after submitting any tasks.
210 210 """
211 211 delta = self.outstanding.difference(self.client.outstanding)
212 212 completed = self.outstanding.intersection(delta)
213 213 self.outstanding = self.outstanding.difference(completed)
214 214
215 215 @sync_results
216 216 @save_ids
217 217 def _really_apply(self, f, args, kwargs, block=None, **options):
218 218 """wrapper for client.send_apply_request"""
219 219 raise NotImplementedError("Implement in subclasses")
220 220
221 221 def apply(self, f, *args, **kwargs):
222 222 """calls f(*args, **kwargs) on remote engines, returning the result.
223 223
224 224 This method sets all apply flags via this View's attributes.
225 225
226 226 if self.block is False:
227 227 returns AsyncResult
228 228 else:
229 229 returns actual result of f(*args, **kwargs)
230 230 """
231 231 return self._really_apply(f, args, kwargs)
232 232
233 233 def apply_async(self, f, *args, **kwargs):
234 234 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
235 235
236 236 returns AsyncResult
237 237 """
238 238 return self._really_apply(f, args, kwargs, block=False)
239 239
240 240 @spin_after
241 241 def apply_sync(self, f, *args, **kwargs):
242 242 """calls f(*args, **kwargs) on remote engines in a blocking manner,
243 243 returning the result.
244 244
245 245 returns: actual result of f(*args, **kwargs)
246 246 """
247 247 return self._really_apply(f, args, kwargs, block=True)
248 248
249 249 #----------------------------------------------------------------
250 250 # wrappers for client and control methods
251 251 #----------------------------------------------------------------
252 252 @sync_results
253 253 def spin(self):
254 254 """spin the client, and sync"""
255 255 self.client.spin()
256 256
257 257 @sync_results
258 258 def wait(self, jobs=None, timeout=-1):
259 259 """waits on one or more `jobs`, for up to `timeout` seconds.
260 260
261 261 Parameters
262 262 ----------
263 263
264 264 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
265 265 ints are indices to self.history
266 266 strs are msg_ids
267 267 default: wait on all outstanding messages
268 268 timeout : float
269 269 a time in seconds, after which to give up.
270 270 default is -1, which means no timeout
271 271
272 272 Returns
273 273 -------
274 274
275 275 True : when all msg_ids are done
276 276 False : timeout reached, some msg_ids still outstanding
277 277 """
278 278 if jobs is None:
279 279 jobs = self.history
280 280 return self.client.wait(jobs, timeout)
281 281
282 282 def abort(self, jobs=None, targets=None, block=None):
283 283 """Abort jobs on my engines.
284 284
285 285 Parameters
286 286 ----------
287 287
288 288 jobs : None, str, list of strs, optional
289 289 if None: abort all jobs.
290 290 else: abort specific msg_id(s).
291 291 """
292 292 block = block if block is not None else self.block
293 293 targets = targets if targets is not None else self.targets
294 294 jobs = jobs if jobs is not None else list(self.outstanding)
295 295
296 296 return self.client.abort(jobs=jobs, targets=targets, block=block)
297 297
298 298 def queue_status(self, targets=None, verbose=False):
299 299 """Fetch the Queue status of my engines"""
300 300 targets = targets if targets is not None else self.targets
301 301 return self.client.queue_status(targets=targets, verbose=verbose)
302 302
303 303 def purge_results(self, jobs=[], targets=[]):
304 304 """Instruct the controller to forget specific results."""
305 305 if targets is None or targets == 'all':
306 306 targets = self.targets
307 307 return self.client.purge_results(jobs=jobs, targets=targets)
308 308
309 309 def shutdown(self, targets=None, restart=False, hub=False, block=None):
310 310 """Terminates one or more engine processes, optionally including the hub.
311 311 """
312 312 block = self.block if block is None else block
313 313 if targets is None or targets == 'all':
314 314 targets = self.targets
315 315 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
316 316
317 317 @spin_after
318 318 def get_result(self, indices_or_msg_ids=None):
319 319 """return one or more results, specified by history index or msg_id.
320 320
321 321 See client.get_result for details.
322 322
323 323 """
324 324
325 325 if indices_or_msg_ids is None:
326 326 indices_or_msg_ids = -1
327 327 if isinstance(indices_or_msg_ids, int):
328 328 indices_or_msg_ids = self.history[indices_or_msg_ids]
329 329 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
330 330 indices_or_msg_ids = list(indices_or_msg_ids)
331 331 for i,index in enumerate(indices_or_msg_ids):
332 332 if isinstance(index, int):
333 333 indices_or_msg_ids[i] = self.history[index]
334 334 return self.client.get_result(indices_or_msg_ids)
335 335
336 336 #-------------------------------------------------------------------
337 337 # Map
338 338 #-------------------------------------------------------------------
339 339
340 340 @sync_results
341 341 def map(self, f, *sequences, **kwargs):
342 342 """override in subclasses"""
343 343 raise NotImplementedError
344 344
345 345 def map_async(self, f, *sequences, **kwargs):
346 346 """Parallel version of builtin `map`, using this view's engines.
347 347
348 348 This is equivalent to map(...block=False)
349 349
350 350 See `self.map` for details.
351 351 """
352 352 if 'block' in kwargs:
353 353 raise TypeError("map_async doesn't take a `block` keyword argument.")
354 354 kwargs['block'] = False
355 355 return self.map(f,*sequences,**kwargs)
356 356
357 357 def map_sync(self, f, *sequences, **kwargs):
358 358 """Parallel version of builtin `map`, using this view's engines.
359 359
360 360 This is equivalent to map(...block=True)
361 361
362 362 See `self.map` for details.
363 363 """
364 364 if 'block' in kwargs:
365 365 raise TypeError("map_sync doesn't take a `block` keyword argument.")
366 366 kwargs['block'] = True
367 367 return self.map(f,*sequences,**kwargs)
368 368
369 369 def imap(self, f, *sequences, **kwargs):
370 370 """Parallel version of `itertools.imap`.
371 371
372 372 See `self.map` for details.
373 373
374 374 """
375 375
376 376 return iter(self.map_async(f,*sequences, **kwargs))
377 377
378 378 #-------------------------------------------------------------------
379 379 # Decorators
380 380 #-------------------------------------------------------------------
381 381
382 def remote(self, block=True, **flags):
382 def remote(self, block=None, **flags):
383 383 """Decorator for making a RemoteFunction"""
384 384 block = self.block if block is None else block
385 385 return remote(self, block=block, **flags)
386 386
387 387 def parallel(self, dist='b', block=None, **flags):
388 388 """Decorator for making a ParallelFunction"""
389 389 block = self.block if block is None else block
390 390 return parallel(self, dist=dist, block=block, **flags)
391 391
392 392 @skip_doctest
393 393 class DirectView(View):
394 394 """Direct Multiplexer View of one or more engines.
395 395
396 396 These are created via indexed access to a client:
397 397
398 398 >>> dv_1 = client[1]
399 399 >>> dv_all = client[:]
400 400 >>> dv_even = client[::2]
401 401 >>> dv_some = client[1:3]
402 402
403 403 This object provides dictionary access to engine namespaces:
404 404
405 405 # push a=5:
406 406 >>> dv['a'] = 5
407 407 # pull 'foo':
408 408 >>> db['foo']
409 409
410 410 """
411 411
412 412 def __init__(self, client=None, socket=None, targets=None):
413 413 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
414 414
415 415 @property
416 416 def importer(self):
417 417 """sync_imports(local=True) as a property.
418 418
419 419 See sync_imports for details.
420 420
421 421 """
422 422 return self.sync_imports(True)
423 423
424 424 @contextmanager
425 425 def sync_imports(self, local=True, quiet=False):
426 426 """Context Manager for performing simultaneous local and remote imports.
427 427
428 428 'import x as y' will *not* work. The 'as y' part will simply be ignored.
429 429
430 430 If `local=True`, then the package will also be imported locally.
431 431
432 432 If `quiet=True`, no output will be produced when attempting remote
433 433 imports.
434 434
435 435 Note that remote-only (`local=False`) imports have not been implemented.
436 436
437 437 >>> with view.sync_imports():
438 438 ... from numpy import recarray
439 439 importing recarray from numpy on engine(s)
440 440
441 441 """
442 442 import __builtin__
443 443 local_import = __builtin__.__import__
444 444 modules = set()
445 445 results = []
446 446 @util.interactive
447 447 def remote_import(name, fromlist, level):
448 448 """the function to be passed to apply, that actually performs the import
449 449 on the engine, and loads up the user namespace.
450 450 """
451 451 import sys
452 452 user_ns = globals()
453 453 mod = __import__(name, fromlist=fromlist, level=level)
454 454 if fromlist:
455 455 for key in fromlist:
456 456 user_ns[key] = getattr(mod, key)
457 457 else:
458 458 user_ns[name] = sys.modules[name]
459 459
460 460 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
461 461 """the drop-in replacement for __import__, that optionally imports
462 462 locally as well.
463 463 """
464 464 # don't override nested imports
465 465 save_import = __builtin__.__import__
466 466 __builtin__.__import__ = local_import
467 467
468 468 if imp.lock_held():
469 469 # this is a side-effect import, don't do it remotely, or even
470 470 # ignore the local effects
471 471 return local_import(name, globals, locals, fromlist, level)
472 472
473 473 imp.acquire_lock()
474 474 if local:
475 475 mod = local_import(name, globals, locals, fromlist, level)
476 476 else:
477 477 raise NotImplementedError("remote-only imports not yet implemented")
478 478 imp.release_lock()
479 479
480 480 key = name+':'+','.join(fromlist or [])
481 481 if level == -1 and key not in modules:
482 482 modules.add(key)
483 483 if not quiet:
484 484 if fromlist:
485 485 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
486 486 else:
487 487 print "importing %s on engine(s)"%name
488 488 results.append(self.apply_async(remote_import, name, fromlist, level))
489 489 # restore override
490 490 __builtin__.__import__ = save_import
491 491
492 492 return mod
493 493
494 494 # override __import__
495 495 __builtin__.__import__ = view_import
496 496 try:
497 497 # enter the block
498 498 yield
499 499 except ImportError:
500 500 if local:
501 501 raise
502 502 else:
503 503 # ignore import errors if not doing local imports
504 504 pass
505 505 finally:
506 506 # always restore __import__
507 507 __builtin__.__import__ = local_import
508 508
509 509 for r in results:
510 510 # raise possible remote ImportErrors here
511 511 r.get()
512 512
513 513
514 514 @sync_results
515 515 @save_ids
516 516 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
517 517 """calls f(*args, **kwargs) on remote engines, returning the result.
518 518
519 519 This method sets all of `apply`'s flags via this View's attributes.
520 520
521 521 Parameters
522 522 ----------
523 523
524 524 f : callable
525 525
526 526 args : list [default: empty]
527 527
528 528 kwargs : dict [default: empty]
529 529
530 530 targets : target list [default: self.targets]
531 531 where to run
532 532 block : bool [default: self.block]
533 533 whether to block
534 534 track : bool [default: self.track]
535 535 whether to ask zmq to track the message, for safe non-copying sends
536 536
537 537 Returns
538 538 -------
539 539
540 540 if self.block is False:
541 541 returns AsyncResult
542 542 else:
543 543 returns actual result of f(*args, **kwargs) on the engine(s)
544 544 This will be a list of self.targets is also a list (even length 1), or
545 545 the single result if self.targets is an integer engine id
546 546 """
547 547 args = [] if args is None else args
548 548 kwargs = {} if kwargs is None else kwargs
549 549 block = self.block if block is None else block
550 550 track = self.track if track is None else track
551 551 targets = self.targets if targets is None else targets
552 552
553 553 _idents, _targets = self.client._build_targets(targets)
554 554 msg_ids = []
555 555 trackers = []
556 556 for ident in _idents:
557 557 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
558 558 ident=ident)
559 559 if track:
560 560 trackers.append(msg['tracker'])
561 561 msg_ids.append(msg['header']['msg_id'])
562 562 if isinstance(targets, int):
563 563 msg_ids = msg_ids[0]
564 564 tracker = None if track is False else zmq.MessageTracker(*trackers)
565 565 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets, tracker=tracker)
566 566 if block:
567 567 try:
568 568 return ar.get()
569 569 except KeyboardInterrupt:
570 570 pass
571 571 return ar
572 572
573 573
574 574 @sync_results
575 575 def map(self, f, *sequences, **kwargs):
576 576 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
577 577
578 578 Parallel version of builtin `map`, using this View's `targets`.
579 579
580 580 There will be one task per target, so work will be chunked
581 581 if the sequences are longer than `targets`.
582 582
583 583 Results can be iterated as they are ready, but will become available in chunks.
584 584
585 585 Parameters
586 586 ----------
587 587
588 588 f : callable
589 589 function to be mapped
590 590 *sequences: one or more sequences of matching length
591 591 the sequences to be distributed and passed to `f`
592 592 block : bool
593 593 whether to wait for the result or not [default self.block]
594 594
595 595 Returns
596 596 -------
597 597
598 598 if block=False:
599 599 AsyncMapResult
600 600 An object like AsyncResult, but which reassembles the sequence of results
601 601 into a single list. AsyncMapResults can be iterated through before all
602 602 results are complete.
603 603 else:
604 604 list
605 605 the result of map(f,*sequences)
606 606 """
607 607
608 608 block = kwargs.pop('block', self.block)
609 609 for k in kwargs.keys():
610 610 if k not in ['block', 'track']:
611 611 raise TypeError("invalid keyword arg, %r"%k)
612 612
613 613 assert len(sequences) > 0, "must have some sequences to map onto!"
614 614 pf = ParallelFunction(self, f, block=block, **kwargs)
615 615 return pf.map(*sequences)
616 616
617 617 @sync_results
618 618 @save_ids
619 619 def execute(self, code, silent=True, targets=None, block=None):
620 620 """Executes `code` on `targets` in blocking or nonblocking manner.
621 621
622 622 ``execute`` is always `bound` (affects engine namespace)
623 623
624 624 Parameters
625 625 ----------
626 626
627 627 code : str
628 628 the code string to be executed
629 629 block : bool
630 630 whether or not to wait until done to return
631 631 default: self.block
632 632 """
633 633 block = self.block if block is None else block
634 634 targets = self.targets if targets is None else targets
635 635
636 636 _idents, _targets = self.client._build_targets(targets)
637 637 msg_ids = []
638 638 trackers = []
639 639 for ident in _idents:
640 640 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
641 641 msg_ids.append(msg['header']['msg_id'])
642 642 if isinstance(targets, int):
643 643 msg_ids = msg_ids[0]
644 644 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets)
645 645 if block:
646 646 try:
647 647 ar.get()
648 648 except KeyboardInterrupt:
649 649 pass
650 650 return ar
651 651
652 652 def run(self, filename, targets=None, block=None):
653 653 """Execute contents of `filename` on my engine(s).
654 654
655 655 This simply reads the contents of the file and calls `execute`.
656 656
657 657 Parameters
658 658 ----------
659 659
660 660 filename : str
661 661 The path to the file
662 662 targets : int/str/list of ints/strs
663 663 the engines on which to execute
664 664 default : all
665 665 block : bool
666 666 whether or not to wait until done
667 667 default: self.block
668 668
669 669 """
670 670 with open(filename, 'r') as f:
671 671 # add newline in case of trailing indented whitespace
672 672 # which will cause SyntaxError
673 673 code = f.read()+'\n'
674 674 return self.execute(code, block=block, targets=targets)
675 675
676 676 def update(self, ns):
677 677 """update remote namespace with dict `ns`
678 678
679 679 See `push` for details.
680 680 """
681 681 return self.push(ns, block=self.block, track=self.track)
682 682
683 683 def push(self, ns, targets=None, block=None, track=None):
684 684 """update remote namespace with dict `ns`
685 685
686 686 Parameters
687 687 ----------
688 688
689 689 ns : dict
690 690 dict of keys with which to update engine namespace(s)
691 691 block : bool [default : self.block]
692 692 whether to wait to be notified of engine receipt
693 693
694 694 """
695 695
696 696 block = block if block is not None else self.block
697 697 track = track if track is not None else self.track
698 698 targets = targets if targets is not None else self.targets
699 699 # applier = self.apply_sync if block else self.apply_async
700 700 if not isinstance(ns, dict):
701 701 raise TypeError("Must be a dict, not %s"%type(ns))
702 702 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
703 703
704 704 def get(self, key_s):
705 705 """get object(s) by `key_s` from remote namespace
706 706
707 707 see `pull` for details.
708 708 """
709 709 # block = block if block is not None else self.block
710 710 return self.pull(key_s, block=True)
711 711
712 712 def pull(self, names, targets=None, block=None):
713 713 """get object(s) by `name` from remote namespace
714 714
715 715 will return one object if it is a key.
716 716 can also take a list of keys, in which case it will return a list of objects.
717 717 """
718 718 block = block if block is not None else self.block
719 719 targets = targets if targets is not None else self.targets
720 720 applier = self.apply_sync if block else self.apply_async
721 721 if isinstance(names, basestring):
722 722 pass
723 723 elif isinstance(names, (list,tuple,set)):
724 724 for key in names:
725 725 if not isinstance(key, basestring):
726 726 raise TypeError("keys must be str, not type %r"%type(key))
727 727 else:
728 728 raise TypeError("names must be strs, not %r"%names)
729 729 return self._really_apply(util._pull, (names,), block=block, targets=targets)
730 730
731 731 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
732 732 """
733 733 Partition a Python sequence and send the partitions to a set of engines.
734 734 """
735 735 block = block if block is not None else self.block
736 736 track = track if track is not None else self.track
737 737 targets = targets if targets is not None else self.targets
738 738
739 739 # construct integer ID list:
740 740 targets = self.client._build_targets(targets)[1]
741 741
742 742 mapObject = Map.dists[dist]()
743 743 nparts = len(targets)
744 744 msg_ids = []
745 745 trackers = []
746 746 for index, engineid in enumerate(targets):
747 747 partition = mapObject.getPartition(seq, index, nparts)
748 748 if flatten and len(partition) == 1:
749 749 ns = {key: partition[0]}
750 750 else:
751 751 ns = {key: partition}
752 752 r = self.push(ns, block=False, track=track, targets=engineid)
753 753 msg_ids.extend(r.msg_ids)
754 754 if track:
755 755 trackers.append(r._tracker)
756 756
757 757 if track:
758 758 tracker = zmq.MessageTracker(*trackers)
759 759 else:
760 760 tracker = None
761 761
762 762 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
763 763 if block:
764 764 r.wait()
765 765 else:
766 766 return r
767 767
768 768 @sync_results
769 769 @save_ids
770 770 def gather(self, key, dist='b', targets=None, block=None):
771 771 """
772 772 Gather a partitioned sequence on a set of engines as a single local seq.
773 773 """
774 774 block = block if block is not None else self.block
775 775 targets = targets if targets is not None else self.targets
776 776 mapObject = Map.dists[dist]()
777 777 msg_ids = []
778 778
779 779 # construct integer ID list:
780 780 targets = self.client._build_targets(targets)[1]
781 781
782 782 for index, engineid in enumerate(targets):
783 783 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
784 784
785 785 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
786 786
787 787 if block:
788 788 try:
789 789 return r.get()
790 790 except KeyboardInterrupt:
791 791 pass
792 792 return r
793 793
794 794 def __getitem__(self, key):
795 795 return self.get(key)
796 796
797 797 def __setitem__(self,key, value):
798 798 self.update({key:value})
799 799
800 def clear(self, targets=None, block=False):
800 def clear(self, targets=None, block=None):
801 801 """Clear the remote namespaces on my engines."""
802 802 block = block if block is not None else self.block
803 803 targets = targets if targets is not None else self.targets
804 804 return self.client.clear(targets=targets, block=block)
805 805
806 806 #----------------------------------------
807 807 # activate for %px, %autopx, etc. magics
808 808 #----------------------------------------
809 809
810 810 def activate(self, suffix=''):
811 811 """Activate IPython magics associated with this View
812 812
813 813 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
814 814
815 815 Parameters
816 816 ----------
817 817
818 818 suffix: str [default: '']
819 819 The suffix, if any, for the magics. This allows you to have
820 820 multiple views associated with parallel magics at the same time.
821 821
822 822 e.g. ``rc[::2].activate(suffix='_even')`` will give you
823 823 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
824 824 on the even engines.
825 825 """
826 826
827 827 from IPython.parallel.client.magics import ParallelMagics
828 828
829 829 try:
830 830 # This is injected into __builtins__.
831 831 ip = get_ipython()
832 832 except NameError:
833 833 print "The IPython parallel magics (%px, etc.) only work within IPython."
834 834 return
835 835
836 836 M = ParallelMagics(ip, self, suffix)
837 837 ip.magics_manager.register(M)
838 838
839 839
840 840 @skip_doctest
841 841 class LoadBalancedView(View):
842 842 """An load-balancing View that only executes via the Task scheduler.
843 843
844 844 Load-balanced views can be created with the client's `view` method:
845 845
846 846 >>> v = client.load_balanced_view()
847 847
848 848 or targets can be specified, to restrict the potential destinations:
849 849
850 850 >>> v = client.client.load_balanced_view([1,3])
851 851
852 852 which would restrict loadbalancing to between engines 1 and 3.
853 853
854 854 """
855 855
856 856 follow=Any()
857 857 after=Any()
858 858 timeout=CFloat()
859 859 retries = Integer(0)
860 860
861 861 _task_scheme = Any()
862 862 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
863 863
864 864 def __init__(self, client=None, socket=None, **flags):
865 865 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
866 866 self._task_scheme=client._task_scheme
867 867
868 868 def _validate_dependency(self, dep):
869 869 """validate a dependency.
870 870
871 871 For use in `set_flags`.
872 872 """
873 873 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
874 874 return True
875 875 elif isinstance(dep, (list,set, tuple)):
876 876 for d in dep:
877 877 if not isinstance(d, (basestring, AsyncResult)):
878 878 return False
879 879 elif isinstance(dep, dict):
880 880 if set(dep.keys()) != set(Dependency().as_dict().keys()):
881 881 return False
882 882 if not isinstance(dep['msg_ids'], list):
883 883 return False
884 884 for d in dep['msg_ids']:
885 885 if not isinstance(d, basestring):
886 886 return False
887 887 else:
888 888 return False
889 889
890 890 return True
891 891
892 892 def _render_dependency(self, dep):
893 893 """helper for building jsonable dependencies from various input forms."""
894 894 if isinstance(dep, Dependency):
895 895 return dep.as_dict()
896 896 elif isinstance(dep, AsyncResult):
897 897 return dep.msg_ids
898 898 elif dep is None:
899 899 return []
900 900 else:
901 901 # pass to Dependency constructor
902 902 return list(Dependency(dep))
903 903
904 904 def set_flags(self, **kwargs):
905 905 """set my attribute flags by keyword.
906 906
907 907 A View is a wrapper for the Client's apply method, but with attributes
908 908 that specify keyword arguments, those attributes can be set by keyword
909 909 argument with this method.
910 910
911 911 Parameters
912 912 ----------
913 913
914 914 block : bool
915 915 whether to wait for results
916 916 track : bool
917 917 whether to create a MessageTracker to allow the user to
918 918 safely edit after arrays and buffers during non-copying
919 919 sends.
920 920
921 921 after : Dependency or collection of msg_ids
922 922 Only for load-balanced execution (targets=None)
923 923 Specify a list of msg_ids as a time-based dependency.
924 924 This job will only be run *after* the dependencies
925 925 have been met.
926 926
927 927 follow : Dependency or collection of msg_ids
928 928 Only for load-balanced execution (targets=None)
929 929 Specify a list of msg_ids as a location-based dependency.
930 930 This job will only be run on an engine where this dependency
931 931 is met.
932 932
933 933 timeout : float/int or None
934 934 Only for load-balanced execution (targets=None)
935 935 Specify an amount of time (in seconds) for the scheduler to
936 936 wait for dependencies to be met before failing with a
937 937 DependencyTimeout.
938 938
939 939 retries : int
940 940 Number of times a task will be retried on failure.
941 941 """
942 942
943 943 super(LoadBalancedView, self).set_flags(**kwargs)
944 944 for name in ('follow', 'after'):
945 945 if name in kwargs:
946 946 value = kwargs[name]
947 947 if self._validate_dependency(value):
948 948 setattr(self, name, value)
949 949 else:
950 950 raise ValueError("Invalid dependency: %r"%value)
951 951 if 'timeout' in kwargs:
952 952 t = kwargs['timeout']
953 953 if not isinstance(t, (int, long, float, type(None))):
954 954 raise TypeError("Invalid type for timeout: %r"%type(t))
955 955 if t is not None:
956 956 if t < 0:
957 957 raise ValueError("Invalid timeout: %s"%t)
958 958 self.timeout = t
959 959
960 960 @sync_results
961 961 @save_ids
962 962 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
963 963 after=None, follow=None, timeout=None,
964 964 targets=None, retries=None):
965 965 """calls f(*args, **kwargs) on a remote engine, returning the result.
966 966
967 967 This method temporarily sets all of `apply`'s flags for a single call.
968 968
969 969 Parameters
970 970 ----------
971 971
972 972 f : callable
973 973
974 974 args : list [default: empty]
975 975
976 976 kwargs : dict [default: empty]
977 977
978 978 block : bool [default: self.block]
979 979 whether to block
980 980 track : bool [default: self.track]
981 981 whether to ask zmq to track the message, for safe non-copying sends
982 982
983 983 !!!!!! TODO: THE REST HERE !!!!
984 984
985 985 Returns
986 986 -------
987 987
988 988 if self.block is False:
989 989 returns AsyncResult
990 990 else:
991 991 returns actual result of f(*args, **kwargs) on the engine(s)
992 992 This will be a list of self.targets is also a list (even length 1), or
993 993 the single result if self.targets is an integer engine id
994 994 """
995 995
996 996 # validate whether we can run
997 997 if self._socket.closed:
998 998 msg = "Task farming is disabled"
999 999 if self._task_scheme == 'pure':
1000 1000 msg += " because the pure ZMQ scheduler cannot handle"
1001 1001 msg += " disappearing engines."
1002 1002 raise RuntimeError(msg)
1003 1003
1004 1004 if self._task_scheme == 'pure':
1005 1005 # pure zmq scheme doesn't support extra features
1006 1006 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1007 1007 "follow, after, retries, targets, timeout"
1008 1008 if (follow or after or retries or targets or timeout):
1009 1009 # hard fail on Scheduler flags
1010 1010 raise RuntimeError(msg)
1011 1011 if isinstance(f, dependent):
1012 1012 # soft warn on functional dependencies
1013 1013 warnings.warn(msg, RuntimeWarning)
1014 1014
1015 1015 # build args
1016 1016 args = [] if args is None else args
1017 1017 kwargs = {} if kwargs is None else kwargs
1018 1018 block = self.block if block is None else block
1019 1019 track = self.track if track is None else track
1020 1020 after = self.after if after is None else after
1021 1021 retries = self.retries if retries is None else retries
1022 1022 follow = self.follow if follow is None else follow
1023 1023 timeout = self.timeout if timeout is None else timeout
1024 1024 targets = self.targets if targets is None else targets
1025 1025
1026 1026 if not isinstance(retries, int):
1027 1027 raise TypeError('retries must be int, not %r'%type(retries))
1028 1028
1029 1029 if targets is None:
1030 1030 idents = []
1031 1031 else:
1032 1032 idents = self.client._build_targets(targets)[0]
1033 1033 # ensure *not* bytes
1034 1034 idents = [ ident.decode() for ident in idents ]
1035 1035
1036 1036 after = self._render_dependency(after)
1037 1037 follow = self._render_dependency(follow)
1038 1038 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1039 1039
1040 1040 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1041 1041 metadata=metadata)
1042 1042 tracker = None if track is False else msg['tracker']
1043 1043
1044 1044 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
1045 1045
1046 1046 if block:
1047 1047 try:
1048 1048 return ar.get()
1049 1049 except KeyboardInterrupt:
1050 1050 pass
1051 1051 return ar
1052 1052
1053 1053 @sync_results
1054 1054 @save_ids
1055 1055 def map(self, f, *sequences, **kwargs):
1056 1056 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1057 1057
1058 1058 Parallel version of builtin `map`, load-balanced by this View.
1059 1059
1060 1060 `block`, and `chunksize` can be specified by keyword only.
1061 1061
1062 1062 Each `chunksize` elements will be a separate task, and will be
1063 1063 load-balanced. This lets individual elements be available for iteration
1064 1064 as soon as they arrive.
1065 1065
1066 1066 Parameters
1067 1067 ----------
1068 1068
1069 1069 f : callable
1070 1070 function to be mapped
1071 1071 *sequences: one or more sequences of matching length
1072 1072 the sequences to be distributed and passed to `f`
1073 1073 block : bool [default self.block]
1074 1074 whether to wait for the result or not
1075 1075 track : bool
1076 1076 whether to create a MessageTracker to allow the user to
1077 1077 safely edit after arrays and buffers during non-copying
1078 1078 sends.
1079 1079 chunksize : int [default 1]
1080 1080 how many elements should be in each task.
1081 1081 ordered : bool [default True]
1082 1082 Whether the results should be gathered as they arrive, or enforce
1083 1083 the order of submission.
1084 1084
1085 1085 Only applies when iterating through AsyncMapResult as results arrive.
1086 1086 Has no effect when block=True.
1087 1087
1088 1088 Returns
1089 1089 -------
1090 1090
1091 1091 if block=False:
1092 1092 AsyncMapResult
1093 1093 An object like AsyncResult, but which reassembles the sequence of results
1094 1094 into a single list. AsyncMapResults can be iterated through before all
1095 1095 results are complete.
1096 1096 else:
1097 1097 the result of map(f,*sequences)
1098 1098
1099 1099 """
1100 1100
1101 1101 # default
1102 1102 block = kwargs.get('block', self.block)
1103 1103 chunksize = kwargs.get('chunksize', 1)
1104 1104 ordered = kwargs.get('ordered', True)
1105 1105
1106 1106 keyset = set(kwargs.keys())
1107 1107 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1108 1108 if extra_keys:
1109 1109 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1110 1110
1111 1111 assert len(sequences) > 0, "must have some sequences to map onto!"
1112 1112
1113 1113 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1114 1114 return pf.map(*sequences)
1115 1115
1116 1116 __all__ = ['LoadBalancedView', 'DirectView']
General Comments 0
You need to be logged in to leave comments. Login now