##// END OF EJS Templates
add shutdown to Views
MinRK -
Show More
@@ -1,1028 +1,1036 b''
1 1 """Views of remote engines."""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import imp
14 14 import sys
15 15 import warnings
16 16 from contextlib import contextmanager
17 17 from types import ModuleType
18 18
19 19 import zmq
20 20
21 21 from IPython.testing import decorators as testdec
22 22 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat
23 23
24 24 from IPython.external.decorator import decorator
25 25
26 26 from . import map as Map
27 27 from . import util
28 28 from .asyncresult import AsyncResult, AsyncMapResult
29 29 from .dependency import Dependency, dependent
30 30 from .remotefunction import ParallelFunction, parallel, remote
31 31
32 32 #-----------------------------------------------------------------------------
33 33 # Decorators
34 34 #-----------------------------------------------------------------------------
35 35
36 36 @decorator
37 37 def save_ids(f, self, *args, **kwargs):
38 38 """Keep our history and outstanding attributes up to date after a method call."""
39 39 n_previous = len(self.client.history)
40 40 try:
41 41 ret = f(self, *args, **kwargs)
42 42 finally:
43 43 nmsgs = len(self.client.history) - n_previous
44 44 msg_ids = self.client.history[-nmsgs:]
45 45 self.history.extend(msg_ids)
46 46 map(self.outstanding.add, msg_ids)
47 47 return ret
48 48
49 49 @decorator
50 50 def sync_results(f, self, *args, **kwargs):
51 51 """sync relevant results from self.client to our results attribute."""
52 52 ret = f(self, *args, **kwargs)
53 53 delta = self.outstanding.difference(self.client.outstanding)
54 54 completed = self.outstanding.intersection(delta)
55 55 self.outstanding = self.outstanding.difference(completed)
56 56 for msg_id in completed:
57 57 self.results[msg_id] = self.client.results[msg_id]
58 58 return ret
59 59
60 60 @decorator
61 61 def spin_after(f, self, *args, **kwargs):
62 62 """call spin after the method."""
63 63 ret = f(self, *args, **kwargs)
64 64 self.spin()
65 65 return ret
66 66
67 67 #-----------------------------------------------------------------------------
68 68 # Classes
69 69 #-----------------------------------------------------------------------------
70 70
71 71 class View(HasTraits):
72 72 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
73 73
74 74 Don't use this class, use subclasses.
75 75
76 76 Methods
77 77 -------
78 78
79 79 spin
80 80 flushes incoming results and registration state changes
81 81 control methods spin, and requesting `ids` also ensures up to date
82 82
83 83 wait
84 84 wait on one or more msg_ids
85 85
86 86 execution methods
87 87 apply
88 88 legacy: execute, run
89 89
90 90 data movement
91 91 push, pull, scatter, gather
92 92
93 93 query methods
94 94 get_result, queue_status, purge_results, result_status
95 95
96 96 control methods
97 97 abort, shutdown
98 98
99 99 """
100 100 # flags
101 101 block=Bool(False)
102 102 track=Bool(True)
103 103 targets = Any()
104 104
105 105 history=List()
106 106 outstanding = Set()
107 107 results = Dict()
108 108 client = Instance('IPython.parallel.client.Client')
109 109
110 110 _socket = Instance('zmq.Socket')
111 111 _flag_names = List(['targets', 'block', 'track'])
112 112 _targets = Any()
113 113 _idents = Any()
114 114
115 115 def __init__(self, client=None, socket=None, **flags):
116 116 super(View, self).__init__(client=client, _socket=socket)
117 117 self.block = client.block
118 118
119 119 self.set_flags(**flags)
120 120
121 121 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
122 122
123 123
124 124 def __repr__(self):
125 125 strtargets = str(self.targets)
126 126 if len(strtargets) > 16:
127 127 strtargets = strtargets[:12]+'...]'
128 128 return "<%s %s>"%(self.__class__.__name__, strtargets)
129 129
130 130 def set_flags(self, **kwargs):
131 131 """set my attribute flags by keyword.
132 132
133 133 Views determine behavior with a few attributes (`block`, `track`, etc.).
134 134 These attributes can be set all at once by name with this method.
135 135
136 136 Parameters
137 137 ----------
138 138
139 139 block : bool
140 140 whether to wait for results
141 141 track : bool
142 142 whether to create a MessageTracker to allow the user to
143 143 safely edit after arrays and buffers during non-copying
144 144 sends.
145 145 """
146 146 for name, value in kwargs.iteritems():
147 147 if name not in self._flag_names:
148 148 raise KeyError("Invalid name: %r"%name)
149 149 else:
150 150 setattr(self, name, value)
151 151
152 152 @contextmanager
153 153 def temp_flags(self, **kwargs):
154 154 """temporarily set flags, for use in `with` statements.
155 155
156 156 See set_flags for permanent setting of flags
157 157
158 158 Examples
159 159 --------
160 160
161 161 >>> view.track=False
162 162 ...
163 163 >>> with view.temp_flags(track=True):
164 164 ... ar = view.apply(dostuff, my_big_array)
165 165 ... ar.tracker.wait() # wait for send to finish
166 166 >>> view.track
167 167 False
168 168
169 169 """
170 170 # preflight: save flags, and set temporaries
171 171 saved_flags = {}
172 172 for f in self._flag_names:
173 173 saved_flags[f] = getattr(self, f)
174 174 self.set_flags(**kwargs)
175 175 # yield to the with-statement block
176 176 try:
177 177 yield
178 178 finally:
179 179 # postflight: restore saved flags
180 180 self.set_flags(**saved_flags)
181 181
182 182
183 183 #----------------------------------------------------------------
184 184 # apply
185 185 #----------------------------------------------------------------
186 186
187 187 @sync_results
188 188 @save_ids
189 189 def _really_apply(self, f, args, kwargs, block=None, **options):
190 190 """wrapper for client.send_apply_message"""
191 191 raise NotImplementedError("Implement in subclasses")
192 192
193 193 def apply(self, f, *args, **kwargs):
194 194 """calls f(*args, **kwargs) on remote engines, returning the result.
195 195
196 196 This method sets all apply flags via this View's attributes.
197 197
198 198 if self.block is False:
199 199 returns AsyncResult
200 200 else:
201 201 returns actual result of f(*args, **kwargs)
202 202 """
203 203 return self._really_apply(f, args, kwargs)
204 204
205 205 def apply_async(self, f, *args, **kwargs):
206 206 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
207 207
208 208 returns AsyncResult
209 209 """
210 210 return self._really_apply(f, args, kwargs, block=False)
211 211
212 212 @spin_after
213 213 def apply_sync(self, f, *args, **kwargs):
214 214 """calls f(*args, **kwargs) on remote engines in a blocking manner,
215 215 returning the result.
216 216
217 217 returns: actual result of f(*args, **kwargs)
218 218 """
219 219 return self._really_apply(f, args, kwargs, block=True)
220 220
221 221 #----------------------------------------------------------------
222 222 # wrappers for client and control methods
223 223 #----------------------------------------------------------------
224 224 @sync_results
225 225 def spin(self):
226 226 """spin the client, and sync"""
227 227 self.client.spin()
228 228
229 229 @sync_results
230 230 def wait(self, jobs=None, timeout=-1):
231 231 """waits on one or more `jobs`, for up to `timeout` seconds.
232 232
233 233 Parameters
234 234 ----------
235 235
236 236 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
237 237 ints are indices to self.history
238 238 strs are msg_ids
239 239 default: wait on all outstanding messages
240 240 timeout : float
241 241 a time in seconds, after which to give up.
242 242 default is -1, which means no timeout
243 243
244 244 Returns
245 245 -------
246 246
247 247 True : when all msg_ids are done
248 248 False : timeout reached, some msg_ids still outstanding
249 249 """
250 250 if jobs is None:
251 251 jobs = self.history
252 252 return self.client.wait(jobs, timeout)
253 253
254 254 def abort(self, jobs=None, targets=None, block=None):
255 255 """Abort jobs on my engines.
256 256
257 257 Parameters
258 258 ----------
259 259
260 260 jobs : None, str, list of strs, optional
261 261 if None: abort all jobs.
262 262 else: abort specific msg_id(s).
263 263 """
264 264 block = block if block is not None else self.block
265 265 targets = targets if targets is not None else self.targets
266 266 return self.client.abort(jobs=jobs, targets=targets, block=block)
267 267
268 268 def queue_status(self, targets=None, verbose=False):
269 269 """Fetch the Queue status of my engines"""
270 270 targets = targets if targets is not None else self.targets
271 271 return self.client.queue_status(targets=targets, verbose=verbose)
272 272
273 273 def purge_results(self, jobs=[], targets=[]):
274 274 """Instruct the controller to forget specific results."""
275 275 if targets is None or targets == 'all':
276 276 targets = self.targets
277 277 return self.client.purge_results(jobs=jobs, targets=targets)
278 278
279 def shutdown(self, targets=None, restart=False, hub=False, block=None):
280 """Terminates one or more engine processes, optionally including the hub.
281 """
282 block = self.block if block is None else block
283 if targets is None or targets == 'all':
284 targets = self.targets
285 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
286
279 287 @spin_after
280 288 def get_result(self, indices_or_msg_ids=None):
281 289 """return one or more results, specified by history index or msg_id.
282 290
283 291 See client.get_result for details.
284 292
285 293 """
286 294
287 295 if indices_or_msg_ids is None:
288 296 indices_or_msg_ids = -1
289 297 if isinstance(indices_or_msg_ids, int):
290 298 indices_or_msg_ids = self.history[indices_or_msg_ids]
291 299 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
292 300 indices_or_msg_ids = list(indices_or_msg_ids)
293 301 for i,index in enumerate(indices_or_msg_ids):
294 302 if isinstance(index, int):
295 303 indices_or_msg_ids[i] = self.history[index]
296 304 return self.client.get_result(indices_or_msg_ids)
297 305
298 306 #-------------------------------------------------------------------
299 307 # Map
300 308 #-------------------------------------------------------------------
301 309
302 310 def map(self, f, *sequences, **kwargs):
303 311 """override in subclasses"""
304 312 raise NotImplementedError
305 313
306 314 def map_async(self, f, *sequences, **kwargs):
307 315 """Parallel version of builtin `map`, using this view's engines.
308 316
309 317 This is equivalent to map(...block=False)
310 318
311 319 See `self.map` for details.
312 320 """
313 321 if 'block' in kwargs:
314 322 raise TypeError("map_async doesn't take a `block` keyword argument.")
315 323 kwargs['block'] = False
316 324 return self.map(f,*sequences,**kwargs)
317 325
318 326 def map_sync(self, f, *sequences, **kwargs):
319 327 """Parallel version of builtin `map`, using this view's engines.
320 328
321 329 This is equivalent to map(...block=True)
322 330
323 331 See `self.map` for details.
324 332 """
325 333 if 'block' in kwargs:
326 334 raise TypeError("map_sync doesn't take a `block` keyword argument.")
327 335 kwargs['block'] = True
328 336 return self.map(f,*sequences,**kwargs)
329 337
330 338 def imap(self, f, *sequences, **kwargs):
331 339 """Parallel version of `itertools.imap`.
332 340
333 341 See `self.map` for details.
334 342
335 343 """
336 344
337 345 return iter(self.map_async(f,*sequences, **kwargs))
338 346
339 347 #-------------------------------------------------------------------
340 348 # Decorators
341 349 #-------------------------------------------------------------------
342 350
343 351 def remote(self, block=True, **flags):
344 352 """Decorator for making a RemoteFunction"""
345 353 block = self.block if block is None else block
346 354 return remote(self, block=block, **flags)
347 355
348 356 def parallel(self, dist='b', block=None, **flags):
349 357 """Decorator for making a ParallelFunction"""
350 358 block = self.block if block is None else block
351 359 return parallel(self, dist=dist, block=block, **flags)
352 360
353 361 @testdec.skip_doctest
354 362 class DirectView(View):
355 363 """Direct Multiplexer View of one or more engines.
356 364
357 365 These are created via indexed access to a client:
358 366
359 367 >>> dv_1 = client[1]
360 368 >>> dv_all = client[:]
361 369 >>> dv_even = client[::2]
362 370 >>> dv_some = client[1:3]
363 371
364 372 This object provides dictionary access to engine namespaces:
365 373
366 374 # push a=5:
367 375 >>> dv['a'] = 5
368 376 # pull 'foo':
369 377 >>> db['foo']
370 378
371 379 """
372 380
373 381 def __init__(self, client=None, socket=None, targets=None):
374 382 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
375 383
376 384 @property
377 385 def importer(self):
378 386 """sync_imports(local=True) as a property.
379 387
380 388 See sync_imports for details.
381 389
382 390 In [10]: with v.importer:
383 391 ....: import numpy
384 392 ....:
385 393 importing numpy on engine(s)
386 394
387 395 """
388 396 return self.sync_imports(True)
389 397
390 398 @contextmanager
391 399 def sync_imports(self, local=True):
392 400 """Context Manager for performing simultaneous local and remote imports.
393 401
394 402 'import x as y' will *not* work. The 'as y' part will simply be ignored.
395 403
396 404 >>> with view.sync_imports():
397 405 ... from numpy import recarray
398 406 importing recarray from numpy on engine(s)
399 407
400 408 """
401 409 import __builtin__
402 410 local_import = __builtin__.__import__
403 411 modules = set()
404 412 results = []
405 413 @util.interactive
406 414 def remote_import(name, fromlist, level):
407 415 """the function to be passed to apply, that actually performs the import
408 416 on the engine, and loads up the user namespace.
409 417 """
410 418 import sys
411 419 user_ns = globals()
412 420 mod = __import__(name, fromlist=fromlist, level=level)
413 421 if fromlist:
414 422 for key in fromlist:
415 423 user_ns[key] = getattr(mod, key)
416 424 else:
417 425 user_ns[name] = sys.modules[name]
418 426
419 427 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
420 428 """the drop-in replacement for __import__, that optionally imports
421 429 locally as well.
422 430 """
423 431 # don't override nested imports
424 432 save_import = __builtin__.__import__
425 433 __builtin__.__import__ = local_import
426 434
427 435 if imp.lock_held():
428 436 # this is a side-effect import, don't do it remotely, or even
429 437 # ignore the local effects
430 438 return local_import(name, globals, locals, fromlist, level)
431 439
432 440 imp.acquire_lock()
433 441 if local:
434 442 mod = local_import(name, globals, locals, fromlist, level)
435 443 else:
436 444 raise NotImplementedError("remote-only imports not yet implemented")
437 445 imp.release_lock()
438 446
439 447 key = name+':'+','.join(fromlist or [])
440 448 if level == -1 and key not in modules:
441 449 modules.add(key)
442 450 if fromlist:
443 451 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
444 452 else:
445 453 print "importing %s on engine(s)"%name
446 454 results.append(self.apply_async(remote_import, name, fromlist, level))
447 455 # restore override
448 456 __builtin__.__import__ = save_import
449 457
450 458 return mod
451 459
452 460 # override __import__
453 461 __builtin__.__import__ = view_import
454 462 try:
455 463 # enter the block
456 464 yield
457 465 except ImportError:
458 466 if not local:
459 467 # ignore import errors if not doing local imports
460 468 pass
461 469 finally:
462 470 # always restore __import__
463 471 __builtin__.__import__ = local_import
464 472
465 473 for r in results:
466 474 # raise possible remote ImportErrors here
467 475 r.get()
468 476
469 477
470 478 @sync_results
471 479 @save_ids
472 480 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
473 481 """calls f(*args, **kwargs) on remote engines, returning the result.
474 482
475 483 This method sets all of `apply`'s flags via this View's attributes.
476 484
477 485 Parameters
478 486 ----------
479 487
480 488 f : callable
481 489
482 490 args : list [default: empty]
483 491
484 492 kwargs : dict [default: empty]
485 493
486 494 targets : target list [default: self.targets]
487 495 where to run
488 496 block : bool [default: self.block]
489 497 whether to block
490 498 track : bool [default: self.track]
491 499 whether to ask zmq to track the message, for safe non-copying sends
492 500
493 501 Returns
494 502 -------
495 503
496 504 if self.block is False:
497 505 returns AsyncResult
498 506 else:
499 507 returns actual result of f(*args, **kwargs) on the engine(s)
500 508 This will be a list of self.targets is also a list (even length 1), or
501 509 the single result if self.targets is an integer engine id
502 510 """
503 511 args = [] if args is None else args
504 512 kwargs = {} if kwargs is None else kwargs
505 513 block = self.block if block is None else block
506 514 track = self.track if track is None else track
507 515 targets = self.targets if targets is None else targets
508 516
509 517 _idents = self.client._build_targets(targets)[0]
510 518 msg_ids = []
511 519 trackers = []
512 520 for ident in _idents:
513 521 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
514 522 ident=ident)
515 523 if track:
516 524 trackers.append(msg['tracker'])
517 525 msg_ids.append(msg['msg_id'])
518 526 tracker = None if track is False else zmq.MessageTracker(*trackers)
519 527 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
520 528 if block:
521 529 try:
522 530 return ar.get()
523 531 except KeyboardInterrupt:
524 532 pass
525 533 return ar
526 534
527 535 @spin_after
528 536 def map(self, f, *sequences, **kwargs):
529 537 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
530 538
531 539 Parallel version of builtin `map`, using this View's `targets`.
532 540
533 541 There will be one task per target, so work will be chunked
534 542 if the sequences are longer than `targets`.
535 543
536 544 Results can be iterated as they are ready, but will become available in chunks.
537 545
538 546 Parameters
539 547 ----------
540 548
541 549 f : callable
542 550 function to be mapped
543 551 *sequences: one or more sequences of matching length
544 552 the sequences to be distributed and passed to `f`
545 553 block : bool
546 554 whether to wait for the result or not [default self.block]
547 555
548 556 Returns
549 557 -------
550 558
551 559 if block=False:
552 560 AsyncMapResult
553 561 An object like AsyncResult, but which reassembles the sequence of results
554 562 into a single list. AsyncMapResults can be iterated through before all
555 563 results are complete.
556 564 else:
557 565 list
558 566 the result of map(f,*sequences)
559 567 """
560 568
561 569 block = kwargs.pop('block', self.block)
562 570 for k in kwargs.keys():
563 571 if k not in ['block', 'track']:
564 572 raise TypeError("invalid keyword arg, %r"%k)
565 573
566 574 assert len(sequences) > 0, "must have some sequences to map onto!"
567 575 pf = ParallelFunction(self, f, block=block, **kwargs)
568 576 return pf.map(*sequences)
569 577
570 578 def execute(self, code, targets=None, block=None):
571 579 """Executes `code` on `targets` in blocking or nonblocking manner.
572 580
573 581 ``execute`` is always `bound` (affects engine namespace)
574 582
575 583 Parameters
576 584 ----------
577 585
578 586 code : str
579 587 the code string to be executed
580 588 block : bool
581 589 whether or not to wait until done to return
582 590 default: self.block
583 591 """
584 592 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
585 593
586 594 def run(self, filename, targets=None, block=None):
587 595 """Execute contents of `filename` on my engine(s).
588 596
589 597 This simply reads the contents of the file and calls `execute`.
590 598
591 599 Parameters
592 600 ----------
593 601
594 602 filename : str
595 603 The path to the file
596 604 targets : int/str/list of ints/strs
597 605 the engines on which to execute
598 606 default : all
599 607 block : bool
600 608 whether or not to wait until done
601 609 default: self.block
602 610
603 611 """
604 612 with open(filename, 'r') as f:
605 613 # add newline in case of trailing indented whitespace
606 614 # which will cause SyntaxError
607 615 code = f.read()+'\n'
608 616 return self.execute(code, block=block, targets=targets)
609 617
610 618 def update(self, ns):
611 619 """update remote namespace with dict `ns`
612 620
613 621 See `push` for details.
614 622 """
615 623 return self.push(ns, block=self.block, track=self.track)
616 624
617 625 def push(self, ns, targets=None, block=None, track=None):
618 626 """update remote namespace with dict `ns`
619 627
620 628 Parameters
621 629 ----------
622 630
623 631 ns : dict
624 632 dict of keys with which to update engine namespace(s)
625 633 block : bool [default : self.block]
626 634 whether to wait to be notified of engine receipt
627 635
628 636 """
629 637
630 638 block = block if block is not None else self.block
631 639 track = track if track is not None else self.track
632 640 targets = targets if targets is not None else self.targets
633 641 # applier = self.apply_sync if block else self.apply_async
634 642 if not isinstance(ns, dict):
635 643 raise TypeError("Must be a dict, not %s"%type(ns))
636 644 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
637 645
638 646 def get(self, key_s):
639 647 """get object(s) by `key_s` from remote namespace
640 648
641 649 see `pull` for details.
642 650 """
643 651 # block = block if block is not None else self.block
644 652 return self.pull(key_s, block=True)
645 653
646 654 def pull(self, names, targets=None, block=True):
647 655 """get object(s) by `name` from remote namespace
648 656
649 657 will return one object if it is a key.
650 658 can also take a list of keys, in which case it will return a list of objects.
651 659 """
652 660 block = block if block is not None else self.block
653 661 targets = targets if targets is not None else self.targets
654 662 applier = self.apply_sync if block else self.apply_async
655 663 if isinstance(names, basestring):
656 664 pass
657 665 elif isinstance(names, (list,tuple,set)):
658 666 for key in names:
659 667 if not isinstance(key, basestring):
660 668 raise TypeError("keys must be str, not type %r"%type(key))
661 669 else:
662 670 raise TypeError("names must be strs, not %r"%names)
663 671 return self._really_apply(util._pull, (names,), block=block, targets=targets)
664 672
665 673 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
666 674 """
667 675 Partition a Python sequence and send the partitions to a set of engines.
668 676 """
669 677 block = block if block is not None else self.block
670 678 track = track if track is not None else self.track
671 679 targets = targets if targets is not None else self.targets
672 680
673 681 mapObject = Map.dists[dist]()
674 682 nparts = len(targets)
675 683 msg_ids = []
676 684 trackers = []
677 685 for index, engineid in enumerate(targets):
678 686 partition = mapObject.getPartition(seq, index, nparts)
679 687 if flatten and len(partition) == 1:
680 688 ns = {key: partition[0]}
681 689 else:
682 690 ns = {key: partition}
683 691 r = self.push(ns, block=False, track=track, targets=engineid)
684 692 msg_ids.extend(r.msg_ids)
685 693 if track:
686 694 trackers.append(r._tracker)
687 695
688 696 if track:
689 697 tracker = zmq.MessageTracker(*trackers)
690 698 else:
691 699 tracker = None
692 700
693 701 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
694 702 if block:
695 703 r.wait()
696 704 else:
697 705 return r
698 706
699 707 @sync_results
700 708 @save_ids
701 709 def gather(self, key, dist='b', targets=None, block=None):
702 710 """
703 711 Gather a partitioned sequence on a set of engines as a single local seq.
704 712 """
705 713 block = block if block is not None else self.block
706 714 targets = targets if targets is not None else self.targets
707 715 mapObject = Map.dists[dist]()
708 716 msg_ids = []
709 717
710 718 for index, engineid in enumerate(targets):
711 719 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
712 720
713 721 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
714 722
715 723 if block:
716 724 try:
717 725 return r.get()
718 726 except KeyboardInterrupt:
719 727 pass
720 728 return r
721 729
722 730 def __getitem__(self, key):
723 731 return self.get(key)
724 732
725 733 def __setitem__(self,key, value):
726 734 self.update({key:value})
727 735
728 736 def clear(self, targets=None, block=False):
729 737 """Clear the remote namespaces on my engines."""
730 738 block = block if block is not None else self.block
731 739 targets = targets if targets is not None else self.targets
732 740 return self.client.clear(targets=targets, block=block)
733 741
734 742 def kill(self, targets=None, block=True):
735 743 """Kill my engines."""
736 744 block = block if block is not None else self.block
737 745 targets = targets if targets is not None else self.targets
738 746 return self.client.kill(targets=targets, block=block)
739 747
740 748 #----------------------------------------
741 749 # activate for %px,%autopx magics
742 750 #----------------------------------------
743 751 def activate(self):
744 752 """Make this `View` active for parallel magic commands.
745 753
746 754 IPython has a magic command syntax to work with `MultiEngineClient` objects.
747 755 In a given IPython session there is a single active one. While
748 756 there can be many `Views` created and used by the user,
749 757 there is only one active one. The active `View` is used whenever
750 758 the magic commands %px and %autopx are used.
751 759
752 760 The activate() method is called on a given `View` to make it
753 761 active. Once this has been done, the magic commands can be used.
754 762 """
755 763
756 764 try:
757 765 # This is injected into __builtins__.
758 766 ip = get_ipython()
759 767 except NameError:
760 768 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
761 769 else:
762 770 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
763 771 if pmagic is not None:
764 772 pmagic.active_multiengine_client = self
765 773 else:
766 774 print "You must first load the parallelmagic extension " \
767 775 "by doing '%load_ext parallelmagic'"
768 776
769 777
770 778 @testdec.skip_doctest
771 779 class LoadBalancedView(View):
772 780 """An load-balancing View that only executes via the Task scheduler.
773 781
774 782 Load-balanced views can be created with the client's `view` method:
775 783
776 784 >>> v = client.load_balanced_view()
777 785
778 786 or targets can be specified, to restrict the potential destinations:
779 787
780 788 >>> v = client.client.load_balanced_view(([1,3])
781 789
782 790 which would restrict loadbalancing to between engines 1 and 3.
783 791
784 792 """
785 793
786 794 follow=Any()
787 795 after=Any()
788 796 timeout=CFloat()
789 797
790 798 _task_scheme = Any()
791 799 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout'])
792 800
793 801 def __init__(self, client=None, socket=None, **flags):
794 802 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
795 803 self._task_scheme=client._task_scheme
796 804
797 805 def _validate_dependency(self, dep):
798 806 """validate a dependency.
799 807
800 808 For use in `set_flags`.
801 809 """
802 810 if dep is None or isinstance(dep, (str, AsyncResult, Dependency)):
803 811 return True
804 812 elif isinstance(dep, (list,set, tuple)):
805 813 for d in dep:
806 814 if not isinstance(d, (str, AsyncResult)):
807 815 return False
808 816 elif isinstance(dep, dict):
809 817 if set(dep.keys()) != set(Dependency().as_dict().keys()):
810 818 return False
811 819 if not isinstance(dep['msg_ids'], list):
812 820 return False
813 821 for d in dep['msg_ids']:
814 822 if not isinstance(d, str):
815 823 return False
816 824 else:
817 825 return False
818 826
819 827 return True
820 828
821 829 def _render_dependency(self, dep):
822 830 """helper for building jsonable dependencies from various input forms."""
823 831 if isinstance(dep, Dependency):
824 832 return dep.as_dict()
825 833 elif isinstance(dep, AsyncResult):
826 834 return dep.msg_ids
827 835 elif dep is None:
828 836 return []
829 837 else:
830 838 # pass to Dependency constructor
831 839 return list(Dependency(dep))
832 840
833 841 def set_flags(self, **kwargs):
834 842 """set my attribute flags by keyword.
835 843
836 844 A View is a wrapper for the Client's apply method, but with attributes
837 845 that specify keyword arguments, those attributes can be set by keyword
838 846 argument with this method.
839 847
840 848 Parameters
841 849 ----------
842 850
843 851 block : bool
844 852 whether to wait for results
845 853 track : bool
846 854 whether to create a MessageTracker to allow the user to
847 855 safely edit after arrays and buffers during non-copying
848 856 sends.
849 857 #
850 858 after : Dependency or collection of msg_ids
851 859 Only for load-balanced execution (targets=None)
852 860 Specify a list of msg_ids as a time-based dependency.
853 861 This job will only be run *after* the dependencies
854 862 have been met.
855 863
856 864 follow : Dependency or collection of msg_ids
857 865 Only for load-balanced execution (targets=None)
858 866 Specify a list of msg_ids as a location-based dependency.
859 867 This job will only be run on an engine where this dependency
860 868 is met.
861 869
862 870 timeout : float/int or None
863 871 Only for load-balanced execution (targets=None)
864 872 Specify an amount of time (in seconds) for the scheduler to
865 873 wait for dependencies to be met before failing with a
866 874 DependencyTimeout.
867 875 """
868 876
869 877 super(LoadBalancedView, self).set_flags(**kwargs)
870 878 for name in ('follow', 'after'):
871 879 if name in kwargs:
872 880 value = kwargs[name]
873 881 if self._validate_dependency(value):
874 882 setattr(self, name, value)
875 883 else:
876 884 raise ValueError("Invalid dependency: %r"%value)
877 885 if 'timeout' in kwargs:
878 886 t = kwargs['timeout']
879 887 if not isinstance(t, (int, long, float, type(None))):
880 888 raise TypeError("Invalid type for timeout: %r"%type(t))
881 889 if t is not None:
882 890 if t < 0:
883 891 raise ValueError("Invalid timeout: %s"%t)
884 892 self.timeout = t
885 893
886 894 @sync_results
887 895 @save_ids
888 896 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
889 897 after=None, follow=None, timeout=None,
890 898 targets=None):
891 899 """calls f(*args, **kwargs) on a remote engine, returning the result.
892 900
893 901 This method temporarily sets all of `apply`'s flags for a single call.
894 902
895 903 Parameters
896 904 ----------
897 905
898 906 f : callable
899 907
900 908 args : list [default: empty]
901 909
902 910 kwargs : dict [default: empty]
903 911
904 912 block : bool [default: self.block]
905 913 whether to block
906 914 track : bool [default: self.track]
907 915 whether to ask zmq to track the message, for safe non-copying sends
908 916
909 917 !!!!!! TODO: THE REST HERE !!!!
910 918
911 919 Returns
912 920 -------
913 921
914 922 if self.block is False:
915 923 returns AsyncResult
916 924 else:
917 925 returns actual result of f(*args, **kwargs) on the engine(s)
918 926 This will be a list of self.targets is also a list (even length 1), or
919 927 the single result if self.targets is an integer engine id
920 928 """
921 929
922 930 # validate whether we can run
923 931 if self._socket.closed:
924 932 msg = "Task farming is disabled"
925 933 if self._task_scheme == 'pure':
926 934 msg += " because the pure ZMQ scheduler cannot handle"
927 935 msg += " disappearing engines."
928 936 raise RuntimeError(msg)
929 937
930 938 if self._task_scheme == 'pure':
931 939 # pure zmq scheme doesn't support dependencies
932 940 msg = "Pure ZMQ scheduler doesn't support dependencies"
933 941 if (follow or after):
934 942 # hard fail on DAG dependencies
935 943 raise RuntimeError(msg)
936 944 if isinstance(f, dependent):
937 945 # soft warn on functional dependencies
938 946 warnings.warn(msg, RuntimeWarning)
939 947
940 948 # build args
941 949 args = [] if args is None else args
942 950 kwargs = {} if kwargs is None else kwargs
943 951 block = self.block if block is None else block
944 952 track = self.track if track is None else track
945 953 after = self.after if after is None else after
946 954 follow = self.follow if follow is None else follow
947 955 timeout = self.timeout if timeout is None else timeout
948 956 targets = self.targets if targets is None else targets
949 957
950 958 if targets is None:
951 959 idents = []
952 960 else:
953 961 idents = self.client._build_targets(targets)[0]
954 962
955 963 after = self._render_dependency(after)
956 964 follow = self._render_dependency(follow)
957 965 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
958 966
959 967 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
960 968 subheader=subheader)
961 969 tracker = None if track is False else msg['tracker']
962 970
963 971 ar = AsyncResult(self.client, msg['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
964 972
965 973 if block:
966 974 try:
967 975 return ar.get()
968 976 except KeyboardInterrupt:
969 977 pass
970 978 return ar
971 979
972 980 @spin_after
973 981 @save_ids
974 982 def map(self, f, *sequences, **kwargs):
975 983 """view.map(f, *sequences, block=self.block, chunksize=1) => list|AsyncMapResult
976 984
977 985 Parallel version of builtin `map`, load-balanced by this View.
978 986
979 987 `block`, and `chunksize` can be specified by keyword only.
980 988
981 989 Each `chunksize` elements will be a separate task, and will be
982 990 load-balanced. This lets individual elements be available for iteration
983 991 as soon as they arrive.
984 992
985 993 Parameters
986 994 ----------
987 995
988 996 f : callable
989 997 function to be mapped
990 998 *sequences: one or more sequences of matching length
991 999 the sequences to be distributed and passed to `f`
992 1000 block : bool
993 1001 whether to wait for the result or not [default self.block]
994 1002 track : bool
995 1003 whether to create a MessageTracker to allow the user to
996 1004 safely edit after arrays and buffers during non-copying
997 1005 sends.
998 1006 chunksize : int
999 1007 how many elements should be in each task [default 1]
1000 1008
1001 1009 Returns
1002 1010 -------
1003 1011
1004 1012 if block=False:
1005 1013 AsyncMapResult
1006 1014 An object like AsyncResult, but which reassembles the sequence of results
1007 1015 into a single list. AsyncMapResults can be iterated through before all
1008 1016 results are complete.
1009 1017 else:
1010 1018 the result of map(f,*sequences)
1011 1019
1012 1020 """
1013 1021
1014 1022 # default
1015 1023 block = kwargs.get('block', self.block)
1016 1024 chunksize = kwargs.get('chunksize', 1)
1017 1025
1018 1026 keyset = set(kwargs.keys())
1019 1027 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1020 1028 if extra_keys:
1021 1029 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1022 1030
1023 1031 assert len(sequences) > 0, "must have some sequences to map onto!"
1024 1032
1025 1033 pf = ParallelFunction(self, f, block=block, chunksize=chunksize)
1026 1034 return pf.map(*sequences)
1027 1035
1028 1036 __all__ = ['LoadBalancedView', 'DirectView'] No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now