##// END OF EJS Templates
Adds a quiet keyword to sync_imports to allow users to surpress messages about imports on remote engines.
Ben Edwards -
Show More
@@ -1,1065 +1,1069 b''
1 1 """Views of remote engines.
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import imp
19 19 import sys
20 20 import warnings
21 21 from contextlib import contextmanager
22 22 from types import ModuleType
23 23
24 24 import zmq
25 25
26 26 from IPython.testing.skipdoctest import skip_doctest
27 27 from IPython.utils.traitlets import (
28 28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
29 29 )
30 30 from IPython.external.decorator import decorator
31 31
32 32 from IPython.parallel import util
33 33 from IPython.parallel.controller.dependency import Dependency, dependent
34 34
35 35 from . import map as Map
36 36 from .asyncresult import AsyncResult, AsyncMapResult
37 37 from .remotefunction import ParallelFunction, parallel, remote, getname
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Decorators
41 41 #-----------------------------------------------------------------------------
42 42
43 43 @decorator
44 44 def save_ids(f, self, *args, **kwargs):
45 45 """Keep our history and outstanding attributes up to date after a method call."""
46 46 n_previous = len(self.client.history)
47 47 try:
48 48 ret = f(self, *args, **kwargs)
49 49 finally:
50 50 nmsgs = len(self.client.history) - n_previous
51 51 msg_ids = self.client.history[-nmsgs:]
52 52 self.history.extend(msg_ids)
53 53 map(self.outstanding.add, msg_ids)
54 54 return ret
55 55
56 56 @decorator
57 57 def sync_results(f, self, *args, **kwargs):
58 58 """sync relevant results from self.client to our results attribute."""
59 59 ret = f(self, *args, **kwargs)
60 60 delta = self.outstanding.difference(self.client.outstanding)
61 61 completed = self.outstanding.intersection(delta)
62 62 self.outstanding = self.outstanding.difference(completed)
63 63 for msg_id in completed:
64 64 self.results[msg_id] = self.client.results[msg_id]
65 65 return ret
66 66
67 67 @decorator
68 68 def spin_after(f, self, *args, **kwargs):
69 69 """call spin after the method."""
70 70 ret = f(self, *args, **kwargs)
71 71 self.spin()
72 72 return ret
73 73
74 74 #-----------------------------------------------------------------------------
75 75 # Classes
76 76 #-----------------------------------------------------------------------------
77 77
78 78 @skip_doctest
79 79 class View(HasTraits):
80 80 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
81 81
82 82 Don't use this class, use subclasses.
83 83
84 84 Methods
85 85 -------
86 86
87 87 spin
88 88 flushes incoming results and registration state changes
89 89 control methods spin, and requesting `ids` also ensures up to date
90 90
91 91 wait
92 92 wait on one or more msg_ids
93 93
94 94 execution methods
95 95 apply
96 96 legacy: execute, run
97 97
98 98 data movement
99 99 push, pull, scatter, gather
100 100
101 101 query methods
102 102 get_result, queue_status, purge_results, result_status
103 103
104 104 control methods
105 105 abort, shutdown
106 106
107 107 """
108 108 # flags
109 109 block=Bool(False)
110 110 track=Bool(True)
111 111 targets = Any()
112 112
113 113 history=List()
114 114 outstanding = Set()
115 115 results = Dict()
116 116 client = Instance('IPython.parallel.Client')
117 117
118 118 _socket = Instance('zmq.Socket')
119 119 _flag_names = List(['targets', 'block', 'track'])
120 120 _targets = Any()
121 121 _idents = Any()
122 122
123 123 def __init__(self, client=None, socket=None, **flags):
124 124 super(View, self).__init__(client=client, _socket=socket)
125 125 self.block = client.block
126 126
127 127 self.set_flags(**flags)
128 128
129 129 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
130 130
131 131
132 132 def __repr__(self):
133 133 strtargets = str(self.targets)
134 134 if len(strtargets) > 16:
135 135 strtargets = strtargets[:12]+'...]'
136 136 return "<%s %s>"%(self.__class__.__name__, strtargets)
137 137
138 138 def set_flags(self, **kwargs):
139 139 """set my attribute flags by keyword.
140 140
141 141 Views determine behavior with a few attributes (`block`, `track`, etc.).
142 142 These attributes can be set all at once by name with this method.
143 143
144 144 Parameters
145 145 ----------
146 146
147 147 block : bool
148 148 whether to wait for results
149 149 track : bool
150 150 whether to create a MessageTracker to allow the user to
151 151 safely edit after arrays and buffers during non-copying
152 152 sends.
153 153 """
154 154 for name, value in kwargs.iteritems():
155 155 if name not in self._flag_names:
156 156 raise KeyError("Invalid name: %r"%name)
157 157 else:
158 158 setattr(self, name, value)
159 159
160 160 @contextmanager
161 161 def temp_flags(self, **kwargs):
162 162 """temporarily set flags, for use in `with` statements.
163 163
164 164 See set_flags for permanent setting of flags
165 165
166 166 Examples
167 167 --------
168 168
169 169 >>> view.track=False
170 170 ...
171 171 >>> with view.temp_flags(track=True):
172 172 ... ar = view.apply(dostuff, my_big_array)
173 173 ... ar.tracker.wait() # wait for send to finish
174 174 >>> view.track
175 175 False
176 176
177 177 """
178 178 # preflight: save flags, and set temporaries
179 179 saved_flags = {}
180 180 for f in self._flag_names:
181 181 saved_flags[f] = getattr(self, f)
182 182 self.set_flags(**kwargs)
183 183 # yield to the with-statement block
184 184 try:
185 185 yield
186 186 finally:
187 187 # postflight: restore saved flags
188 188 self.set_flags(**saved_flags)
189 189
190 190
191 191 #----------------------------------------------------------------
192 192 # apply
193 193 #----------------------------------------------------------------
194 194
195 195 @sync_results
196 196 @save_ids
197 197 def _really_apply(self, f, args, kwargs, block=None, **options):
198 198 """wrapper for client.send_apply_message"""
199 199 raise NotImplementedError("Implement in subclasses")
200 200
201 201 def apply(self, f, *args, **kwargs):
202 202 """calls f(*args, **kwargs) on remote engines, returning the result.
203 203
204 204 This method sets all apply flags via this View's attributes.
205 205
206 206 if self.block is False:
207 207 returns AsyncResult
208 208 else:
209 209 returns actual result of f(*args, **kwargs)
210 210 """
211 211 return self._really_apply(f, args, kwargs)
212 212
213 213 def apply_async(self, f, *args, **kwargs):
214 214 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
215 215
216 216 returns AsyncResult
217 217 """
218 218 return self._really_apply(f, args, kwargs, block=False)
219 219
220 220 @spin_after
221 221 def apply_sync(self, f, *args, **kwargs):
222 222 """calls f(*args, **kwargs) on remote engines in a blocking manner,
223 223 returning the result.
224 224
225 225 returns: actual result of f(*args, **kwargs)
226 226 """
227 227 return self._really_apply(f, args, kwargs, block=True)
228 228
229 229 #----------------------------------------------------------------
230 230 # wrappers for client and control methods
231 231 #----------------------------------------------------------------
232 232 @sync_results
233 233 def spin(self):
234 234 """spin the client, and sync"""
235 235 self.client.spin()
236 236
237 237 @sync_results
238 238 def wait(self, jobs=None, timeout=-1):
239 239 """waits on one or more `jobs`, for up to `timeout` seconds.
240 240
241 241 Parameters
242 242 ----------
243 243
244 244 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
245 245 ints are indices to self.history
246 246 strs are msg_ids
247 247 default: wait on all outstanding messages
248 248 timeout : float
249 249 a time in seconds, after which to give up.
250 250 default is -1, which means no timeout
251 251
252 252 Returns
253 253 -------
254 254
255 255 True : when all msg_ids are done
256 256 False : timeout reached, some msg_ids still outstanding
257 257 """
258 258 if jobs is None:
259 259 jobs = self.history
260 260 return self.client.wait(jobs, timeout)
261 261
262 262 def abort(self, jobs=None, targets=None, block=None):
263 263 """Abort jobs on my engines.
264 264
265 265 Parameters
266 266 ----------
267 267
268 268 jobs : None, str, list of strs, optional
269 269 if None: abort all jobs.
270 270 else: abort specific msg_id(s).
271 271 """
272 272 block = block if block is not None else self.block
273 273 targets = targets if targets is not None else self.targets
274 274 jobs = jobs if jobs is not None else list(self.outstanding)
275 275
276 276 return self.client.abort(jobs=jobs, targets=targets, block=block)
277 277
278 278 def queue_status(self, targets=None, verbose=False):
279 279 """Fetch the Queue status of my engines"""
280 280 targets = targets if targets is not None else self.targets
281 281 return self.client.queue_status(targets=targets, verbose=verbose)
282 282
283 283 def purge_results(self, jobs=[], targets=[]):
284 284 """Instruct the controller to forget specific results."""
285 285 if targets is None or targets == 'all':
286 286 targets = self.targets
287 287 return self.client.purge_results(jobs=jobs, targets=targets)
288 288
289 289 def shutdown(self, targets=None, restart=False, hub=False, block=None):
290 290 """Terminates one or more engine processes, optionally including the hub.
291 291 """
292 292 block = self.block if block is None else block
293 293 if targets is None or targets == 'all':
294 294 targets = self.targets
295 295 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
296 296
297 297 @spin_after
298 298 def get_result(self, indices_or_msg_ids=None):
299 299 """return one or more results, specified by history index or msg_id.
300 300
301 301 See client.get_result for details.
302 302
303 303 """
304 304
305 305 if indices_or_msg_ids is None:
306 306 indices_or_msg_ids = -1
307 307 if isinstance(indices_or_msg_ids, int):
308 308 indices_or_msg_ids = self.history[indices_or_msg_ids]
309 309 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
310 310 indices_or_msg_ids = list(indices_or_msg_ids)
311 311 for i,index in enumerate(indices_or_msg_ids):
312 312 if isinstance(index, int):
313 313 indices_or_msg_ids[i] = self.history[index]
314 314 return self.client.get_result(indices_or_msg_ids)
315 315
316 316 #-------------------------------------------------------------------
317 317 # Map
318 318 #-------------------------------------------------------------------
319 319
320 320 def map(self, f, *sequences, **kwargs):
321 321 """override in subclasses"""
322 322 raise NotImplementedError
323 323
324 324 def map_async(self, f, *sequences, **kwargs):
325 325 """Parallel version of builtin `map`, using this view's engines.
326 326
327 327 This is equivalent to map(...block=False)
328 328
329 329 See `self.map` for details.
330 330 """
331 331 if 'block' in kwargs:
332 332 raise TypeError("map_async doesn't take a `block` keyword argument.")
333 333 kwargs['block'] = False
334 334 return self.map(f,*sequences,**kwargs)
335 335
336 336 def map_sync(self, f, *sequences, **kwargs):
337 337 """Parallel version of builtin `map`, using this view's engines.
338 338
339 339 This is equivalent to map(...block=True)
340 340
341 341 See `self.map` for details.
342 342 """
343 343 if 'block' in kwargs:
344 344 raise TypeError("map_sync doesn't take a `block` keyword argument.")
345 345 kwargs['block'] = True
346 346 return self.map(f,*sequences,**kwargs)
347 347
348 348 def imap(self, f, *sequences, **kwargs):
349 349 """Parallel version of `itertools.imap`.
350 350
351 351 See `self.map` for details.
352 352
353 353 """
354 354
355 355 return iter(self.map_async(f,*sequences, **kwargs))
356 356
357 357 #-------------------------------------------------------------------
358 358 # Decorators
359 359 #-------------------------------------------------------------------
360 360
361 361 def remote(self, block=True, **flags):
362 362 """Decorator for making a RemoteFunction"""
363 363 block = self.block if block is None else block
364 364 return remote(self, block=block, **flags)
365 365
366 366 def parallel(self, dist='b', block=None, **flags):
367 367 """Decorator for making a ParallelFunction"""
368 368 block = self.block if block is None else block
369 369 return parallel(self, dist=dist, block=block, **flags)
370 370
371 371 @skip_doctest
372 372 class DirectView(View):
373 373 """Direct Multiplexer View of one or more engines.
374 374
375 375 These are created via indexed access to a client:
376 376
377 377 >>> dv_1 = client[1]
378 378 >>> dv_all = client[:]
379 379 >>> dv_even = client[::2]
380 380 >>> dv_some = client[1:3]
381 381
382 382 This object provides dictionary access to engine namespaces:
383 383
384 384 # push a=5:
385 385 >>> dv['a'] = 5
386 386 # pull 'foo':
387 387 >>> db['foo']
388 388
389 389 """
390 390
391 391 def __init__(self, client=None, socket=None, targets=None):
392 392 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
393 393
394 394 @property
395 395 def importer(self):
396 396 """sync_imports(local=True) as a property.
397 397
398 398 See sync_imports for details.
399 399
400 400 """
401 401 return self.sync_imports(True)
402 402
403 403 @contextmanager
404 def sync_imports(self, local=True):
404 def sync_imports(self, local=True, quiet=False):
405 405 """Context Manager for performing simultaneous local and remote imports.
406 406
407 407 'import x as y' will *not* work. The 'as y' part will simply be ignored.
408 408
409 409 If `local=True`, then the package will also be imported locally.
410 410
411 If `quiet=True`, then no message concerning the success of import will be
412 reported.
413
411 414 Note that remote-only (`local=False`) imports have not been implemented.
412 415
413 416 >>> with view.sync_imports():
414 417 ... from numpy import recarray
415 418 importing recarray from numpy on engine(s)
416 419
417 420 """
418 421 import __builtin__
419 422 local_import = __builtin__.__import__
420 423 modules = set()
421 424 results = []
422 425 @util.interactive
423 426 def remote_import(name, fromlist, level):
424 427 """the function to be passed to apply, that actually performs the import
425 428 on the engine, and loads up the user namespace.
426 429 """
427 430 import sys
428 431 user_ns = globals()
429 432 mod = __import__(name, fromlist=fromlist, level=level)
430 433 if fromlist:
431 434 for key in fromlist:
432 435 user_ns[key] = getattr(mod, key)
433 436 else:
434 437 user_ns[name] = sys.modules[name]
435 438
436 439 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
437 440 """the drop-in replacement for __import__, that optionally imports
438 441 locally as well.
439 442 """
440 443 # don't override nested imports
441 444 save_import = __builtin__.__import__
442 445 __builtin__.__import__ = local_import
443 446
444 447 if imp.lock_held():
445 448 # this is a side-effect import, don't do it remotely, or even
446 449 # ignore the local effects
447 450 return local_import(name, globals, locals, fromlist, level)
448 451
449 452 imp.acquire_lock()
450 453 if local:
451 454 mod = local_import(name, globals, locals, fromlist, level)
452 455 else:
453 456 raise NotImplementedError("remote-only imports not yet implemented")
454 457 imp.release_lock()
455 458
456 459 key = name+':'+','.join(fromlist or [])
457 460 if level == -1 and key not in modules:
458 461 modules.add(key)
459 if fromlist:
460 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
461 else:
462 print "importing %s on engine(s)"%name
462 if not quiet:
463 if fromlist:
464 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
465 else:
466 print "importing %s on engine(s)"%name
463 467 results.append(self.apply_async(remote_import, name, fromlist, level))
464 468 # restore override
465 469 __builtin__.__import__ = save_import
466 470
467 471 return mod
468 472
469 473 # override __import__
470 474 __builtin__.__import__ = view_import
471 475 try:
472 476 # enter the block
473 477 yield
474 478 except ImportError:
475 479 if local:
476 480 raise
477 481 else:
478 482 # ignore import errors if not doing local imports
479 483 pass
480 484 finally:
481 485 # always restore __import__
482 486 __builtin__.__import__ = local_import
483 487
484 488 for r in results:
485 489 # raise possible remote ImportErrors here
486 490 r.get()
487 491
488 492
489 493 @sync_results
490 494 @save_ids
491 495 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
492 496 """calls f(*args, **kwargs) on remote engines, returning the result.
493 497
494 498 This method sets all of `apply`'s flags via this View's attributes.
495 499
496 500 Parameters
497 501 ----------
498 502
499 503 f : callable
500 504
501 505 args : list [default: empty]
502 506
503 507 kwargs : dict [default: empty]
504 508
505 509 targets : target list [default: self.targets]
506 510 where to run
507 511 block : bool [default: self.block]
508 512 whether to block
509 513 track : bool [default: self.track]
510 514 whether to ask zmq to track the message, for safe non-copying sends
511 515
512 516 Returns
513 517 -------
514 518
515 519 if self.block is False:
516 520 returns AsyncResult
517 521 else:
518 522 returns actual result of f(*args, **kwargs) on the engine(s)
519 523 This will be a list of self.targets is also a list (even length 1), or
520 524 the single result if self.targets is an integer engine id
521 525 """
522 526 args = [] if args is None else args
523 527 kwargs = {} if kwargs is None else kwargs
524 528 block = self.block if block is None else block
525 529 track = self.track if track is None else track
526 530 targets = self.targets if targets is None else targets
527 531
528 532 _idents = self.client._build_targets(targets)[0]
529 533 msg_ids = []
530 534 trackers = []
531 535 for ident in _idents:
532 536 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
533 537 ident=ident)
534 538 if track:
535 539 trackers.append(msg['tracker'])
536 540 msg_ids.append(msg['header']['msg_id'])
537 541 tracker = None if track is False else zmq.MessageTracker(*trackers)
538 542 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker)
539 543 if block:
540 544 try:
541 545 return ar.get()
542 546 except KeyboardInterrupt:
543 547 pass
544 548 return ar
545 549
546 550 @spin_after
547 551 def map(self, f, *sequences, **kwargs):
548 552 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
549 553
550 554 Parallel version of builtin `map`, using this View's `targets`.
551 555
552 556 There will be one task per target, so work will be chunked
553 557 if the sequences are longer than `targets`.
554 558
555 559 Results can be iterated as they are ready, but will become available in chunks.
556 560
557 561 Parameters
558 562 ----------
559 563
560 564 f : callable
561 565 function to be mapped
562 566 *sequences: one or more sequences of matching length
563 567 the sequences to be distributed and passed to `f`
564 568 block : bool
565 569 whether to wait for the result or not [default self.block]
566 570
567 571 Returns
568 572 -------
569 573
570 574 if block=False:
571 575 AsyncMapResult
572 576 An object like AsyncResult, but which reassembles the sequence of results
573 577 into a single list. AsyncMapResults can be iterated through before all
574 578 results are complete.
575 579 else:
576 580 list
577 581 the result of map(f,*sequences)
578 582 """
579 583
580 584 block = kwargs.pop('block', self.block)
581 585 for k in kwargs.keys():
582 586 if k not in ['block', 'track']:
583 587 raise TypeError("invalid keyword arg, %r"%k)
584 588
585 589 assert len(sequences) > 0, "must have some sequences to map onto!"
586 590 pf = ParallelFunction(self, f, block=block, **kwargs)
587 591 return pf.map(*sequences)
588 592
589 593 def execute(self, code, targets=None, block=None):
590 594 """Executes `code` on `targets` in blocking or nonblocking manner.
591 595
592 596 ``execute`` is always `bound` (affects engine namespace)
593 597
594 598 Parameters
595 599 ----------
596 600
597 601 code : str
598 602 the code string to be executed
599 603 block : bool
600 604 whether or not to wait until done to return
601 605 default: self.block
602 606 """
603 607 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
604 608
605 609 def run(self, filename, targets=None, block=None):
606 610 """Execute contents of `filename` on my engine(s).
607 611
608 612 This simply reads the contents of the file and calls `execute`.
609 613
610 614 Parameters
611 615 ----------
612 616
613 617 filename : str
614 618 The path to the file
615 619 targets : int/str/list of ints/strs
616 620 the engines on which to execute
617 621 default : all
618 622 block : bool
619 623 whether or not to wait until done
620 624 default: self.block
621 625
622 626 """
623 627 with open(filename, 'r') as f:
624 628 # add newline in case of trailing indented whitespace
625 629 # which will cause SyntaxError
626 630 code = f.read()+'\n'
627 631 return self.execute(code, block=block, targets=targets)
628 632
629 633 def update(self, ns):
630 634 """update remote namespace with dict `ns`
631 635
632 636 See `push` for details.
633 637 """
634 638 return self.push(ns, block=self.block, track=self.track)
635 639
636 640 def push(self, ns, targets=None, block=None, track=None):
637 641 """update remote namespace with dict `ns`
638 642
639 643 Parameters
640 644 ----------
641 645
642 646 ns : dict
643 647 dict of keys with which to update engine namespace(s)
644 648 block : bool [default : self.block]
645 649 whether to wait to be notified of engine receipt
646 650
647 651 """
648 652
649 653 block = block if block is not None else self.block
650 654 track = track if track is not None else self.track
651 655 targets = targets if targets is not None else self.targets
652 656 # applier = self.apply_sync if block else self.apply_async
653 657 if not isinstance(ns, dict):
654 658 raise TypeError("Must be a dict, not %s"%type(ns))
655 659 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
656 660
657 661 def get(self, key_s):
658 662 """get object(s) by `key_s` from remote namespace
659 663
660 664 see `pull` for details.
661 665 """
662 666 # block = block if block is not None else self.block
663 667 return self.pull(key_s, block=True)
664 668
665 669 def pull(self, names, targets=None, block=None):
666 670 """get object(s) by `name` from remote namespace
667 671
668 672 will return one object if it is a key.
669 673 can also take a list of keys, in which case it will return a list of objects.
670 674 """
671 675 block = block if block is not None else self.block
672 676 targets = targets if targets is not None else self.targets
673 677 applier = self.apply_sync if block else self.apply_async
674 678 if isinstance(names, basestring):
675 679 pass
676 680 elif isinstance(names, (list,tuple,set)):
677 681 for key in names:
678 682 if not isinstance(key, basestring):
679 683 raise TypeError("keys must be str, not type %r"%type(key))
680 684 else:
681 685 raise TypeError("names must be strs, not %r"%names)
682 686 return self._really_apply(util._pull, (names,), block=block, targets=targets)
683 687
684 688 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
685 689 """
686 690 Partition a Python sequence and send the partitions to a set of engines.
687 691 """
688 692 block = block if block is not None else self.block
689 693 track = track if track is not None else self.track
690 694 targets = targets if targets is not None else self.targets
691 695
692 696 mapObject = Map.dists[dist]()
693 697 nparts = len(targets)
694 698 msg_ids = []
695 699 trackers = []
696 700 for index, engineid in enumerate(targets):
697 701 partition = mapObject.getPartition(seq, index, nparts)
698 702 if flatten and len(partition) == 1:
699 703 ns = {key: partition[0]}
700 704 else:
701 705 ns = {key: partition}
702 706 r = self.push(ns, block=False, track=track, targets=engineid)
703 707 msg_ids.extend(r.msg_ids)
704 708 if track:
705 709 trackers.append(r._tracker)
706 710
707 711 if track:
708 712 tracker = zmq.MessageTracker(*trackers)
709 713 else:
710 714 tracker = None
711 715
712 716 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
713 717 if block:
714 718 r.wait()
715 719 else:
716 720 return r
717 721
718 722 @sync_results
719 723 @save_ids
720 724 def gather(self, key, dist='b', targets=None, block=None):
721 725 """
722 726 Gather a partitioned sequence on a set of engines as a single local seq.
723 727 """
724 728 block = block if block is not None else self.block
725 729 targets = targets if targets is not None else self.targets
726 730 mapObject = Map.dists[dist]()
727 731 msg_ids = []
728 732
729 733 for index, engineid in enumerate(targets):
730 734 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
731 735
732 736 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
733 737
734 738 if block:
735 739 try:
736 740 return r.get()
737 741 except KeyboardInterrupt:
738 742 pass
739 743 return r
740 744
741 745 def __getitem__(self, key):
742 746 return self.get(key)
743 747
744 748 def __setitem__(self,key, value):
745 749 self.update({key:value})
746 750
747 751 def clear(self, targets=None, block=False):
748 752 """Clear the remote namespaces on my engines."""
749 753 block = block if block is not None else self.block
750 754 targets = targets if targets is not None else self.targets
751 755 return self.client.clear(targets=targets, block=block)
752 756
753 757 def kill(self, targets=None, block=True):
754 758 """Kill my engines."""
755 759 block = block if block is not None else self.block
756 760 targets = targets if targets is not None else self.targets
757 761 return self.client.kill(targets=targets, block=block)
758 762
759 763 #----------------------------------------
760 764 # activate for %px,%autopx magics
761 765 #----------------------------------------
762 766 def activate(self):
763 767 """Make this `View` active for parallel magic commands.
764 768
765 769 IPython has a magic command syntax to work with `MultiEngineClient` objects.
766 770 In a given IPython session there is a single active one. While
767 771 there can be many `Views` created and used by the user,
768 772 there is only one active one. The active `View` is used whenever
769 773 the magic commands %px and %autopx are used.
770 774
771 775 The activate() method is called on a given `View` to make it
772 776 active. Once this has been done, the magic commands can be used.
773 777 """
774 778
775 779 try:
776 780 # This is injected into __builtins__.
777 781 ip = get_ipython()
778 782 except NameError:
779 783 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
780 784 else:
781 785 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
782 786 if pmagic is None:
783 787 ip.magic_load_ext('parallelmagic')
784 788 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
785 789
786 790 pmagic.active_view = self
787 791
788 792
789 793 @skip_doctest
790 794 class LoadBalancedView(View):
791 795 """An load-balancing View that only executes via the Task scheduler.
792 796
793 797 Load-balanced views can be created with the client's `view` method:
794 798
795 799 >>> v = client.load_balanced_view()
796 800
797 801 or targets can be specified, to restrict the potential destinations:
798 802
799 803 >>> v = client.client.load_balanced_view([1,3])
800 804
801 805 which would restrict loadbalancing to between engines 1 and 3.
802 806
803 807 """
804 808
805 809 follow=Any()
806 810 after=Any()
807 811 timeout=CFloat()
808 812 retries = Integer(0)
809 813
810 814 _task_scheme = Any()
811 815 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
812 816
813 817 def __init__(self, client=None, socket=None, **flags):
814 818 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
815 819 self._task_scheme=client._task_scheme
816 820
817 821 def _validate_dependency(self, dep):
818 822 """validate a dependency.
819 823
820 824 For use in `set_flags`.
821 825 """
822 826 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
823 827 return True
824 828 elif isinstance(dep, (list,set, tuple)):
825 829 for d in dep:
826 830 if not isinstance(d, (basestring, AsyncResult)):
827 831 return False
828 832 elif isinstance(dep, dict):
829 833 if set(dep.keys()) != set(Dependency().as_dict().keys()):
830 834 return False
831 835 if not isinstance(dep['msg_ids'], list):
832 836 return False
833 837 for d in dep['msg_ids']:
834 838 if not isinstance(d, basestring):
835 839 return False
836 840 else:
837 841 return False
838 842
839 843 return True
840 844
841 845 def _render_dependency(self, dep):
842 846 """helper for building jsonable dependencies from various input forms."""
843 847 if isinstance(dep, Dependency):
844 848 return dep.as_dict()
845 849 elif isinstance(dep, AsyncResult):
846 850 return dep.msg_ids
847 851 elif dep is None:
848 852 return []
849 853 else:
850 854 # pass to Dependency constructor
851 855 return list(Dependency(dep))
852 856
853 857 def set_flags(self, **kwargs):
854 858 """set my attribute flags by keyword.
855 859
856 860 A View is a wrapper for the Client's apply method, but with attributes
857 861 that specify keyword arguments, those attributes can be set by keyword
858 862 argument with this method.
859 863
860 864 Parameters
861 865 ----------
862 866
863 867 block : bool
864 868 whether to wait for results
865 869 track : bool
866 870 whether to create a MessageTracker to allow the user to
867 871 safely edit after arrays and buffers during non-copying
868 872 sends.
869 873
870 874 after : Dependency or collection of msg_ids
871 875 Only for load-balanced execution (targets=None)
872 876 Specify a list of msg_ids as a time-based dependency.
873 877 This job will only be run *after* the dependencies
874 878 have been met.
875 879
876 880 follow : Dependency or collection of msg_ids
877 881 Only for load-balanced execution (targets=None)
878 882 Specify a list of msg_ids as a location-based dependency.
879 883 This job will only be run on an engine where this dependency
880 884 is met.
881 885
882 886 timeout : float/int or None
883 887 Only for load-balanced execution (targets=None)
884 888 Specify an amount of time (in seconds) for the scheduler to
885 889 wait for dependencies to be met before failing with a
886 890 DependencyTimeout.
887 891
888 892 retries : int
889 893 Number of times a task will be retried on failure.
890 894 """
891 895
892 896 super(LoadBalancedView, self).set_flags(**kwargs)
893 897 for name in ('follow', 'after'):
894 898 if name in kwargs:
895 899 value = kwargs[name]
896 900 if self._validate_dependency(value):
897 901 setattr(self, name, value)
898 902 else:
899 903 raise ValueError("Invalid dependency: %r"%value)
900 904 if 'timeout' in kwargs:
901 905 t = kwargs['timeout']
902 906 if not isinstance(t, (int, long, float, type(None))):
903 907 raise TypeError("Invalid type for timeout: %r"%type(t))
904 908 if t is not None:
905 909 if t < 0:
906 910 raise ValueError("Invalid timeout: %s"%t)
907 911 self.timeout = t
908 912
909 913 @sync_results
910 914 @save_ids
911 915 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
912 916 after=None, follow=None, timeout=None,
913 917 targets=None, retries=None):
914 918 """calls f(*args, **kwargs) on a remote engine, returning the result.
915 919
916 920 This method temporarily sets all of `apply`'s flags for a single call.
917 921
918 922 Parameters
919 923 ----------
920 924
921 925 f : callable
922 926
923 927 args : list [default: empty]
924 928
925 929 kwargs : dict [default: empty]
926 930
927 931 block : bool [default: self.block]
928 932 whether to block
929 933 track : bool [default: self.track]
930 934 whether to ask zmq to track the message, for safe non-copying sends
931 935
932 936 !!!!!! TODO: THE REST HERE !!!!
933 937
934 938 Returns
935 939 -------
936 940
937 941 if self.block is False:
938 942 returns AsyncResult
939 943 else:
940 944 returns actual result of f(*args, **kwargs) on the engine(s)
941 945 This will be a list of self.targets is also a list (even length 1), or
942 946 the single result if self.targets is an integer engine id
943 947 """
944 948
945 949 # validate whether we can run
946 950 if self._socket.closed:
947 951 msg = "Task farming is disabled"
948 952 if self._task_scheme == 'pure':
949 953 msg += " because the pure ZMQ scheduler cannot handle"
950 954 msg += " disappearing engines."
951 955 raise RuntimeError(msg)
952 956
953 957 if self._task_scheme == 'pure':
954 958 # pure zmq scheme doesn't support extra features
955 959 msg = "Pure ZMQ scheduler doesn't support the following flags:"
956 960 "follow, after, retries, targets, timeout"
957 961 if (follow or after or retries or targets or timeout):
958 962 # hard fail on Scheduler flags
959 963 raise RuntimeError(msg)
960 964 if isinstance(f, dependent):
961 965 # soft warn on functional dependencies
962 966 warnings.warn(msg, RuntimeWarning)
963 967
964 968 # build args
965 969 args = [] if args is None else args
966 970 kwargs = {} if kwargs is None else kwargs
967 971 block = self.block if block is None else block
968 972 track = self.track if track is None else track
969 973 after = self.after if after is None else after
970 974 retries = self.retries if retries is None else retries
971 975 follow = self.follow if follow is None else follow
972 976 timeout = self.timeout if timeout is None else timeout
973 977 targets = self.targets if targets is None else targets
974 978
975 979 if not isinstance(retries, int):
976 980 raise TypeError('retries must be int, not %r'%type(retries))
977 981
978 982 if targets is None:
979 983 idents = []
980 984 else:
981 985 idents = self.client._build_targets(targets)[0]
982 986 # ensure *not* bytes
983 987 idents = [ ident.decode() for ident in idents ]
984 988
985 989 after = self._render_dependency(after)
986 990 follow = self._render_dependency(follow)
987 991 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
988 992
989 993 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
990 994 subheader=subheader)
991 995 tracker = None if track is False else msg['tracker']
992 996
993 997 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
994 998
995 999 if block:
996 1000 try:
997 1001 return ar.get()
998 1002 except KeyboardInterrupt:
999 1003 pass
1000 1004 return ar
1001 1005
1002 1006 @spin_after
1003 1007 @save_ids
1004 1008 def map(self, f, *sequences, **kwargs):
1005 1009 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1006 1010
1007 1011 Parallel version of builtin `map`, load-balanced by this View.
1008 1012
1009 1013 `block`, and `chunksize` can be specified by keyword only.
1010 1014
1011 1015 Each `chunksize` elements will be a separate task, and will be
1012 1016 load-balanced. This lets individual elements be available for iteration
1013 1017 as soon as they arrive.
1014 1018
1015 1019 Parameters
1016 1020 ----------
1017 1021
1018 1022 f : callable
1019 1023 function to be mapped
1020 1024 *sequences: one or more sequences of matching length
1021 1025 the sequences to be distributed and passed to `f`
1022 1026 block : bool [default self.block]
1023 1027 whether to wait for the result or not
1024 1028 track : bool
1025 1029 whether to create a MessageTracker to allow the user to
1026 1030 safely edit after arrays and buffers during non-copying
1027 1031 sends.
1028 1032 chunksize : int [default 1]
1029 1033 how many elements should be in each task.
1030 1034 ordered : bool [default True]
1031 1035 Whether the results should be gathered as they arrive, or enforce
1032 1036 the order of submission.
1033 1037
1034 1038 Only applies when iterating through AsyncMapResult as results arrive.
1035 1039 Has no effect when block=True.
1036 1040
1037 1041 Returns
1038 1042 -------
1039 1043
1040 1044 if block=False:
1041 1045 AsyncMapResult
1042 1046 An object like AsyncResult, but which reassembles the sequence of results
1043 1047 into a single list. AsyncMapResults can be iterated through before all
1044 1048 results are complete.
1045 1049 else:
1046 1050 the result of map(f,*sequences)
1047 1051
1048 1052 """
1049 1053
1050 1054 # default
1051 1055 block = kwargs.get('block', self.block)
1052 1056 chunksize = kwargs.get('chunksize', 1)
1053 1057 ordered = kwargs.get('ordered', True)
1054 1058
1055 1059 keyset = set(kwargs.keys())
1056 1060 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1057 1061 if extra_keys:
1058 1062 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1059 1063
1060 1064 assert len(sequences) > 0, "must have some sequences to map onto!"
1061 1065
1062 1066 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1063 1067 return pf.map(*sequences)
1064 1068
1065 1069 __all__ = ['LoadBalancedView', 'DirectView']
General Comments 0
You need to be logged in to leave comments. Login now