##// END OF EJS Templates
remove leftover (
Jens H. Nielsen -
Show More
@@ -1,1048 +1,1048 b''
1 1 """Views of remote engines.
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import imp
19 19 import sys
20 20 import warnings
21 21 from contextlib import contextmanager
22 22 from types import ModuleType
23 23
24 24 import zmq
25 25
26 26 from IPython.testing.skipdoctest import skip_doctest
27 27 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat, CInt
28 28 from IPython.external.decorator import decorator
29 29
30 30 from IPython.parallel import util
31 31 from IPython.parallel.controller.dependency import Dependency, dependent
32 32
33 33 from . import map as Map
34 34 from .asyncresult import AsyncResult, AsyncMapResult
35 35 from .remotefunction import ParallelFunction, parallel, remote
36 36
37 37 #-----------------------------------------------------------------------------
38 38 # Decorators
39 39 #-----------------------------------------------------------------------------
40 40
41 41 @decorator
42 42 def save_ids(f, self, *args, **kwargs):
43 43 """Keep our history and outstanding attributes up to date after a method call."""
44 44 n_previous = len(self.client.history)
45 45 try:
46 46 ret = f(self, *args, **kwargs)
47 47 finally:
48 48 nmsgs = len(self.client.history) - n_previous
49 49 msg_ids = self.client.history[-nmsgs:]
50 50 self.history.extend(msg_ids)
51 51 map(self.outstanding.add, msg_ids)
52 52 return ret
53 53
54 54 @decorator
55 55 def sync_results(f, self, *args, **kwargs):
56 56 """sync relevant results from self.client to our results attribute."""
57 57 ret = f(self, *args, **kwargs)
58 58 delta = self.outstanding.difference(self.client.outstanding)
59 59 completed = self.outstanding.intersection(delta)
60 60 self.outstanding = self.outstanding.difference(completed)
61 61 for msg_id in completed:
62 62 self.results[msg_id] = self.client.results[msg_id]
63 63 return ret
64 64
65 65 @decorator
66 66 def spin_after(f, self, *args, **kwargs):
67 67 """call spin after the method."""
68 68 ret = f(self, *args, **kwargs)
69 69 self.spin()
70 70 return ret
71 71
72 72 #-----------------------------------------------------------------------------
73 73 # Classes
74 74 #-----------------------------------------------------------------------------
75 75
76 76 @skip_doctest
77 77 class View(HasTraits):
78 78 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
79 79
80 80 Don't use this class, use subclasses.
81 81
82 82 Methods
83 83 -------
84 84
85 85 spin
86 86 flushes incoming results and registration state changes
87 87 control methods spin, and requesting `ids` also ensures up to date
88 88
89 89 wait
90 90 wait on one or more msg_ids
91 91
92 92 execution methods
93 93 apply
94 94 legacy: execute, run
95 95
96 96 data movement
97 97 push, pull, scatter, gather
98 98
99 99 query methods
100 100 get_result, queue_status, purge_results, result_status
101 101
102 102 control methods
103 103 abort, shutdown
104 104
105 105 """
106 106 # flags
107 107 block=Bool(False)
108 108 track=Bool(True)
109 109 targets = Any()
110 110
111 111 history=List()
112 112 outstanding = Set()
113 113 results = Dict()
114 114 client = Instance('IPython.parallel.Client')
115 115
116 116 _socket = Instance('zmq.Socket')
117 117 _flag_names = List(['targets', 'block', 'track'])
118 118 _targets = Any()
119 119 _idents = Any()
120 120
121 121 def __init__(self, client=None, socket=None, **flags):
122 122 super(View, self).__init__(client=client, _socket=socket)
123 123 self.block = client.block
124 124
125 125 self.set_flags(**flags)
126 126
127 127 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
128 128
129 129
130 130 def __repr__(self):
131 131 strtargets = str(self.targets)
132 132 if len(strtargets) > 16:
133 133 strtargets = strtargets[:12]+'...]'
134 134 return "<%s %s>"%(self.__class__.__name__, strtargets)
135 135
136 136 def set_flags(self, **kwargs):
137 137 """set my attribute flags by keyword.
138 138
139 139 Views determine behavior with a few attributes (`block`, `track`, etc.).
140 140 These attributes can be set all at once by name with this method.
141 141
142 142 Parameters
143 143 ----------
144 144
145 145 block : bool
146 146 whether to wait for results
147 147 track : bool
148 148 whether to create a MessageTracker to allow the user to
149 149 safely edit after arrays and buffers during non-copying
150 150 sends.
151 151 """
152 152 for name, value in kwargs.iteritems():
153 153 if name not in self._flag_names:
154 154 raise KeyError("Invalid name: %r"%name)
155 155 else:
156 156 setattr(self, name, value)
157 157
158 158 @contextmanager
159 159 def temp_flags(self, **kwargs):
160 160 """temporarily set flags, for use in `with` statements.
161 161
162 162 See set_flags for permanent setting of flags
163 163
164 164 Examples
165 165 --------
166 166
167 167 >>> view.track=False
168 168 ...
169 169 >>> with view.temp_flags(track=True):
170 170 ... ar = view.apply(dostuff, my_big_array)
171 171 ... ar.tracker.wait() # wait for send to finish
172 172 >>> view.track
173 173 False
174 174
175 175 """
176 176 # preflight: save flags, and set temporaries
177 177 saved_flags = {}
178 178 for f in self._flag_names:
179 179 saved_flags[f] = getattr(self, f)
180 180 self.set_flags(**kwargs)
181 181 # yield to the with-statement block
182 182 try:
183 183 yield
184 184 finally:
185 185 # postflight: restore saved flags
186 186 self.set_flags(**saved_flags)
187 187
188 188
189 189 #----------------------------------------------------------------
190 190 # apply
191 191 #----------------------------------------------------------------
192 192
193 193 @sync_results
194 194 @save_ids
195 195 def _really_apply(self, f, args, kwargs, block=None, **options):
196 196 """wrapper for client.send_apply_message"""
197 197 raise NotImplementedError("Implement in subclasses")
198 198
199 199 def apply(self, f, *args, **kwargs):
200 200 """calls f(*args, **kwargs) on remote engines, returning the result.
201 201
202 202 This method sets all apply flags via this View's attributes.
203 203
204 204 if self.block is False:
205 205 returns AsyncResult
206 206 else:
207 207 returns actual result of f(*args, **kwargs)
208 208 """
209 209 return self._really_apply(f, args, kwargs)
210 210
211 211 def apply_async(self, f, *args, **kwargs):
212 212 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
213 213
214 214 returns AsyncResult
215 215 """
216 216 return self._really_apply(f, args, kwargs, block=False)
217 217
218 218 @spin_after
219 219 def apply_sync(self, f, *args, **kwargs):
220 220 """calls f(*args, **kwargs) on remote engines in a blocking manner,
221 221 returning the result.
222 222
223 223 returns: actual result of f(*args, **kwargs)
224 224 """
225 225 return self._really_apply(f, args, kwargs, block=True)
226 226
227 227 #----------------------------------------------------------------
228 228 # wrappers for client and control methods
229 229 #----------------------------------------------------------------
230 230 @sync_results
231 231 def spin(self):
232 232 """spin the client, and sync"""
233 233 self.client.spin()
234 234
235 235 @sync_results
236 236 def wait(self, jobs=None, timeout=-1):
237 237 """waits on one or more `jobs`, for up to `timeout` seconds.
238 238
239 239 Parameters
240 240 ----------
241 241
242 242 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
243 243 ints are indices to self.history
244 244 strs are msg_ids
245 245 default: wait on all outstanding messages
246 246 timeout : float
247 247 a time in seconds, after which to give up.
248 248 default is -1, which means no timeout
249 249
250 250 Returns
251 251 -------
252 252
253 253 True : when all msg_ids are done
254 254 False : timeout reached, some msg_ids still outstanding
255 255 """
256 256 if jobs is None:
257 257 jobs = self.history
258 258 return self.client.wait(jobs, timeout)
259 259
260 260 def abort(self, jobs=None, targets=None, block=None):
261 261 """Abort jobs on my engines.
262 262
263 263 Parameters
264 264 ----------
265 265
266 266 jobs : None, str, list of strs, optional
267 267 if None: abort all jobs.
268 268 else: abort specific msg_id(s).
269 269 """
270 270 block = block if block is not None else self.block
271 271 targets = targets if targets is not None else self.targets
272 272 return self.client.abort(jobs=jobs, targets=targets, block=block)
273 273
274 274 def queue_status(self, targets=None, verbose=False):
275 275 """Fetch the Queue status of my engines"""
276 276 targets = targets if targets is not None else self.targets
277 277 return self.client.queue_status(targets=targets, verbose=verbose)
278 278
279 279 def purge_results(self, jobs=[], targets=[]):
280 280 """Instruct the controller to forget specific results."""
281 281 if targets is None or targets == 'all':
282 282 targets = self.targets
283 283 return self.client.purge_results(jobs=jobs, targets=targets)
284 284
285 285 def shutdown(self, targets=None, restart=False, hub=False, block=None):
286 286 """Terminates one or more engine processes, optionally including the hub.
287 287 """
288 288 block = self.block if block is None else block
289 289 if targets is None or targets == 'all':
290 290 targets = self.targets
291 291 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
292 292
293 293 @spin_after
294 294 def get_result(self, indices_or_msg_ids=None):
295 295 """return one or more results, specified by history index or msg_id.
296 296
297 297 See client.get_result for details.
298 298
299 299 """
300 300
301 301 if indices_or_msg_ids is None:
302 302 indices_or_msg_ids = -1
303 303 if isinstance(indices_or_msg_ids, int):
304 304 indices_or_msg_ids = self.history[indices_or_msg_ids]
305 305 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
306 306 indices_or_msg_ids = list(indices_or_msg_ids)
307 307 for i,index in enumerate(indices_or_msg_ids):
308 308 if isinstance(index, int):
309 309 indices_or_msg_ids[i] = self.history[index]
310 310 return self.client.get_result(indices_or_msg_ids)
311 311
312 312 #-------------------------------------------------------------------
313 313 # Map
314 314 #-------------------------------------------------------------------
315 315
316 316 def map(self, f, *sequences, **kwargs):
317 317 """override in subclasses"""
318 318 raise NotImplementedError
319 319
320 320 def map_async(self, f, *sequences, **kwargs):
321 321 """Parallel version of builtin `map`, using this view's engines.
322 322
323 323 This is equivalent to map(...block=False)
324 324
325 325 See `self.map` for details.
326 326 """
327 327 if 'block' in kwargs:
328 328 raise TypeError("map_async doesn't take a `block` keyword argument.")
329 329 kwargs['block'] = False
330 330 return self.map(f,*sequences,**kwargs)
331 331
332 332 def map_sync(self, f, *sequences, **kwargs):
333 333 """Parallel version of builtin `map`, using this view's engines.
334 334
335 335 This is equivalent to map(...block=True)
336 336
337 337 See `self.map` for details.
338 338 """
339 339 if 'block' in kwargs:
340 340 raise TypeError("map_sync doesn't take a `block` keyword argument.")
341 341 kwargs['block'] = True
342 342 return self.map(f,*sequences,**kwargs)
343 343
344 344 def imap(self, f, *sequences, **kwargs):
345 345 """Parallel version of `itertools.imap`.
346 346
347 347 See `self.map` for details.
348 348
349 349 """
350 350
351 351 return iter(self.map_async(f,*sequences, **kwargs))
352 352
353 353 #-------------------------------------------------------------------
354 354 # Decorators
355 355 #-------------------------------------------------------------------
356 356
357 357 def remote(self, block=True, **flags):
358 358 """Decorator for making a RemoteFunction"""
359 359 block = self.block if block is None else block
360 360 return remote(self, block=block, **flags)
361 361
362 362 def parallel(self, dist='b', block=None, **flags):
363 363 """Decorator for making a ParallelFunction"""
364 364 block = self.block if block is None else block
365 365 return parallel(self, dist=dist, block=block, **flags)
366 366
367 367 @skip_doctest
368 368 class DirectView(View):
369 369 """Direct Multiplexer View of one or more engines.
370 370
371 371 These are created via indexed access to a client:
372 372
373 373 >>> dv_1 = client[1]
374 374 >>> dv_all = client[:]
375 375 >>> dv_even = client[::2]
376 376 >>> dv_some = client[1:3]
377 377
378 378 This object provides dictionary access to engine namespaces:
379 379
380 380 # push a=5:
381 381 >>> dv['a'] = 5
382 382 # pull 'foo':
383 383 >>> db['foo']
384 384
385 385 """
386 386
387 387 def __init__(self, client=None, socket=None, targets=None):
388 388 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
389 389
390 390 @property
391 391 def importer(self):
392 392 """sync_imports(local=True) as a property.
393 393
394 394 See sync_imports for details.
395 395
396 396 """
397 397 return self.sync_imports(True)
398 398
399 399 @contextmanager
400 400 def sync_imports(self, local=True):
401 401 """Context Manager for performing simultaneous local and remote imports.
402 402
403 403 'import x as y' will *not* work. The 'as y' part will simply be ignored.
404 404
405 405 >>> with view.sync_imports():
406 406 ... from numpy import recarray
407 407 importing recarray from numpy on engine(s)
408 408
409 409 """
410 410 import __builtin__
411 411 local_import = __builtin__.__import__
412 412 modules = set()
413 413 results = []
414 414 @util.interactive
415 415 def remote_import(name, fromlist, level):
416 416 """the function to be passed to apply, that actually performs the import
417 417 on the engine, and loads up the user namespace.
418 418 """
419 419 import sys
420 420 user_ns = globals()
421 421 mod = __import__(name, fromlist=fromlist, level=level)
422 422 if fromlist:
423 423 for key in fromlist:
424 424 user_ns[key] = getattr(mod, key)
425 425 else:
426 426 user_ns[name] = sys.modules[name]
427 427
428 428 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
429 429 """the drop-in replacement for __import__, that optionally imports
430 430 locally as well.
431 431 """
432 432 # don't override nested imports
433 433 save_import = __builtin__.__import__
434 434 __builtin__.__import__ = local_import
435 435
436 436 if imp.lock_held():
437 437 # this is a side-effect import, don't do it remotely, or even
438 438 # ignore the local effects
439 439 return local_import(name, globals, locals, fromlist, level)
440 440
441 441 imp.acquire_lock()
442 442 if local:
443 443 mod = local_import(name, globals, locals, fromlist, level)
444 444 else:
445 445 raise NotImplementedError("remote-only imports not yet implemented")
446 446 imp.release_lock()
447 447
448 448 key = name+':'+','.join(fromlist or [])
449 449 if level == -1 and key not in modules:
450 450 modules.add(key)
451 451 if fromlist:
452 452 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
453 453 else:
454 454 print "importing %s on engine(s)"%name
455 455 results.append(self.apply_async(remote_import, name, fromlist, level))
456 456 # restore override
457 457 __builtin__.__import__ = save_import
458 458
459 459 return mod
460 460
461 461 # override __import__
462 462 __builtin__.__import__ = view_import
463 463 try:
464 464 # enter the block
465 465 yield
466 466 except ImportError:
467 467 if not local:
468 468 # ignore import errors if not doing local imports
469 469 pass
470 470 finally:
471 471 # always restore __import__
472 472 __builtin__.__import__ = local_import
473 473
474 474 for r in results:
475 475 # raise possible remote ImportErrors here
476 476 r.get()
477 477
478 478
479 479 @sync_results
480 480 @save_ids
481 481 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
482 482 """calls f(*args, **kwargs) on remote engines, returning the result.
483 483
484 484 This method sets all of `apply`'s flags via this View's attributes.
485 485
486 486 Parameters
487 487 ----------
488 488
489 489 f : callable
490 490
491 491 args : list [default: empty]
492 492
493 493 kwargs : dict [default: empty]
494 494
495 495 targets : target list [default: self.targets]
496 496 where to run
497 497 block : bool [default: self.block]
498 498 whether to block
499 499 track : bool [default: self.track]
500 500 whether to ask zmq to track the message, for safe non-copying sends
501 501
502 502 Returns
503 503 -------
504 504
505 505 if self.block is False:
506 506 returns AsyncResult
507 507 else:
508 508 returns actual result of f(*args, **kwargs) on the engine(s)
509 509 This will be a list of self.targets is also a list (even length 1), or
510 510 the single result if self.targets is an integer engine id
511 511 """
512 512 args = [] if args is None else args
513 513 kwargs = {} if kwargs is None else kwargs
514 514 block = self.block if block is None else block
515 515 track = self.track if track is None else track
516 516 targets = self.targets if targets is None else targets
517 517
518 518 _idents = self.client._build_targets(targets)[0]
519 519 msg_ids = []
520 520 trackers = []
521 521 for ident in _idents:
522 522 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
523 523 ident=ident)
524 524 if track:
525 525 trackers.append(msg['tracker'])
526 526 msg_ids.append(msg['header']['msg_id'])
527 527 tracker = None if track is False else zmq.MessageTracker(*trackers)
528 528 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
529 529 if block:
530 530 try:
531 531 return ar.get()
532 532 except KeyboardInterrupt:
533 533 pass
534 534 return ar
535 535
536 536 @spin_after
537 537 def map(self, f, *sequences, **kwargs):
538 538 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
539 539
540 540 Parallel version of builtin `map`, using this View's `targets`.
541 541
542 542 There will be one task per target, so work will be chunked
543 543 if the sequences are longer than `targets`.
544 544
545 545 Results can be iterated as they are ready, but will become available in chunks.
546 546
547 547 Parameters
548 548 ----------
549 549
550 550 f : callable
551 551 function to be mapped
552 552 *sequences: one or more sequences of matching length
553 553 the sequences to be distributed and passed to `f`
554 554 block : bool
555 555 whether to wait for the result or not [default self.block]
556 556
557 557 Returns
558 558 -------
559 559
560 560 if block=False:
561 561 AsyncMapResult
562 562 An object like AsyncResult, but which reassembles the sequence of results
563 563 into a single list. AsyncMapResults can be iterated through before all
564 564 results are complete.
565 565 else:
566 566 list
567 567 the result of map(f,*sequences)
568 568 """
569 569
570 570 block = kwargs.pop('block', self.block)
571 571 for k in kwargs.keys():
572 572 if k not in ['block', 'track']:
573 573 raise TypeError("invalid keyword arg, %r"%k)
574 574
575 575 assert len(sequences) > 0, "must have some sequences to map onto!"
576 576 pf = ParallelFunction(self, f, block=block, **kwargs)
577 577 return pf.map(*sequences)
578 578
579 579 def execute(self, code, targets=None, block=None):
580 580 """Executes `code` on `targets` in blocking or nonblocking manner.
581 581
582 582 ``execute`` is always `bound` (affects engine namespace)
583 583
584 584 Parameters
585 585 ----------
586 586
587 587 code : str
588 588 the code string to be executed
589 589 block : bool
590 590 whether or not to wait until done to return
591 591 default: self.block
592 592 """
593 593 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
594 594
595 595 def run(self, filename, targets=None, block=None):
596 596 """Execute contents of `filename` on my engine(s).
597 597
598 598 This simply reads the contents of the file and calls `execute`.
599 599
600 600 Parameters
601 601 ----------
602 602
603 603 filename : str
604 604 The path to the file
605 605 targets : int/str/list of ints/strs
606 606 the engines on which to execute
607 607 default : all
608 608 block : bool
609 609 whether or not to wait until done
610 610 default: self.block
611 611
612 612 """
613 613 with open(filename, 'r') as f:
614 614 # add newline in case of trailing indented whitespace
615 615 # which will cause SyntaxError
616 616 code = f.read()+'\n'
617 617 return self.execute(code, block=block, targets=targets)
618 618
619 619 def update(self, ns):
620 620 """update remote namespace with dict `ns`
621 621
622 622 See `push` for details.
623 623 """
624 624 return self.push(ns, block=self.block, track=self.track)
625 625
626 626 def push(self, ns, targets=None, block=None, track=None):
627 627 """update remote namespace with dict `ns`
628 628
629 629 Parameters
630 630 ----------
631 631
632 632 ns : dict
633 633 dict of keys with which to update engine namespace(s)
634 634 block : bool [default : self.block]
635 635 whether to wait to be notified of engine receipt
636 636
637 637 """
638 638
639 639 block = block if block is not None else self.block
640 640 track = track if track is not None else self.track
641 641 targets = targets if targets is not None else self.targets
642 642 # applier = self.apply_sync if block else self.apply_async
643 643 if not isinstance(ns, dict):
644 644 raise TypeError("Must be a dict, not %s"%type(ns))
645 645 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
646 646
647 647 def get(self, key_s):
648 648 """get object(s) by `key_s` from remote namespace
649 649
650 650 see `pull` for details.
651 651 """
652 652 # block = block if block is not None else self.block
653 653 return self.pull(key_s, block=True)
654 654
655 655 def pull(self, names, targets=None, block=None):
656 656 """get object(s) by `name` from remote namespace
657 657
658 658 will return one object if it is a key.
659 659 can also take a list of keys, in which case it will return a list of objects.
660 660 """
661 661 block = block if block is not None else self.block
662 662 targets = targets if targets is not None else self.targets
663 663 applier = self.apply_sync if block else self.apply_async
664 664 if isinstance(names, basestring):
665 665 pass
666 666 elif isinstance(names, (list,tuple,set)):
667 667 for key in names:
668 668 if not isinstance(key, basestring):
669 669 raise TypeError("keys must be str, not type %r"%type(key))
670 670 else:
671 671 raise TypeError("names must be strs, not %r"%names)
672 672 return self._really_apply(util._pull, (names,), block=block, targets=targets)
673 673
674 674 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
675 675 """
676 676 Partition a Python sequence and send the partitions to a set of engines.
677 677 """
678 678 block = block if block is not None else self.block
679 679 track = track if track is not None else self.track
680 680 targets = targets if targets is not None else self.targets
681 681
682 682 mapObject = Map.dists[dist]()
683 683 nparts = len(targets)
684 684 msg_ids = []
685 685 trackers = []
686 686 for index, engineid in enumerate(targets):
687 687 partition = mapObject.getPartition(seq, index, nparts)
688 688 if flatten and len(partition) == 1:
689 689 ns = {key: partition[0]}
690 690 else:
691 691 ns = {key: partition}
692 692 r = self.push(ns, block=False, track=track, targets=engineid)
693 693 msg_ids.extend(r.msg_ids)
694 694 if track:
695 695 trackers.append(r._tracker)
696 696
697 697 if track:
698 698 tracker = zmq.MessageTracker(*trackers)
699 699 else:
700 700 tracker = None
701 701
702 702 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
703 703 if block:
704 704 r.wait()
705 705 else:
706 706 return r
707 707
708 708 @sync_results
709 709 @save_ids
710 710 def gather(self, key, dist='b', targets=None, block=None):
711 711 """
712 712 Gather a partitioned sequence on a set of engines as a single local seq.
713 713 """
714 714 block = block if block is not None else self.block
715 715 targets = targets if targets is not None else self.targets
716 716 mapObject = Map.dists[dist]()
717 717 msg_ids = []
718 718
719 719 for index, engineid in enumerate(targets):
720 720 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
721 721
722 722 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
723 723
724 724 if block:
725 725 try:
726 726 return r.get()
727 727 except KeyboardInterrupt:
728 728 pass
729 729 return r
730 730
731 731 def __getitem__(self, key):
732 732 return self.get(key)
733 733
734 734 def __setitem__(self,key, value):
735 735 self.update({key:value})
736 736
737 737 def clear(self, targets=None, block=False):
738 738 """Clear the remote namespaces on my engines."""
739 739 block = block if block is not None else self.block
740 740 targets = targets if targets is not None else self.targets
741 741 return self.client.clear(targets=targets, block=block)
742 742
743 743 def kill(self, targets=None, block=True):
744 744 """Kill my engines."""
745 745 block = block if block is not None else self.block
746 746 targets = targets if targets is not None else self.targets
747 747 return self.client.kill(targets=targets, block=block)
748 748
749 749 #----------------------------------------
750 750 # activate for %px,%autopx magics
751 751 #----------------------------------------
752 752 def activate(self):
753 753 """Make this `View` active for parallel magic commands.
754 754
755 755 IPython has a magic command syntax to work with `MultiEngineClient` objects.
756 756 In a given IPython session there is a single active one. While
757 757 there can be many `Views` created and used by the user,
758 758 there is only one active one. The active `View` is used whenever
759 759 the magic commands %px and %autopx are used.
760 760
761 761 The activate() method is called on a given `View` to make it
762 762 active. Once this has been done, the magic commands can be used.
763 763 """
764 764
765 765 try:
766 766 # This is injected into __builtins__.
767 767 ip = get_ipython()
768 768 except NameError:
769 769 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
770 770 else:
771 771 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
772 772 if pmagic is None:
773 773 ip.magic_load_ext('parallelmagic')
774 774 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
775 775
776 776 pmagic.active_view = self
777 777
778 778
779 779 @skip_doctest
780 780 class LoadBalancedView(View):
781 781 """An load-balancing View that only executes via the Task scheduler.
782 782
783 783 Load-balanced views can be created with the client's `view` method:
784 784
785 785 >>> v = client.load_balanced_view()
786 786
787 787 or targets can be specified, to restrict the potential destinations:
788 788
789 >>> v = client.client.load_balanced_view(([1,3])
789 >>> v = client.client.load_balanced_view([1,3])
790 790
791 791 which would restrict loadbalancing to between engines 1 and 3.
792 792
793 793 """
794 794
795 795 follow=Any()
796 796 after=Any()
797 797 timeout=CFloat()
798 798 retries = CInt(0)
799 799
800 800 _task_scheme = Any()
801 801 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
802 802
803 803 def __init__(self, client=None, socket=None, **flags):
804 804 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
805 805 self._task_scheme=client._task_scheme
806 806
807 807 def _validate_dependency(self, dep):
808 808 """validate a dependency.
809 809
810 810 For use in `set_flags`.
811 811 """
812 812 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
813 813 return True
814 814 elif isinstance(dep, (list,set, tuple)):
815 815 for d in dep:
816 816 if not isinstance(d, (basestring, AsyncResult)):
817 817 return False
818 818 elif isinstance(dep, dict):
819 819 if set(dep.keys()) != set(Dependency().as_dict().keys()):
820 820 return False
821 821 if not isinstance(dep['msg_ids'], list):
822 822 return False
823 823 for d in dep['msg_ids']:
824 824 if not isinstance(d, basestring):
825 825 return False
826 826 else:
827 827 return False
828 828
829 829 return True
830 830
831 831 def _render_dependency(self, dep):
832 832 """helper for building jsonable dependencies from various input forms."""
833 833 if isinstance(dep, Dependency):
834 834 return dep.as_dict()
835 835 elif isinstance(dep, AsyncResult):
836 836 return dep.msg_ids
837 837 elif dep is None:
838 838 return []
839 839 else:
840 840 # pass to Dependency constructor
841 841 return list(Dependency(dep))
842 842
843 843 def set_flags(self, **kwargs):
844 844 """set my attribute flags by keyword.
845 845
846 846 A View is a wrapper for the Client's apply method, but with attributes
847 847 that specify keyword arguments, those attributes can be set by keyword
848 848 argument with this method.
849 849
850 850 Parameters
851 851 ----------
852 852
853 853 block : bool
854 854 whether to wait for results
855 855 track : bool
856 856 whether to create a MessageTracker to allow the user to
857 857 safely edit after arrays and buffers during non-copying
858 858 sends.
859 859
860 860 after : Dependency or collection of msg_ids
861 861 Only for load-balanced execution (targets=None)
862 862 Specify a list of msg_ids as a time-based dependency.
863 863 This job will only be run *after* the dependencies
864 864 have been met.
865 865
866 866 follow : Dependency or collection of msg_ids
867 867 Only for load-balanced execution (targets=None)
868 868 Specify a list of msg_ids as a location-based dependency.
869 869 This job will only be run on an engine where this dependency
870 870 is met.
871 871
872 872 timeout : float/int or None
873 873 Only for load-balanced execution (targets=None)
874 874 Specify an amount of time (in seconds) for the scheduler to
875 875 wait for dependencies to be met before failing with a
876 876 DependencyTimeout.
877 877
878 878 retries : int
879 879 Number of times a task will be retried on failure.
880 880 """
881 881
882 882 super(LoadBalancedView, self).set_flags(**kwargs)
883 883 for name in ('follow', 'after'):
884 884 if name in kwargs:
885 885 value = kwargs[name]
886 886 if self._validate_dependency(value):
887 887 setattr(self, name, value)
888 888 else:
889 889 raise ValueError("Invalid dependency: %r"%value)
890 890 if 'timeout' in kwargs:
891 891 t = kwargs['timeout']
892 892 if not isinstance(t, (int, long, float, type(None))):
893 893 raise TypeError("Invalid type for timeout: %r"%type(t))
894 894 if t is not None:
895 895 if t < 0:
896 896 raise ValueError("Invalid timeout: %s"%t)
897 897 self.timeout = t
898 898
899 899 @sync_results
900 900 @save_ids
901 901 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
902 902 after=None, follow=None, timeout=None,
903 903 targets=None, retries=None):
904 904 """calls f(*args, **kwargs) on a remote engine, returning the result.
905 905
906 906 This method temporarily sets all of `apply`'s flags for a single call.
907 907
908 908 Parameters
909 909 ----------
910 910
911 911 f : callable
912 912
913 913 args : list [default: empty]
914 914
915 915 kwargs : dict [default: empty]
916 916
917 917 block : bool [default: self.block]
918 918 whether to block
919 919 track : bool [default: self.track]
920 920 whether to ask zmq to track the message, for safe non-copying sends
921 921
922 922 !!!!!! TODO: THE REST HERE !!!!
923 923
924 924 Returns
925 925 -------
926 926
927 927 if self.block is False:
928 928 returns AsyncResult
929 929 else:
930 930 returns actual result of f(*args, **kwargs) on the engine(s)
931 931 This will be a list of self.targets is also a list (even length 1), or
932 932 the single result if self.targets is an integer engine id
933 933 """
934 934
935 935 # validate whether we can run
936 936 if self._socket.closed:
937 937 msg = "Task farming is disabled"
938 938 if self._task_scheme == 'pure':
939 939 msg += " because the pure ZMQ scheduler cannot handle"
940 940 msg += " disappearing engines."
941 941 raise RuntimeError(msg)
942 942
943 943 if self._task_scheme == 'pure':
944 944 # pure zmq scheme doesn't support extra features
945 945 msg = "Pure ZMQ scheduler doesn't support the following flags:"
946 946 "follow, after, retries, targets, timeout"
947 947 if (follow or after or retries or targets or timeout):
948 948 # hard fail on Scheduler flags
949 949 raise RuntimeError(msg)
950 950 if isinstance(f, dependent):
951 951 # soft warn on functional dependencies
952 952 warnings.warn(msg, RuntimeWarning)
953 953
954 954 # build args
955 955 args = [] if args is None else args
956 956 kwargs = {} if kwargs is None else kwargs
957 957 block = self.block if block is None else block
958 958 track = self.track if track is None else track
959 959 after = self.after if after is None else after
960 960 retries = self.retries if retries is None else retries
961 961 follow = self.follow if follow is None else follow
962 962 timeout = self.timeout if timeout is None else timeout
963 963 targets = self.targets if targets is None else targets
964 964
965 965 if not isinstance(retries, int):
966 966 raise TypeError('retries must be int, not %r'%type(retries))
967 967
968 968 if targets is None:
969 969 idents = []
970 970 else:
971 971 idents = self.client._build_targets(targets)[0]
972 972 # ensure *not* bytes
973 973 idents = [ ident.decode() for ident in idents ]
974 974
975 975 after = self._render_dependency(after)
976 976 follow = self._render_dependency(follow)
977 977 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
978 978
979 979 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
980 980 subheader=subheader)
981 981 tracker = None if track is False else msg['tracker']
982 982
983 983 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
984 984
985 985 if block:
986 986 try:
987 987 return ar.get()
988 988 except KeyboardInterrupt:
989 989 pass
990 990 return ar
991 991
992 992 @spin_after
993 993 @save_ids
994 994 def map(self, f, *sequences, **kwargs):
995 995 """view.map(f, *sequences, block=self.block, chunksize=1) => list|AsyncMapResult
996 996
997 997 Parallel version of builtin `map`, load-balanced by this View.
998 998
999 999 `block`, and `chunksize` can be specified by keyword only.
1000 1000
1001 1001 Each `chunksize` elements will be a separate task, and will be
1002 1002 load-balanced. This lets individual elements be available for iteration
1003 1003 as soon as they arrive.
1004 1004
1005 1005 Parameters
1006 1006 ----------
1007 1007
1008 1008 f : callable
1009 1009 function to be mapped
1010 1010 *sequences: one or more sequences of matching length
1011 1011 the sequences to be distributed and passed to `f`
1012 1012 block : bool
1013 1013 whether to wait for the result or not [default self.block]
1014 1014 track : bool
1015 1015 whether to create a MessageTracker to allow the user to
1016 1016 safely edit after arrays and buffers during non-copying
1017 1017 sends.
1018 1018 chunksize : int
1019 1019 how many elements should be in each task [default 1]
1020 1020
1021 1021 Returns
1022 1022 -------
1023 1023
1024 1024 if block=False:
1025 1025 AsyncMapResult
1026 1026 An object like AsyncResult, but which reassembles the sequence of results
1027 1027 into a single list. AsyncMapResults can be iterated through before all
1028 1028 results are complete.
1029 1029 else:
1030 1030 the result of map(f,*sequences)
1031 1031
1032 1032 """
1033 1033
1034 1034 # default
1035 1035 block = kwargs.get('block', self.block)
1036 1036 chunksize = kwargs.get('chunksize', 1)
1037 1037
1038 1038 keyset = set(kwargs.keys())
1039 1039 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1040 1040 if extra_keys:
1041 1041 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1042 1042
1043 1043 assert len(sequences) > 0, "must have some sequences to map onto!"
1044 1044
1045 1045 pf = ParallelFunction(self, f, block=block, chunksize=chunksize)
1046 1046 return pf.map(*sequences)
1047 1047
1048 1048 __all__ = ['LoadBalancedView', 'DirectView']
General Comments 0
You need to be logged in to leave comments. Login now