##// END OF EJS Templates
add DirectView.use_dill
MinRK -
Show More
@@ -1,1119 +1,1130 b''
1 1 """Views of remote engines.
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 from __future__ import print_function
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2010-2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18
19 19 import imp
20 20 import sys
21 21 import warnings
22 22 from contextlib import contextmanager
23 23 from types import ModuleType
24 24
25 25 import zmq
26 26
27 27 from IPython.testing.skipdoctest import skip_doctest
28 from IPython.utils import pickleutil
28 29 from IPython.utils.traitlets import (
29 30 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
30 31 )
31 32 from IPython.external.decorator import decorator
32 33
33 34 from IPython.parallel import util
34 35 from IPython.parallel.controller.dependency import Dependency, dependent
35 36 from IPython.utils.py3compat import string_types, iteritems, PY3
36 37
37 38 from . import map as Map
38 39 from .asyncresult import AsyncResult, AsyncMapResult
39 40 from .remotefunction import ParallelFunction, parallel, remote, getname
40 41
41 42 #-----------------------------------------------------------------------------
42 43 # Decorators
43 44 #-----------------------------------------------------------------------------
44 45
45 46 @decorator
46 47 def save_ids(f, self, *args, **kwargs):
47 48 """Keep our history and outstanding attributes up to date after a method call."""
48 49 n_previous = len(self.client.history)
49 50 try:
50 51 ret = f(self, *args, **kwargs)
51 52 finally:
52 53 nmsgs = len(self.client.history) - n_previous
53 54 msg_ids = self.client.history[-nmsgs:]
54 55 self.history.extend(msg_ids)
55 56 self.outstanding.update(msg_ids)
56 57 return ret
57 58
58 59 @decorator
59 60 def sync_results(f, self, *args, **kwargs):
60 61 """sync relevant results from self.client to our results attribute."""
61 62 if self._in_sync_results:
62 63 return f(self, *args, **kwargs)
63 64 self._in_sync_results = True
64 65 try:
65 66 ret = f(self, *args, **kwargs)
66 67 finally:
67 68 self._in_sync_results = False
68 69 self._sync_results()
69 70 return ret
70 71
71 72 @decorator
72 73 def spin_after(f, self, *args, **kwargs):
73 74 """call spin after the method."""
74 75 ret = f(self, *args, **kwargs)
75 76 self.spin()
76 77 return ret
77 78
78 79 #-----------------------------------------------------------------------------
79 80 # Classes
80 81 #-----------------------------------------------------------------------------
81 82
82 83 @skip_doctest
83 84 class View(HasTraits):
84 85 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
85 86
86 87 Don't use this class, use subclasses.
87 88
88 89 Methods
89 90 -------
90 91
91 92 spin
92 93 flushes incoming results and registration state changes
93 94 control methods spin, and requesting `ids` also ensures up to date
94 95
95 96 wait
96 97 wait on one or more msg_ids
97 98
98 99 execution methods
99 100 apply
100 101 legacy: execute, run
101 102
102 103 data movement
103 104 push, pull, scatter, gather
104 105
105 106 query methods
106 107 get_result, queue_status, purge_results, result_status
107 108
108 109 control methods
109 110 abort, shutdown
110 111
111 112 """
112 113 # flags
113 114 block=Bool(False)
114 115 track=Bool(True)
115 116 targets = Any()
116 117
117 118 history=List()
118 119 outstanding = Set()
119 120 results = Dict()
120 121 client = Instance('IPython.parallel.Client')
121 122
122 123 _socket = Instance('zmq.Socket')
123 124 _flag_names = List(['targets', 'block', 'track'])
124 125 _in_sync_results = Bool(False)
125 126 _targets = Any()
126 127 _idents = Any()
127 128
128 129 def __init__(self, client=None, socket=None, **flags):
129 130 super(View, self).__init__(client=client, _socket=socket)
130 131 self.results = client.results
131 132 self.block = client.block
132 133
133 134 self.set_flags(**flags)
134 135
135 136 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
136 137
137 138 def __repr__(self):
138 139 strtargets = str(self.targets)
139 140 if len(strtargets) > 16:
140 141 strtargets = strtargets[:12]+'...]'
141 142 return "<%s %s>"%(self.__class__.__name__, strtargets)
142 143
143 144 def __len__(self):
144 145 if isinstance(self.targets, list):
145 146 return len(self.targets)
146 147 elif isinstance(self.targets, int):
147 148 return 1
148 149 else:
149 150 return len(self.client)
150 151
151 152 def set_flags(self, **kwargs):
152 153 """set my attribute flags by keyword.
153 154
154 155 Views determine behavior with a few attributes (`block`, `track`, etc.).
155 156 These attributes can be set all at once by name with this method.
156 157
157 158 Parameters
158 159 ----------
159 160
160 161 block : bool
161 162 whether to wait for results
162 163 track : bool
163 164 whether to create a MessageTracker to allow the user to
164 165 safely edit after arrays and buffers during non-copying
165 166 sends.
166 167 """
167 168 for name, value in iteritems(kwargs):
168 169 if name not in self._flag_names:
169 170 raise KeyError("Invalid name: %r"%name)
170 171 else:
171 172 setattr(self, name, value)
172 173
173 174 @contextmanager
174 175 def temp_flags(self, **kwargs):
175 176 """temporarily set flags, for use in `with` statements.
176 177
177 178 See set_flags for permanent setting of flags
178 179
179 180 Examples
180 181 --------
181 182
182 183 >>> view.track=False
183 184 ...
184 185 >>> with view.temp_flags(track=True):
185 186 ... ar = view.apply(dostuff, my_big_array)
186 187 ... ar.tracker.wait() # wait for send to finish
187 188 >>> view.track
188 189 False
189 190
190 191 """
191 192 # preflight: save flags, and set temporaries
192 193 saved_flags = {}
193 194 for f in self._flag_names:
194 195 saved_flags[f] = getattr(self, f)
195 196 self.set_flags(**kwargs)
196 197 # yield to the with-statement block
197 198 try:
198 199 yield
199 200 finally:
200 201 # postflight: restore saved flags
201 202 self.set_flags(**saved_flags)
202 203
203 204
204 205 #----------------------------------------------------------------
205 206 # apply
206 207 #----------------------------------------------------------------
207 208
208 209 def _sync_results(self):
209 210 """to be called by @sync_results decorator
210 211
211 212 after submitting any tasks.
212 213 """
213 214 delta = self.outstanding.difference(self.client.outstanding)
214 215 completed = self.outstanding.intersection(delta)
215 216 self.outstanding = self.outstanding.difference(completed)
216 217
217 218 @sync_results
218 219 @save_ids
219 220 def _really_apply(self, f, args, kwargs, block=None, **options):
220 221 """wrapper for client.send_apply_request"""
221 222 raise NotImplementedError("Implement in subclasses")
222 223
223 224 def apply(self, f, *args, **kwargs):
224 225 """calls f(*args, **kwargs) on remote engines, returning the result.
225 226
226 227 This method sets all apply flags via this View's attributes.
227 228
228 229 if self.block is False:
229 230 returns AsyncResult
230 231 else:
231 232 returns actual result of f(*args, **kwargs)
232 233 """
233 234 return self._really_apply(f, args, kwargs)
234 235
235 236 def apply_async(self, f, *args, **kwargs):
236 237 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
237 238
238 239 returns AsyncResult
239 240 """
240 241 return self._really_apply(f, args, kwargs, block=False)
241 242
242 243 @spin_after
243 244 def apply_sync(self, f, *args, **kwargs):
244 245 """calls f(*args, **kwargs) on remote engines in a blocking manner,
245 246 returning the result.
246 247
247 248 returns: actual result of f(*args, **kwargs)
248 249 """
249 250 return self._really_apply(f, args, kwargs, block=True)
250 251
251 252 #----------------------------------------------------------------
252 253 # wrappers for client and control methods
253 254 #----------------------------------------------------------------
254 255 @sync_results
255 256 def spin(self):
256 257 """spin the client, and sync"""
257 258 self.client.spin()
258 259
259 260 @sync_results
260 261 def wait(self, jobs=None, timeout=-1):
261 262 """waits on one or more `jobs`, for up to `timeout` seconds.
262 263
263 264 Parameters
264 265 ----------
265 266
266 267 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
267 268 ints are indices to self.history
268 269 strs are msg_ids
269 270 default: wait on all outstanding messages
270 271 timeout : float
271 272 a time in seconds, after which to give up.
272 273 default is -1, which means no timeout
273 274
274 275 Returns
275 276 -------
276 277
277 278 True : when all msg_ids are done
278 279 False : timeout reached, some msg_ids still outstanding
279 280 """
280 281 if jobs is None:
281 282 jobs = self.history
282 283 return self.client.wait(jobs, timeout)
283 284
284 285 def abort(self, jobs=None, targets=None, block=None):
285 286 """Abort jobs on my engines.
286 287
287 288 Parameters
288 289 ----------
289 290
290 291 jobs : None, str, list of strs, optional
291 292 if None: abort all jobs.
292 293 else: abort specific msg_id(s).
293 294 """
294 295 block = block if block is not None else self.block
295 296 targets = targets if targets is not None else self.targets
296 297 jobs = jobs if jobs is not None else list(self.outstanding)
297 298
298 299 return self.client.abort(jobs=jobs, targets=targets, block=block)
299 300
300 301 def queue_status(self, targets=None, verbose=False):
301 302 """Fetch the Queue status of my engines"""
302 303 targets = targets if targets is not None else self.targets
303 304 return self.client.queue_status(targets=targets, verbose=verbose)
304 305
305 306 def purge_results(self, jobs=[], targets=[]):
306 307 """Instruct the controller to forget specific results."""
307 308 if targets is None or targets == 'all':
308 309 targets = self.targets
309 310 return self.client.purge_results(jobs=jobs, targets=targets)
310 311
311 312 def shutdown(self, targets=None, restart=False, hub=False, block=None):
312 313 """Terminates one or more engine processes, optionally including the hub.
313 314 """
314 315 block = self.block if block is None else block
315 316 if targets is None or targets == 'all':
316 317 targets = self.targets
317 318 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
318 319
319 320 @spin_after
320 321 def get_result(self, indices_or_msg_ids=None):
321 322 """return one or more results, specified by history index or msg_id.
322 323
323 324 See client.get_result for details.
324 325
325 326 """
326 327
327 328 if indices_or_msg_ids is None:
328 329 indices_or_msg_ids = -1
329 330 if isinstance(indices_or_msg_ids, int):
330 331 indices_or_msg_ids = self.history[indices_or_msg_ids]
331 332 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
332 333 indices_or_msg_ids = list(indices_or_msg_ids)
333 334 for i,index in enumerate(indices_or_msg_ids):
334 335 if isinstance(index, int):
335 336 indices_or_msg_ids[i] = self.history[index]
336 337 return self.client.get_result(indices_or_msg_ids)
337 338
338 339 #-------------------------------------------------------------------
339 340 # Map
340 341 #-------------------------------------------------------------------
341 342
342 343 @sync_results
343 344 def map(self, f, *sequences, **kwargs):
344 345 """override in subclasses"""
345 346 raise NotImplementedError
346 347
347 348 def map_async(self, f, *sequences, **kwargs):
348 349 """Parallel version of builtin `map`, using this view's engines.
349 350
350 351 This is equivalent to map(...block=False)
351 352
352 353 See `self.map` for details.
353 354 """
354 355 if 'block' in kwargs:
355 356 raise TypeError("map_async doesn't take a `block` keyword argument.")
356 357 kwargs['block'] = False
357 358 return self.map(f,*sequences,**kwargs)
358 359
359 360 def map_sync(self, f, *sequences, **kwargs):
360 361 """Parallel version of builtin `map`, using this view's engines.
361 362
362 363 This is equivalent to map(...block=True)
363 364
364 365 See `self.map` for details.
365 366 """
366 367 if 'block' in kwargs:
367 368 raise TypeError("map_sync doesn't take a `block` keyword argument.")
368 369 kwargs['block'] = True
369 370 return self.map(f,*sequences,**kwargs)
370 371
371 372 def imap(self, f, *sequences, **kwargs):
372 373 """Parallel version of `itertools.imap`.
373 374
374 375 See `self.map` for details.
375 376
376 377 """
377 378
378 379 return iter(self.map_async(f,*sequences, **kwargs))
379 380
380 381 #-------------------------------------------------------------------
381 382 # Decorators
382 383 #-------------------------------------------------------------------
383 384
384 385 def remote(self, block=None, **flags):
385 386 """Decorator for making a RemoteFunction"""
386 387 block = self.block if block is None else block
387 388 return remote(self, block=block, **flags)
388 389
389 390 def parallel(self, dist='b', block=None, **flags):
390 391 """Decorator for making a ParallelFunction"""
391 392 block = self.block if block is None else block
392 393 return parallel(self, dist=dist, block=block, **flags)
393 394
394 395 @skip_doctest
395 396 class DirectView(View):
396 397 """Direct Multiplexer View of one or more engines.
397 398
398 399 These are created via indexed access to a client:
399 400
400 401 >>> dv_1 = client[1]
401 402 >>> dv_all = client[:]
402 403 >>> dv_even = client[::2]
403 404 >>> dv_some = client[1:3]
404 405
405 406 This object provides dictionary access to engine namespaces:
406 407
407 408 # push a=5:
408 409 >>> dv['a'] = 5
409 410 # pull 'foo':
410 411 >>> db['foo']
411 412
412 413 """
413 414
414 415 def __init__(self, client=None, socket=None, targets=None):
415 416 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
416 417
417 418 @property
418 419 def importer(self):
419 420 """sync_imports(local=True) as a property.
420 421
421 422 See sync_imports for details.
422 423
423 424 """
424 425 return self.sync_imports(True)
425 426
426 427 @contextmanager
427 428 def sync_imports(self, local=True, quiet=False):
428 429 """Context Manager for performing simultaneous local and remote imports.
429 430
430 431 'import x as y' will *not* work. The 'as y' part will simply be ignored.
431 432
432 433 If `local=True`, then the package will also be imported locally.
433 434
434 435 If `quiet=True`, no output will be produced when attempting remote
435 436 imports.
436 437
437 438 Note that remote-only (`local=False`) imports have not been implemented.
438 439
439 440 >>> with view.sync_imports():
440 441 ... from numpy import recarray
441 442 importing recarray from numpy on engine(s)
442 443
443 444 """
444 445 from IPython.utils.py3compat import builtin_mod
445 446 local_import = builtin_mod.__import__
446 447 modules = set()
447 448 results = []
448 449 @util.interactive
449 450 def remote_import(name, fromlist, level):
450 451 """the function to be passed to apply, that actually performs the import
451 452 on the engine, and loads up the user namespace.
452 453 """
453 454 import sys
454 455 user_ns = globals()
455 456 mod = __import__(name, fromlist=fromlist, level=level)
456 457 if fromlist:
457 458 for key in fromlist:
458 459 user_ns[key] = getattr(mod, key)
459 460 else:
460 461 user_ns[name] = sys.modules[name]
461 462
462 463 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
463 464 """the drop-in replacement for __import__, that optionally imports
464 465 locally as well.
465 466 """
466 467 # don't override nested imports
467 468 save_import = builtin_mod.__import__
468 469 builtin_mod.__import__ = local_import
469 470
470 471 if imp.lock_held():
471 472 # this is a side-effect import, don't do it remotely, or even
472 473 # ignore the local effects
473 474 return local_import(name, globals, locals, fromlist, level)
474 475
475 476 imp.acquire_lock()
476 477 if local:
477 478 mod = local_import(name, globals, locals, fromlist, level)
478 479 else:
479 480 raise NotImplementedError("remote-only imports not yet implemented")
480 481 imp.release_lock()
481 482
482 483 key = name+':'+','.join(fromlist or [])
483 484 if level <= 0 and key not in modules:
484 485 modules.add(key)
485 486 if not quiet:
486 487 if fromlist:
487 488 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
488 489 else:
489 490 print("importing %s on engine(s)"%name)
490 491 results.append(self.apply_async(remote_import, name, fromlist, level))
491 492 # restore override
492 493 builtin_mod.__import__ = save_import
493 494
494 495 return mod
495 496
496 497 # override __import__
497 498 builtin_mod.__import__ = view_import
498 499 try:
499 500 # enter the block
500 501 yield
501 502 except ImportError:
502 503 if local:
503 504 raise
504 505 else:
505 506 # ignore import errors if not doing local imports
506 507 pass
507 508 finally:
508 509 # always restore __import__
509 510 builtin_mod.__import__ = local_import
510 511
511 512 for r in results:
512 513 # raise possible remote ImportErrors here
513 514 r.get()
515
516 def use_dill(self):
517 """Expand serialization support with dill
518
519 adds support for closures, etc.
520
521 This calls IPython.utils.pickleutil.use_dill() here and on each engine.
522 """
523 pickleutil.use_dill()
524 return self.apply(pickleutil.use_dill)
514 525
515 526
516 527 @sync_results
517 528 @save_ids
518 529 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
519 530 """calls f(*args, **kwargs) on remote engines, returning the result.
520 531
521 532 This method sets all of `apply`'s flags via this View's attributes.
522 533
523 534 Parameters
524 535 ----------
525 536
526 537 f : callable
527 538
528 539 args : list [default: empty]
529 540
530 541 kwargs : dict [default: empty]
531 542
532 543 targets : target list [default: self.targets]
533 544 where to run
534 545 block : bool [default: self.block]
535 546 whether to block
536 547 track : bool [default: self.track]
537 548 whether to ask zmq to track the message, for safe non-copying sends
538 549
539 550 Returns
540 551 -------
541 552
542 553 if self.block is False:
543 554 returns AsyncResult
544 555 else:
545 556 returns actual result of f(*args, **kwargs) on the engine(s)
546 557 This will be a list of self.targets is also a list (even length 1), or
547 558 the single result if self.targets is an integer engine id
548 559 """
549 560 args = [] if args is None else args
550 561 kwargs = {} if kwargs is None else kwargs
551 562 block = self.block if block is None else block
552 563 track = self.track if track is None else track
553 564 targets = self.targets if targets is None else targets
554 565
555 566 _idents, _targets = self.client._build_targets(targets)
556 567 msg_ids = []
557 568 trackers = []
558 569 for ident in _idents:
559 570 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
560 571 ident=ident)
561 572 if track:
562 573 trackers.append(msg['tracker'])
563 574 msg_ids.append(msg['header']['msg_id'])
564 575 if isinstance(targets, int):
565 576 msg_ids = msg_ids[0]
566 577 tracker = None if track is False else zmq.MessageTracker(*trackers)
567 578 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets, tracker=tracker)
568 579 if block:
569 580 try:
570 581 return ar.get()
571 582 except KeyboardInterrupt:
572 583 pass
573 584 return ar
574 585
575 586
576 587 @sync_results
577 588 def map(self, f, *sequences, **kwargs):
578 589 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
579 590
580 591 Parallel version of builtin `map`, using this View's `targets`.
581 592
582 593 There will be one task per target, so work will be chunked
583 594 if the sequences are longer than `targets`.
584 595
585 596 Results can be iterated as they are ready, but will become available in chunks.
586 597
587 598 Parameters
588 599 ----------
589 600
590 601 f : callable
591 602 function to be mapped
592 603 *sequences: one or more sequences of matching length
593 604 the sequences to be distributed and passed to `f`
594 605 block : bool
595 606 whether to wait for the result or not [default self.block]
596 607
597 608 Returns
598 609 -------
599 610
600 611 if block=False:
601 612 AsyncMapResult
602 613 An object like AsyncResult, but which reassembles the sequence of results
603 614 into a single list. AsyncMapResults can be iterated through before all
604 615 results are complete.
605 616 else:
606 617 list
607 618 the result of map(f,*sequences)
608 619 """
609 620
610 621 block = kwargs.pop('block', self.block)
611 622 for k in kwargs.keys():
612 623 if k not in ['block', 'track']:
613 624 raise TypeError("invalid keyword arg, %r"%k)
614 625
615 626 assert len(sequences) > 0, "must have some sequences to map onto!"
616 627 pf = ParallelFunction(self, f, block=block, **kwargs)
617 628 return pf.map(*sequences)
618 629
619 630 @sync_results
620 631 @save_ids
621 632 def execute(self, code, silent=True, targets=None, block=None):
622 633 """Executes `code` on `targets` in blocking or nonblocking manner.
623 634
624 635 ``execute`` is always `bound` (affects engine namespace)
625 636
626 637 Parameters
627 638 ----------
628 639
629 640 code : str
630 641 the code string to be executed
631 642 block : bool
632 643 whether or not to wait until done to return
633 644 default: self.block
634 645 """
635 646 block = self.block if block is None else block
636 647 targets = self.targets if targets is None else targets
637 648
638 649 _idents, _targets = self.client._build_targets(targets)
639 650 msg_ids = []
640 651 trackers = []
641 652 for ident in _idents:
642 653 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
643 654 msg_ids.append(msg['header']['msg_id'])
644 655 if isinstance(targets, int):
645 656 msg_ids = msg_ids[0]
646 657 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets)
647 658 if block:
648 659 try:
649 660 ar.get()
650 661 except KeyboardInterrupt:
651 662 pass
652 663 return ar
653 664
654 665 def run(self, filename, targets=None, block=None):
655 666 """Execute contents of `filename` on my engine(s).
656 667
657 668 This simply reads the contents of the file and calls `execute`.
658 669
659 670 Parameters
660 671 ----------
661 672
662 673 filename : str
663 674 The path to the file
664 675 targets : int/str/list of ints/strs
665 676 the engines on which to execute
666 677 default : all
667 678 block : bool
668 679 whether or not to wait until done
669 680 default: self.block
670 681
671 682 """
672 683 with open(filename, 'r') as f:
673 684 # add newline in case of trailing indented whitespace
674 685 # which will cause SyntaxError
675 686 code = f.read()+'\n'
676 687 return self.execute(code, block=block, targets=targets)
677 688
678 689 def update(self, ns):
679 690 """update remote namespace with dict `ns`
680 691
681 692 See `push` for details.
682 693 """
683 694 return self.push(ns, block=self.block, track=self.track)
684 695
685 696 def push(self, ns, targets=None, block=None, track=None):
686 697 """update remote namespace with dict `ns`
687 698
688 699 Parameters
689 700 ----------
690 701
691 702 ns : dict
692 703 dict of keys with which to update engine namespace(s)
693 704 block : bool [default : self.block]
694 705 whether to wait to be notified of engine receipt
695 706
696 707 """
697 708
698 709 block = block if block is not None else self.block
699 710 track = track if track is not None else self.track
700 711 targets = targets if targets is not None else self.targets
701 712 # applier = self.apply_sync if block else self.apply_async
702 713 if not isinstance(ns, dict):
703 714 raise TypeError("Must be a dict, not %s"%type(ns))
704 715 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
705 716
706 717 def get(self, key_s):
707 718 """get object(s) by `key_s` from remote namespace
708 719
709 720 see `pull` for details.
710 721 """
711 722 # block = block if block is not None else self.block
712 723 return self.pull(key_s, block=True)
713 724
714 725 def pull(self, names, targets=None, block=None):
715 726 """get object(s) by `name` from remote namespace
716 727
717 728 will return one object if it is a key.
718 729 can also take a list of keys, in which case it will return a list of objects.
719 730 """
720 731 block = block if block is not None else self.block
721 732 targets = targets if targets is not None else self.targets
722 733 applier = self.apply_sync if block else self.apply_async
723 734 if isinstance(names, string_types):
724 735 pass
725 736 elif isinstance(names, (list,tuple,set)):
726 737 for key in names:
727 738 if not isinstance(key, string_types):
728 739 raise TypeError("keys must be str, not type %r"%type(key))
729 740 else:
730 741 raise TypeError("names must be strs, not %r"%names)
731 742 return self._really_apply(util._pull, (names,), block=block, targets=targets)
732 743
733 744 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
734 745 """
735 746 Partition a Python sequence and send the partitions to a set of engines.
736 747 """
737 748 block = block if block is not None else self.block
738 749 track = track if track is not None else self.track
739 750 targets = targets if targets is not None else self.targets
740 751
741 752 # construct integer ID list:
742 753 targets = self.client._build_targets(targets)[1]
743 754
744 755 mapObject = Map.dists[dist]()
745 756 nparts = len(targets)
746 757 msg_ids = []
747 758 trackers = []
748 759 for index, engineid in enumerate(targets):
749 760 partition = mapObject.getPartition(seq, index, nparts)
750 761 if flatten and len(partition) == 1:
751 762 ns = {key: partition[0]}
752 763 else:
753 764 ns = {key: partition}
754 765 r = self.push(ns, block=False, track=track, targets=engineid)
755 766 msg_ids.extend(r.msg_ids)
756 767 if track:
757 768 trackers.append(r._tracker)
758 769
759 770 if track:
760 771 tracker = zmq.MessageTracker(*trackers)
761 772 else:
762 773 tracker = None
763 774
764 775 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
765 776 if block:
766 777 r.wait()
767 778 else:
768 779 return r
769 780
770 781 @sync_results
771 782 @save_ids
772 783 def gather(self, key, dist='b', targets=None, block=None):
773 784 """
774 785 Gather a partitioned sequence on a set of engines as a single local seq.
775 786 """
776 787 block = block if block is not None else self.block
777 788 targets = targets if targets is not None else self.targets
778 789 mapObject = Map.dists[dist]()
779 790 msg_ids = []
780 791
781 792 # construct integer ID list:
782 793 targets = self.client._build_targets(targets)[1]
783 794
784 795 for index, engineid in enumerate(targets):
785 796 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
786 797
787 798 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
788 799
789 800 if block:
790 801 try:
791 802 return r.get()
792 803 except KeyboardInterrupt:
793 804 pass
794 805 return r
795 806
796 807 def __getitem__(self, key):
797 808 return self.get(key)
798 809
799 810 def __setitem__(self,key, value):
800 811 self.update({key:value})
801 812
802 813 def clear(self, targets=None, block=None):
803 814 """Clear the remote namespaces on my engines."""
804 815 block = block if block is not None else self.block
805 816 targets = targets if targets is not None else self.targets
806 817 return self.client.clear(targets=targets, block=block)
807 818
808 819 #----------------------------------------
809 820 # activate for %px, %autopx, etc. magics
810 821 #----------------------------------------
811 822
812 823 def activate(self, suffix=''):
813 824 """Activate IPython magics associated with this View
814 825
815 826 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
816 827
817 828 Parameters
818 829 ----------
819 830
820 831 suffix: str [default: '']
821 832 The suffix, if any, for the magics. This allows you to have
822 833 multiple views associated with parallel magics at the same time.
823 834
824 835 e.g. ``rc[::2].activate(suffix='_even')`` will give you
825 836 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
826 837 on the even engines.
827 838 """
828 839
829 840 from IPython.parallel.client.magics import ParallelMagics
830 841
831 842 try:
832 843 # This is injected into __builtins__.
833 844 ip = get_ipython()
834 845 except NameError:
835 846 print("The IPython parallel magics (%px, etc.) only work within IPython.")
836 847 return
837 848
838 849 M = ParallelMagics(ip, self, suffix)
839 850 ip.magics_manager.register(M)
840 851
841 852
842 853 @skip_doctest
843 854 class LoadBalancedView(View):
844 855 """An load-balancing View that only executes via the Task scheduler.
845 856
846 857 Load-balanced views can be created with the client's `view` method:
847 858
848 859 >>> v = client.load_balanced_view()
849 860
850 861 or targets can be specified, to restrict the potential destinations:
851 862
852 863 >>> v = client.client.load_balanced_view([1,3])
853 864
854 865 which would restrict loadbalancing to between engines 1 and 3.
855 866
856 867 """
857 868
858 869 follow=Any()
859 870 after=Any()
860 871 timeout=CFloat()
861 872 retries = Integer(0)
862 873
863 874 _task_scheme = Any()
864 875 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
865 876
866 877 def __init__(self, client=None, socket=None, **flags):
867 878 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
868 879 self._task_scheme=client._task_scheme
869 880
870 881 def _validate_dependency(self, dep):
871 882 """validate a dependency.
872 883
873 884 For use in `set_flags`.
874 885 """
875 886 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
876 887 return True
877 888 elif isinstance(dep, (list,set, tuple)):
878 889 for d in dep:
879 890 if not isinstance(d, string_types + (AsyncResult,)):
880 891 return False
881 892 elif isinstance(dep, dict):
882 893 if set(dep.keys()) != set(Dependency().as_dict().keys()):
883 894 return False
884 895 if not isinstance(dep['msg_ids'], list):
885 896 return False
886 897 for d in dep['msg_ids']:
887 898 if not isinstance(d, string_types):
888 899 return False
889 900 else:
890 901 return False
891 902
892 903 return True
893 904
894 905 def _render_dependency(self, dep):
895 906 """helper for building jsonable dependencies from various input forms."""
896 907 if isinstance(dep, Dependency):
897 908 return dep.as_dict()
898 909 elif isinstance(dep, AsyncResult):
899 910 return dep.msg_ids
900 911 elif dep is None:
901 912 return []
902 913 else:
903 914 # pass to Dependency constructor
904 915 return list(Dependency(dep))
905 916
906 917 def set_flags(self, **kwargs):
907 918 """set my attribute flags by keyword.
908 919
909 920 A View is a wrapper for the Client's apply method, but with attributes
910 921 that specify keyword arguments, those attributes can be set by keyword
911 922 argument with this method.
912 923
913 924 Parameters
914 925 ----------
915 926
916 927 block : bool
917 928 whether to wait for results
918 929 track : bool
919 930 whether to create a MessageTracker to allow the user to
920 931 safely edit after arrays and buffers during non-copying
921 932 sends.
922 933
923 934 after : Dependency or collection of msg_ids
924 935 Only for load-balanced execution (targets=None)
925 936 Specify a list of msg_ids as a time-based dependency.
926 937 This job will only be run *after* the dependencies
927 938 have been met.
928 939
929 940 follow : Dependency or collection of msg_ids
930 941 Only for load-balanced execution (targets=None)
931 942 Specify a list of msg_ids as a location-based dependency.
932 943 This job will only be run on an engine where this dependency
933 944 is met.
934 945
935 946 timeout : float/int or None
936 947 Only for load-balanced execution (targets=None)
937 948 Specify an amount of time (in seconds) for the scheduler to
938 949 wait for dependencies to be met before failing with a
939 950 DependencyTimeout.
940 951
941 952 retries : int
942 953 Number of times a task will be retried on failure.
943 954 """
944 955
945 956 super(LoadBalancedView, self).set_flags(**kwargs)
946 957 for name in ('follow', 'after'):
947 958 if name in kwargs:
948 959 value = kwargs[name]
949 960 if self._validate_dependency(value):
950 961 setattr(self, name, value)
951 962 else:
952 963 raise ValueError("Invalid dependency: %r"%value)
953 964 if 'timeout' in kwargs:
954 965 t = kwargs['timeout']
955 966 if not isinstance(t, (int, float, type(None))):
956 967 if (not PY3) and (not isinstance(t, long)):
957 968 raise TypeError("Invalid type for timeout: %r"%type(t))
958 969 if t is not None:
959 970 if t < 0:
960 971 raise ValueError("Invalid timeout: %s"%t)
961 972 self.timeout = t
962 973
963 974 @sync_results
964 975 @save_ids
965 976 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
966 977 after=None, follow=None, timeout=None,
967 978 targets=None, retries=None):
968 979 """calls f(*args, **kwargs) on a remote engine, returning the result.
969 980
970 981 This method temporarily sets all of `apply`'s flags for a single call.
971 982
972 983 Parameters
973 984 ----------
974 985
975 986 f : callable
976 987
977 988 args : list [default: empty]
978 989
979 990 kwargs : dict [default: empty]
980 991
981 992 block : bool [default: self.block]
982 993 whether to block
983 994 track : bool [default: self.track]
984 995 whether to ask zmq to track the message, for safe non-copying sends
985 996
986 997 !!!!!! TODO: THE REST HERE !!!!
987 998
988 999 Returns
989 1000 -------
990 1001
991 1002 if self.block is False:
992 1003 returns AsyncResult
993 1004 else:
994 1005 returns actual result of f(*args, **kwargs) on the engine(s)
995 1006 This will be a list of self.targets is also a list (even length 1), or
996 1007 the single result if self.targets is an integer engine id
997 1008 """
998 1009
999 1010 # validate whether we can run
1000 1011 if self._socket.closed:
1001 1012 msg = "Task farming is disabled"
1002 1013 if self._task_scheme == 'pure':
1003 1014 msg += " because the pure ZMQ scheduler cannot handle"
1004 1015 msg += " disappearing engines."
1005 1016 raise RuntimeError(msg)
1006 1017
1007 1018 if self._task_scheme == 'pure':
1008 1019 # pure zmq scheme doesn't support extra features
1009 1020 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1010 1021 "follow, after, retries, targets, timeout"
1011 1022 if (follow or after or retries or targets or timeout):
1012 1023 # hard fail on Scheduler flags
1013 1024 raise RuntimeError(msg)
1014 1025 if isinstance(f, dependent):
1015 1026 # soft warn on functional dependencies
1016 1027 warnings.warn(msg, RuntimeWarning)
1017 1028
1018 1029 # build args
1019 1030 args = [] if args is None else args
1020 1031 kwargs = {} if kwargs is None else kwargs
1021 1032 block = self.block if block is None else block
1022 1033 track = self.track if track is None else track
1023 1034 after = self.after if after is None else after
1024 1035 retries = self.retries if retries is None else retries
1025 1036 follow = self.follow if follow is None else follow
1026 1037 timeout = self.timeout if timeout is None else timeout
1027 1038 targets = self.targets if targets is None else targets
1028 1039
1029 1040 if not isinstance(retries, int):
1030 1041 raise TypeError('retries must be int, not %r'%type(retries))
1031 1042
1032 1043 if targets is None:
1033 1044 idents = []
1034 1045 else:
1035 1046 idents = self.client._build_targets(targets)[0]
1036 1047 # ensure *not* bytes
1037 1048 idents = [ ident.decode() for ident in idents ]
1038 1049
1039 1050 after = self._render_dependency(after)
1040 1051 follow = self._render_dependency(follow)
1041 1052 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1042 1053
1043 1054 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1044 1055 metadata=metadata)
1045 1056 tracker = None if track is False else msg['tracker']
1046 1057
1047 1058 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
1048 1059
1049 1060 if block:
1050 1061 try:
1051 1062 return ar.get()
1052 1063 except KeyboardInterrupt:
1053 1064 pass
1054 1065 return ar
1055 1066
1056 1067 @sync_results
1057 1068 @save_ids
1058 1069 def map(self, f, *sequences, **kwargs):
1059 1070 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1060 1071
1061 1072 Parallel version of builtin `map`, load-balanced by this View.
1062 1073
1063 1074 `block`, and `chunksize` can be specified by keyword only.
1064 1075
1065 1076 Each `chunksize` elements will be a separate task, and will be
1066 1077 load-balanced. This lets individual elements be available for iteration
1067 1078 as soon as they arrive.
1068 1079
1069 1080 Parameters
1070 1081 ----------
1071 1082
1072 1083 f : callable
1073 1084 function to be mapped
1074 1085 *sequences: one or more sequences of matching length
1075 1086 the sequences to be distributed and passed to `f`
1076 1087 block : bool [default self.block]
1077 1088 whether to wait for the result or not
1078 1089 track : bool
1079 1090 whether to create a MessageTracker to allow the user to
1080 1091 safely edit after arrays and buffers during non-copying
1081 1092 sends.
1082 1093 chunksize : int [default 1]
1083 1094 how many elements should be in each task.
1084 1095 ordered : bool [default True]
1085 1096 Whether the results should be gathered as they arrive, or enforce
1086 1097 the order of submission.
1087 1098
1088 1099 Only applies when iterating through AsyncMapResult as results arrive.
1089 1100 Has no effect when block=True.
1090 1101
1091 1102 Returns
1092 1103 -------
1093 1104
1094 1105 if block=False:
1095 1106 AsyncMapResult
1096 1107 An object like AsyncResult, but which reassembles the sequence of results
1097 1108 into a single list. AsyncMapResults can be iterated through before all
1098 1109 results are complete.
1099 1110 else:
1100 1111 the result of map(f,*sequences)
1101 1112
1102 1113 """
1103 1114
1104 1115 # default
1105 1116 block = kwargs.get('block', self.block)
1106 1117 chunksize = kwargs.get('chunksize', 1)
1107 1118 ordered = kwargs.get('ordered', True)
1108 1119
1109 1120 keyset = set(kwargs.keys())
1110 1121 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1111 1122 if extra_keys:
1112 1123 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1113 1124
1114 1125 assert len(sequences) > 0, "must have some sequences to map onto!"
1115 1126
1116 1127 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1117 1128 return pf.map(*sequences)
1118 1129
1119 1130 __all__ = ['LoadBalancedView', 'DirectView']
General Comments 0
You need to be logged in to leave comments. Login now