##// END OF EJS Templates
fix scatter/gather with targets='all'...
MinRK -
Show More
@@ -1,1069 +1,1075 b''
1 1 """Views of remote engines.
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import imp
19 19 import sys
20 20 import warnings
21 21 from contextlib import contextmanager
22 22 from types import ModuleType
23 23
24 24 import zmq
25 25
26 26 from IPython.testing.skipdoctest import skip_doctest
27 27 from IPython.utils.traitlets import (
28 28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
29 29 )
30 30 from IPython.external.decorator import decorator
31 31
32 32 from IPython.parallel import util
33 33 from IPython.parallel.controller.dependency import Dependency, dependent
34 34
35 35 from . import map as Map
36 36 from .asyncresult import AsyncResult, AsyncMapResult
37 37 from .remotefunction import ParallelFunction, parallel, remote, getname
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Decorators
41 41 #-----------------------------------------------------------------------------
42 42
43 43 @decorator
44 44 def save_ids(f, self, *args, **kwargs):
45 45 """Keep our history and outstanding attributes up to date after a method call."""
46 46 n_previous = len(self.client.history)
47 47 try:
48 48 ret = f(self, *args, **kwargs)
49 49 finally:
50 50 nmsgs = len(self.client.history) - n_previous
51 51 msg_ids = self.client.history[-nmsgs:]
52 52 self.history.extend(msg_ids)
53 53 map(self.outstanding.add, msg_ids)
54 54 return ret
55 55
56 56 @decorator
57 57 def sync_results(f, self, *args, **kwargs):
58 58 """sync relevant results from self.client to our results attribute."""
59 59 ret = f(self, *args, **kwargs)
60 60 delta = self.outstanding.difference(self.client.outstanding)
61 61 completed = self.outstanding.intersection(delta)
62 62 self.outstanding = self.outstanding.difference(completed)
63 63 for msg_id in completed:
64 64 self.results[msg_id] = self.client.results[msg_id]
65 65 return ret
66 66
67 67 @decorator
68 68 def spin_after(f, self, *args, **kwargs):
69 69 """call spin after the method."""
70 70 ret = f(self, *args, **kwargs)
71 71 self.spin()
72 72 return ret
73 73
74 74 #-----------------------------------------------------------------------------
75 75 # Classes
76 76 #-----------------------------------------------------------------------------
77 77
78 78 @skip_doctest
79 79 class View(HasTraits):
80 80 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
81 81
82 82 Don't use this class, use subclasses.
83 83
84 84 Methods
85 85 -------
86 86
87 87 spin
88 88 flushes incoming results and registration state changes
89 89 control methods spin, and requesting `ids` also ensures up to date
90 90
91 91 wait
92 92 wait on one or more msg_ids
93 93
94 94 execution methods
95 95 apply
96 96 legacy: execute, run
97 97
98 98 data movement
99 99 push, pull, scatter, gather
100 100
101 101 query methods
102 102 get_result, queue_status, purge_results, result_status
103 103
104 104 control methods
105 105 abort, shutdown
106 106
107 107 """
108 108 # flags
109 109 block=Bool(False)
110 110 track=Bool(True)
111 111 targets = Any()
112 112
113 113 history=List()
114 114 outstanding = Set()
115 115 results = Dict()
116 116 client = Instance('IPython.parallel.Client')
117 117
118 118 _socket = Instance('zmq.Socket')
119 119 _flag_names = List(['targets', 'block', 'track'])
120 120 _targets = Any()
121 121 _idents = Any()
122 122
123 123 def __init__(self, client=None, socket=None, **flags):
124 124 super(View, self).__init__(client=client, _socket=socket)
125 125 self.block = client.block
126 126
127 127 self.set_flags(**flags)
128 128
129 129 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
130 130
131 131
132 132 def __repr__(self):
133 133 strtargets = str(self.targets)
134 134 if len(strtargets) > 16:
135 135 strtargets = strtargets[:12]+'...]'
136 136 return "<%s %s>"%(self.__class__.__name__, strtargets)
137 137
138 138 def set_flags(self, **kwargs):
139 139 """set my attribute flags by keyword.
140 140
141 141 Views determine behavior with a few attributes (`block`, `track`, etc.).
142 142 These attributes can be set all at once by name with this method.
143 143
144 144 Parameters
145 145 ----------
146 146
147 147 block : bool
148 148 whether to wait for results
149 149 track : bool
150 150 whether to create a MessageTracker to allow the user to
151 151 safely edit after arrays and buffers during non-copying
152 152 sends.
153 153 """
154 154 for name, value in kwargs.iteritems():
155 155 if name not in self._flag_names:
156 156 raise KeyError("Invalid name: %r"%name)
157 157 else:
158 158 setattr(self, name, value)
159 159
160 160 @contextmanager
161 161 def temp_flags(self, **kwargs):
162 162 """temporarily set flags, for use in `with` statements.
163 163
164 164 See set_flags for permanent setting of flags
165 165
166 166 Examples
167 167 --------
168 168
169 169 >>> view.track=False
170 170 ...
171 171 >>> with view.temp_flags(track=True):
172 172 ... ar = view.apply(dostuff, my_big_array)
173 173 ... ar.tracker.wait() # wait for send to finish
174 174 >>> view.track
175 175 False
176 176
177 177 """
178 178 # preflight: save flags, and set temporaries
179 179 saved_flags = {}
180 180 for f in self._flag_names:
181 181 saved_flags[f] = getattr(self, f)
182 182 self.set_flags(**kwargs)
183 183 # yield to the with-statement block
184 184 try:
185 185 yield
186 186 finally:
187 187 # postflight: restore saved flags
188 188 self.set_flags(**saved_flags)
189 189
190 190
191 191 #----------------------------------------------------------------
192 192 # apply
193 193 #----------------------------------------------------------------
194 194
195 195 @sync_results
196 196 @save_ids
197 197 def _really_apply(self, f, args, kwargs, block=None, **options):
198 198 """wrapper for client.send_apply_message"""
199 199 raise NotImplementedError("Implement in subclasses")
200 200
201 201 def apply(self, f, *args, **kwargs):
202 202 """calls f(*args, **kwargs) on remote engines, returning the result.
203 203
204 204 This method sets all apply flags via this View's attributes.
205 205
206 206 if self.block is False:
207 207 returns AsyncResult
208 208 else:
209 209 returns actual result of f(*args, **kwargs)
210 210 """
211 211 return self._really_apply(f, args, kwargs)
212 212
213 213 def apply_async(self, f, *args, **kwargs):
214 214 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
215 215
216 216 returns AsyncResult
217 217 """
218 218 return self._really_apply(f, args, kwargs, block=False)
219 219
220 220 @spin_after
221 221 def apply_sync(self, f, *args, **kwargs):
222 222 """calls f(*args, **kwargs) on remote engines in a blocking manner,
223 223 returning the result.
224 224
225 225 returns: actual result of f(*args, **kwargs)
226 226 """
227 227 return self._really_apply(f, args, kwargs, block=True)
228 228
229 229 #----------------------------------------------------------------
230 230 # wrappers for client and control methods
231 231 #----------------------------------------------------------------
232 232 @sync_results
233 233 def spin(self):
234 234 """spin the client, and sync"""
235 235 self.client.spin()
236 236
237 237 @sync_results
238 238 def wait(self, jobs=None, timeout=-1):
239 239 """waits on one or more `jobs`, for up to `timeout` seconds.
240 240
241 241 Parameters
242 242 ----------
243 243
244 244 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
245 245 ints are indices to self.history
246 246 strs are msg_ids
247 247 default: wait on all outstanding messages
248 248 timeout : float
249 249 a time in seconds, after which to give up.
250 250 default is -1, which means no timeout
251 251
252 252 Returns
253 253 -------
254 254
255 255 True : when all msg_ids are done
256 256 False : timeout reached, some msg_ids still outstanding
257 257 """
258 258 if jobs is None:
259 259 jobs = self.history
260 260 return self.client.wait(jobs, timeout)
261 261
262 262 def abort(self, jobs=None, targets=None, block=None):
263 263 """Abort jobs on my engines.
264 264
265 265 Parameters
266 266 ----------
267 267
268 268 jobs : None, str, list of strs, optional
269 269 if None: abort all jobs.
270 270 else: abort specific msg_id(s).
271 271 """
272 272 block = block if block is not None else self.block
273 273 targets = targets if targets is not None else self.targets
274 274 jobs = jobs if jobs is not None else list(self.outstanding)
275 275
276 276 return self.client.abort(jobs=jobs, targets=targets, block=block)
277 277
278 278 def queue_status(self, targets=None, verbose=False):
279 279 """Fetch the Queue status of my engines"""
280 280 targets = targets if targets is not None else self.targets
281 281 return self.client.queue_status(targets=targets, verbose=verbose)
282 282
283 283 def purge_results(self, jobs=[], targets=[]):
284 284 """Instruct the controller to forget specific results."""
285 285 if targets is None or targets == 'all':
286 286 targets = self.targets
287 287 return self.client.purge_results(jobs=jobs, targets=targets)
288 288
289 289 def shutdown(self, targets=None, restart=False, hub=False, block=None):
290 290 """Terminates one or more engine processes, optionally including the hub.
291 291 """
292 292 block = self.block if block is None else block
293 293 if targets is None or targets == 'all':
294 294 targets = self.targets
295 295 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
296 296
297 297 @spin_after
298 298 def get_result(self, indices_or_msg_ids=None):
299 299 """return one or more results, specified by history index or msg_id.
300 300
301 301 See client.get_result for details.
302 302
303 303 """
304 304
305 305 if indices_or_msg_ids is None:
306 306 indices_or_msg_ids = -1
307 307 if isinstance(indices_or_msg_ids, int):
308 308 indices_or_msg_ids = self.history[indices_or_msg_ids]
309 309 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
310 310 indices_or_msg_ids = list(indices_or_msg_ids)
311 311 for i,index in enumerate(indices_or_msg_ids):
312 312 if isinstance(index, int):
313 313 indices_or_msg_ids[i] = self.history[index]
314 314 return self.client.get_result(indices_or_msg_ids)
315 315
316 316 #-------------------------------------------------------------------
317 317 # Map
318 318 #-------------------------------------------------------------------
319 319
320 320 def map(self, f, *sequences, **kwargs):
321 321 """override in subclasses"""
322 322 raise NotImplementedError
323 323
324 324 def map_async(self, f, *sequences, **kwargs):
325 325 """Parallel version of builtin `map`, using this view's engines.
326 326
327 327 This is equivalent to map(...block=False)
328 328
329 329 See `self.map` for details.
330 330 """
331 331 if 'block' in kwargs:
332 332 raise TypeError("map_async doesn't take a `block` keyword argument.")
333 333 kwargs['block'] = False
334 334 return self.map(f,*sequences,**kwargs)
335 335
336 336 def map_sync(self, f, *sequences, **kwargs):
337 337 """Parallel version of builtin `map`, using this view's engines.
338 338
339 339 This is equivalent to map(...block=True)
340 340
341 341 See `self.map` for details.
342 342 """
343 343 if 'block' in kwargs:
344 344 raise TypeError("map_sync doesn't take a `block` keyword argument.")
345 345 kwargs['block'] = True
346 346 return self.map(f,*sequences,**kwargs)
347 347
348 348 def imap(self, f, *sequences, **kwargs):
349 349 """Parallel version of `itertools.imap`.
350 350
351 351 See `self.map` for details.
352 352
353 353 """
354 354
355 355 return iter(self.map_async(f,*sequences, **kwargs))
356 356
357 357 #-------------------------------------------------------------------
358 358 # Decorators
359 359 #-------------------------------------------------------------------
360 360
361 361 def remote(self, block=True, **flags):
362 362 """Decorator for making a RemoteFunction"""
363 363 block = self.block if block is None else block
364 364 return remote(self, block=block, **flags)
365 365
366 366 def parallel(self, dist='b', block=None, **flags):
367 367 """Decorator for making a ParallelFunction"""
368 368 block = self.block if block is None else block
369 369 return parallel(self, dist=dist, block=block, **flags)
370 370
371 371 @skip_doctest
372 372 class DirectView(View):
373 373 """Direct Multiplexer View of one or more engines.
374 374
375 375 These are created via indexed access to a client:
376 376
377 377 >>> dv_1 = client[1]
378 378 >>> dv_all = client[:]
379 379 >>> dv_even = client[::2]
380 380 >>> dv_some = client[1:3]
381 381
382 382 This object provides dictionary access to engine namespaces:
383 383
384 384 # push a=5:
385 385 >>> dv['a'] = 5
386 386 # pull 'foo':
387 387 >>> db['foo']
388 388
389 389 """
390 390
391 391 def __init__(self, client=None, socket=None, targets=None):
392 392 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
393 393
394 394 @property
395 395 def importer(self):
396 396 """sync_imports(local=True) as a property.
397 397
398 398 See sync_imports for details.
399 399
400 400 """
401 401 return self.sync_imports(True)
402 402
403 403 @contextmanager
404 404 def sync_imports(self, local=True, quiet=False):
405 405 """Context Manager for performing simultaneous local and remote imports.
406 406
407 407 'import x as y' will *not* work. The 'as y' part will simply be ignored.
408 408
409 409 If `local=True`, then the package will also be imported locally.
410 410
411 411 If `quiet=True`, no output will be produced when attempting remote
412 412 imports.
413 413
414 414 Note that remote-only (`local=False`) imports have not been implemented.
415 415
416 416 >>> with view.sync_imports():
417 417 ... from numpy import recarray
418 418 importing recarray from numpy on engine(s)
419 419
420 420 """
421 421 import __builtin__
422 422 local_import = __builtin__.__import__
423 423 modules = set()
424 424 results = []
425 425 @util.interactive
426 426 def remote_import(name, fromlist, level):
427 427 """the function to be passed to apply, that actually performs the import
428 428 on the engine, and loads up the user namespace.
429 429 """
430 430 import sys
431 431 user_ns = globals()
432 432 mod = __import__(name, fromlist=fromlist, level=level)
433 433 if fromlist:
434 434 for key in fromlist:
435 435 user_ns[key] = getattr(mod, key)
436 436 else:
437 437 user_ns[name] = sys.modules[name]
438 438
439 439 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
440 440 """the drop-in replacement for __import__, that optionally imports
441 441 locally as well.
442 442 """
443 443 # don't override nested imports
444 444 save_import = __builtin__.__import__
445 445 __builtin__.__import__ = local_import
446 446
447 447 if imp.lock_held():
448 448 # this is a side-effect import, don't do it remotely, or even
449 449 # ignore the local effects
450 450 return local_import(name, globals, locals, fromlist, level)
451 451
452 452 imp.acquire_lock()
453 453 if local:
454 454 mod = local_import(name, globals, locals, fromlist, level)
455 455 else:
456 456 raise NotImplementedError("remote-only imports not yet implemented")
457 457 imp.release_lock()
458 458
459 459 key = name+':'+','.join(fromlist or [])
460 460 if level == -1 and key not in modules:
461 461 modules.add(key)
462 462 if not quiet:
463 463 if fromlist:
464 464 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
465 465 else:
466 466 print "importing %s on engine(s)"%name
467 467 results.append(self.apply_async(remote_import, name, fromlist, level))
468 468 # restore override
469 469 __builtin__.__import__ = save_import
470 470
471 471 return mod
472 472
473 473 # override __import__
474 474 __builtin__.__import__ = view_import
475 475 try:
476 476 # enter the block
477 477 yield
478 478 except ImportError:
479 479 if local:
480 480 raise
481 481 else:
482 482 # ignore import errors if not doing local imports
483 483 pass
484 484 finally:
485 485 # always restore __import__
486 486 __builtin__.__import__ = local_import
487 487
488 488 for r in results:
489 489 # raise possible remote ImportErrors here
490 490 r.get()
491 491
492 492
493 493 @sync_results
494 494 @save_ids
495 495 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
496 496 """calls f(*args, **kwargs) on remote engines, returning the result.
497 497
498 498 This method sets all of `apply`'s flags via this View's attributes.
499 499
500 500 Parameters
501 501 ----------
502 502
503 503 f : callable
504 504
505 505 args : list [default: empty]
506 506
507 507 kwargs : dict [default: empty]
508 508
509 509 targets : target list [default: self.targets]
510 510 where to run
511 511 block : bool [default: self.block]
512 512 whether to block
513 513 track : bool [default: self.track]
514 514 whether to ask zmq to track the message, for safe non-copying sends
515 515
516 516 Returns
517 517 -------
518 518
519 519 if self.block is False:
520 520 returns AsyncResult
521 521 else:
522 522 returns actual result of f(*args, **kwargs) on the engine(s)
523 523 This will be a list of self.targets is also a list (even length 1), or
524 524 the single result if self.targets is an integer engine id
525 525 """
526 526 args = [] if args is None else args
527 527 kwargs = {} if kwargs is None else kwargs
528 528 block = self.block if block is None else block
529 529 track = self.track if track is None else track
530 530 targets = self.targets if targets is None else targets
531 531
532 532 _idents = self.client._build_targets(targets)[0]
533 533 msg_ids = []
534 534 trackers = []
535 535 for ident in _idents:
536 536 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
537 537 ident=ident)
538 538 if track:
539 539 trackers.append(msg['tracker'])
540 540 msg_ids.append(msg['header']['msg_id'])
541 541 tracker = None if track is False else zmq.MessageTracker(*trackers)
542 542 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker)
543 543 if block:
544 544 try:
545 545 return ar.get()
546 546 except KeyboardInterrupt:
547 547 pass
548 548 return ar
549 549
550 550 @spin_after
551 551 def map(self, f, *sequences, **kwargs):
552 552 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
553 553
554 554 Parallel version of builtin `map`, using this View's `targets`.
555 555
556 556 There will be one task per target, so work will be chunked
557 557 if the sequences are longer than `targets`.
558 558
559 559 Results can be iterated as they are ready, but will become available in chunks.
560 560
561 561 Parameters
562 562 ----------
563 563
564 564 f : callable
565 565 function to be mapped
566 566 *sequences: one or more sequences of matching length
567 567 the sequences to be distributed and passed to `f`
568 568 block : bool
569 569 whether to wait for the result or not [default self.block]
570 570
571 571 Returns
572 572 -------
573 573
574 574 if block=False:
575 575 AsyncMapResult
576 576 An object like AsyncResult, but which reassembles the sequence of results
577 577 into a single list. AsyncMapResults can be iterated through before all
578 578 results are complete.
579 579 else:
580 580 list
581 581 the result of map(f,*sequences)
582 582 """
583 583
584 584 block = kwargs.pop('block', self.block)
585 585 for k in kwargs.keys():
586 586 if k not in ['block', 'track']:
587 587 raise TypeError("invalid keyword arg, %r"%k)
588 588
589 589 assert len(sequences) > 0, "must have some sequences to map onto!"
590 590 pf = ParallelFunction(self, f, block=block, **kwargs)
591 591 return pf.map(*sequences)
592 592
593 593 def execute(self, code, targets=None, block=None):
594 594 """Executes `code` on `targets` in blocking or nonblocking manner.
595 595
596 596 ``execute`` is always `bound` (affects engine namespace)
597 597
598 598 Parameters
599 599 ----------
600 600
601 601 code : str
602 602 the code string to be executed
603 603 block : bool
604 604 whether or not to wait until done to return
605 605 default: self.block
606 606 """
607 607 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
608 608
609 609 def run(self, filename, targets=None, block=None):
610 610 """Execute contents of `filename` on my engine(s).
611 611
612 612 This simply reads the contents of the file and calls `execute`.
613 613
614 614 Parameters
615 615 ----------
616 616
617 617 filename : str
618 618 The path to the file
619 619 targets : int/str/list of ints/strs
620 620 the engines on which to execute
621 621 default : all
622 622 block : bool
623 623 whether or not to wait until done
624 624 default: self.block
625 625
626 626 """
627 627 with open(filename, 'r') as f:
628 628 # add newline in case of trailing indented whitespace
629 629 # which will cause SyntaxError
630 630 code = f.read()+'\n'
631 631 return self.execute(code, block=block, targets=targets)
632 632
633 633 def update(self, ns):
634 634 """update remote namespace with dict `ns`
635 635
636 636 See `push` for details.
637 637 """
638 638 return self.push(ns, block=self.block, track=self.track)
639 639
640 640 def push(self, ns, targets=None, block=None, track=None):
641 641 """update remote namespace with dict `ns`
642 642
643 643 Parameters
644 644 ----------
645 645
646 646 ns : dict
647 647 dict of keys with which to update engine namespace(s)
648 648 block : bool [default : self.block]
649 649 whether to wait to be notified of engine receipt
650 650
651 651 """
652 652
653 653 block = block if block is not None else self.block
654 654 track = track if track is not None else self.track
655 655 targets = targets if targets is not None else self.targets
656 656 # applier = self.apply_sync if block else self.apply_async
657 657 if not isinstance(ns, dict):
658 658 raise TypeError("Must be a dict, not %s"%type(ns))
659 659 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
660 660
661 661 def get(self, key_s):
662 662 """get object(s) by `key_s` from remote namespace
663 663
664 664 see `pull` for details.
665 665 """
666 666 # block = block if block is not None else self.block
667 667 return self.pull(key_s, block=True)
668 668
669 669 def pull(self, names, targets=None, block=None):
670 670 """get object(s) by `name` from remote namespace
671 671
672 672 will return one object if it is a key.
673 673 can also take a list of keys, in which case it will return a list of objects.
674 674 """
675 675 block = block if block is not None else self.block
676 676 targets = targets if targets is not None else self.targets
677 677 applier = self.apply_sync if block else self.apply_async
678 678 if isinstance(names, basestring):
679 679 pass
680 680 elif isinstance(names, (list,tuple,set)):
681 681 for key in names:
682 682 if not isinstance(key, basestring):
683 683 raise TypeError("keys must be str, not type %r"%type(key))
684 684 else:
685 685 raise TypeError("names must be strs, not %r"%names)
686 686 return self._really_apply(util._pull, (names,), block=block, targets=targets)
687 687
688 688 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
689 689 """
690 690 Partition a Python sequence and send the partitions to a set of engines.
691 691 """
692 692 block = block if block is not None else self.block
693 693 track = track if track is not None else self.track
694 694 targets = targets if targets is not None else self.targets
695
696 # construct integer ID list:
697 targets = self.client._build_targets(targets)[1]
695 698
696 699 mapObject = Map.dists[dist]()
697 700 nparts = len(targets)
698 701 msg_ids = []
699 702 trackers = []
700 703 for index, engineid in enumerate(targets):
701 704 partition = mapObject.getPartition(seq, index, nparts)
702 705 if flatten and len(partition) == 1:
703 706 ns = {key: partition[0]}
704 707 else:
705 708 ns = {key: partition}
706 709 r = self.push(ns, block=False, track=track, targets=engineid)
707 710 msg_ids.extend(r.msg_ids)
708 711 if track:
709 712 trackers.append(r._tracker)
710 713
711 714 if track:
712 715 tracker = zmq.MessageTracker(*trackers)
713 716 else:
714 717 tracker = None
715 718
716 719 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
717 720 if block:
718 721 r.wait()
719 722 else:
720 723 return r
721 724
722 725 @sync_results
723 726 @save_ids
724 727 def gather(self, key, dist='b', targets=None, block=None):
725 728 """
726 729 Gather a partitioned sequence on a set of engines as a single local seq.
727 730 """
728 731 block = block if block is not None else self.block
729 732 targets = targets if targets is not None else self.targets
730 733 mapObject = Map.dists[dist]()
731 734 msg_ids = []
732 735
736 # construct integer ID list:
737 targets = self.client._build_targets(targets)[1]
738
733 739 for index, engineid in enumerate(targets):
734 740 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
735 741
736 742 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
737 743
738 744 if block:
739 745 try:
740 746 return r.get()
741 747 except KeyboardInterrupt:
742 748 pass
743 749 return r
744 750
745 751 def __getitem__(self, key):
746 752 return self.get(key)
747 753
748 754 def __setitem__(self,key, value):
749 755 self.update({key:value})
750 756
751 757 def clear(self, targets=None, block=False):
752 758 """Clear the remote namespaces on my engines."""
753 759 block = block if block is not None else self.block
754 760 targets = targets if targets is not None else self.targets
755 761 return self.client.clear(targets=targets, block=block)
756 762
757 763 def kill(self, targets=None, block=True):
758 764 """Kill my engines."""
759 765 block = block if block is not None else self.block
760 766 targets = targets if targets is not None else self.targets
761 767 return self.client.kill(targets=targets, block=block)
762 768
763 769 #----------------------------------------
764 770 # activate for %px,%autopx magics
765 771 #----------------------------------------
766 772 def activate(self):
767 773 """Make this `View` active for parallel magic commands.
768 774
769 775 IPython has a magic command syntax to work with `MultiEngineClient` objects.
770 776 In a given IPython session there is a single active one. While
771 777 there can be many `Views` created and used by the user,
772 778 there is only one active one. The active `View` is used whenever
773 779 the magic commands %px and %autopx are used.
774 780
775 781 The activate() method is called on a given `View` to make it
776 782 active. Once this has been done, the magic commands can be used.
777 783 """
778 784
779 785 try:
780 786 # This is injected into __builtins__.
781 787 ip = get_ipython()
782 788 except NameError:
783 789 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
784 790 else:
785 791 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
786 792 if pmagic is None:
787 793 ip.magic_load_ext('parallelmagic')
788 794 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
789 795
790 796 pmagic.active_view = self
791 797
792 798
793 799 @skip_doctest
794 800 class LoadBalancedView(View):
795 801 """An load-balancing View that only executes via the Task scheduler.
796 802
797 803 Load-balanced views can be created with the client's `view` method:
798 804
799 805 >>> v = client.load_balanced_view()
800 806
801 807 or targets can be specified, to restrict the potential destinations:
802 808
803 809 >>> v = client.client.load_balanced_view([1,3])
804 810
805 811 which would restrict loadbalancing to between engines 1 and 3.
806 812
807 813 """
808 814
809 815 follow=Any()
810 816 after=Any()
811 817 timeout=CFloat()
812 818 retries = Integer(0)
813 819
814 820 _task_scheme = Any()
815 821 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
816 822
817 823 def __init__(self, client=None, socket=None, **flags):
818 824 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
819 825 self._task_scheme=client._task_scheme
820 826
821 827 def _validate_dependency(self, dep):
822 828 """validate a dependency.
823 829
824 830 For use in `set_flags`.
825 831 """
826 832 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
827 833 return True
828 834 elif isinstance(dep, (list,set, tuple)):
829 835 for d in dep:
830 836 if not isinstance(d, (basestring, AsyncResult)):
831 837 return False
832 838 elif isinstance(dep, dict):
833 839 if set(dep.keys()) != set(Dependency().as_dict().keys()):
834 840 return False
835 841 if not isinstance(dep['msg_ids'], list):
836 842 return False
837 843 for d in dep['msg_ids']:
838 844 if not isinstance(d, basestring):
839 845 return False
840 846 else:
841 847 return False
842 848
843 849 return True
844 850
845 851 def _render_dependency(self, dep):
846 852 """helper for building jsonable dependencies from various input forms."""
847 853 if isinstance(dep, Dependency):
848 854 return dep.as_dict()
849 855 elif isinstance(dep, AsyncResult):
850 856 return dep.msg_ids
851 857 elif dep is None:
852 858 return []
853 859 else:
854 860 # pass to Dependency constructor
855 861 return list(Dependency(dep))
856 862
857 863 def set_flags(self, **kwargs):
858 864 """set my attribute flags by keyword.
859 865
860 866 A View is a wrapper for the Client's apply method, but with attributes
861 867 that specify keyword arguments, those attributes can be set by keyword
862 868 argument with this method.
863 869
864 870 Parameters
865 871 ----------
866 872
867 873 block : bool
868 874 whether to wait for results
869 875 track : bool
870 876 whether to create a MessageTracker to allow the user to
871 877 safely edit after arrays and buffers during non-copying
872 878 sends.
873 879
874 880 after : Dependency or collection of msg_ids
875 881 Only for load-balanced execution (targets=None)
876 882 Specify a list of msg_ids as a time-based dependency.
877 883 This job will only be run *after* the dependencies
878 884 have been met.
879 885
880 886 follow : Dependency or collection of msg_ids
881 887 Only for load-balanced execution (targets=None)
882 888 Specify a list of msg_ids as a location-based dependency.
883 889 This job will only be run on an engine where this dependency
884 890 is met.
885 891
886 892 timeout : float/int or None
887 893 Only for load-balanced execution (targets=None)
888 894 Specify an amount of time (in seconds) for the scheduler to
889 895 wait for dependencies to be met before failing with a
890 896 DependencyTimeout.
891 897
892 898 retries : int
893 899 Number of times a task will be retried on failure.
894 900 """
895 901
896 902 super(LoadBalancedView, self).set_flags(**kwargs)
897 903 for name in ('follow', 'after'):
898 904 if name in kwargs:
899 905 value = kwargs[name]
900 906 if self._validate_dependency(value):
901 907 setattr(self, name, value)
902 908 else:
903 909 raise ValueError("Invalid dependency: %r"%value)
904 910 if 'timeout' in kwargs:
905 911 t = kwargs['timeout']
906 912 if not isinstance(t, (int, long, float, type(None))):
907 913 raise TypeError("Invalid type for timeout: %r"%type(t))
908 914 if t is not None:
909 915 if t < 0:
910 916 raise ValueError("Invalid timeout: %s"%t)
911 917 self.timeout = t
912 918
913 919 @sync_results
914 920 @save_ids
915 921 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
916 922 after=None, follow=None, timeout=None,
917 923 targets=None, retries=None):
918 924 """calls f(*args, **kwargs) on a remote engine, returning the result.
919 925
920 926 This method temporarily sets all of `apply`'s flags for a single call.
921 927
922 928 Parameters
923 929 ----------
924 930
925 931 f : callable
926 932
927 933 args : list [default: empty]
928 934
929 935 kwargs : dict [default: empty]
930 936
931 937 block : bool [default: self.block]
932 938 whether to block
933 939 track : bool [default: self.track]
934 940 whether to ask zmq to track the message, for safe non-copying sends
935 941
936 942 !!!!!! TODO: THE REST HERE !!!!
937 943
938 944 Returns
939 945 -------
940 946
941 947 if self.block is False:
942 948 returns AsyncResult
943 949 else:
944 950 returns actual result of f(*args, **kwargs) on the engine(s)
945 951 This will be a list of self.targets is also a list (even length 1), or
946 952 the single result if self.targets is an integer engine id
947 953 """
948 954
949 955 # validate whether we can run
950 956 if self._socket.closed:
951 957 msg = "Task farming is disabled"
952 958 if self._task_scheme == 'pure':
953 959 msg += " because the pure ZMQ scheduler cannot handle"
954 960 msg += " disappearing engines."
955 961 raise RuntimeError(msg)
956 962
957 963 if self._task_scheme == 'pure':
958 964 # pure zmq scheme doesn't support extra features
959 965 msg = "Pure ZMQ scheduler doesn't support the following flags:"
960 966 "follow, after, retries, targets, timeout"
961 967 if (follow or after or retries or targets or timeout):
962 968 # hard fail on Scheduler flags
963 969 raise RuntimeError(msg)
964 970 if isinstance(f, dependent):
965 971 # soft warn on functional dependencies
966 972 warnings.warn(msg, RuntimeWarning)
967 973
968 974 # build args
969 975 args = [] if args is None else args
970 976 kwargs = {} if kwargs is None else kwargs
971 977 block = self.block if block is None else block
972 978 track = self.track if track is None else track
973 979 after = self.after if after is None else after
974 980 retries = self.retries if retries is None else retries
975 981 follow = self.follow if follow is None else follow
976 982 timeout = self.timeout if timeout is None else timeout
977 983 targets = self.targets if targets is None else targets
978 984
979 985 if not isinstance(retries, int):
980 986 raise TypeError('retries must be int, not %r'%type(retries))
981 987
982 988 if targets is None:
983 989 idents = []
984 990 else:
985 991 idents = self.client._build_targets(targets)[0]
986 992 # ensure *not* bytes
987 993 idents = [ ident.decode() for ident in idents ]
988 994
989 995 after = self._render_dependency(after)
990 996 follow = self._render_dependency(follow)
991 997 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
992 998
993 999 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
994 1000 subheader=subheader)
995 1001 tracker = None if track is False else msg['tracker']
996 1002
997 1003 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
998 1004
999 1005 if block:
1000 1006 try:
1001 1007 return ar.get()
1002 1008 except KeyboardInterrupt:
1003 1009 pass
1004 1010 return ar
1005 1011
1006 1012 @spin_after
1007 1013 @save_ids
1008 1014 def map(self, f, *sequences, **kwargs):
1009 1015 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1010 1016
1011 1017 Parallel version of builtin `map`, load-balanced by this View.
1012 1018
1013 1019 `block`, and `chunksize` can be specified by keyword only.
1014 1020
1015 1021 Each `chunksize` elements will be a separate task, and will be
1016 1022 load-balanced. This lets individual elements be available for iteration
1017 1023 as soon as they arrive.
1018 1024
1019 1025 Parameters
1020 1026 ----------
1021 1027
1022 1028 f : callable
1023 1029 function to be mapped
1024 1030 *sequences: one or more sequences of matching length
1025 1031 the sequences to be distributed and passed to `f`
1026 1032 block : bool [default self.block]
1027 1033 whether to wait for the result or not
1028 1034 track : bool
1029 1035 whether to create a MessageTracker to allow the user to
1030 1036 safely edit after arrays and buffers during non-copying
1031 1037 sends.
1032 1038 chunksize : int [default 1]
1033 1039 how many elements should be in each task.
1034 1040 ordered : bool [default True]
1035 1041 Whether the results should be gathered as they arrive, or enforce
1036 1042 the order of submission.
1037 1043
1038 1044 Only applies when iterating through AsyncMapResult as results arrive.
1039 1045 Has no effect when block=True.
1040 1046
1041 1047 Returns
1042 1048 -------
1043 1049
1044 1050 if block=False:
1045 1051 AsyncMapResult
1046 1052 An object like AsyncResult, but which reassembles the sequence of results
1047 1053 into a single list. AsyncMapResults can be iterated through before all
1048 1054 results are complete.
1049 1055 else:
1050 1056 the result of map(f,*sequences)
1051 1057
1052 1058 """
1053 1059
1054 1060 # default
1055 1061 block = kwargs.get('block', self.block)
1056 1062 chunksize = kwargs.get('chunksize', 1)
1057 1063 ordered = kwargs.get('ordered', True)
1058 1064
1059 1065 keyset = set(kwargs.keys())
1060 1066 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1061 1067 if extra_keys:
1062 1068 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1063 1069
1064 1070 assert len(sequences) > 0, "must have some sequences to map onto!"
1065 1071
1066 1072 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1067 1073 return pf.map(*sequences)
1068 1074
1069 1075 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,544 +1,553 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import time
21 21 from tempfile import mktemp
22 22 from StringIO import StringIO
23 23
24 24 import zmq
25 25 from nose import SkipTest
26 26
27 27 from IPython.testing import decorators as dec
28 28
29 29 from IPython import parallel as pmod
30 30 from IPython.parallel import error
31 31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
32 32 from IPython.parallel import DirectView
33 33 from IPython.parallel.util import interactive
34 34
35 35 from IPython.parallel.tests import add_engines
36 36
37 37 from .clienttest import ClusterTestCase, crash, wait, skip_without
38 38
39 39 def setup():
40 40 add_engines(3, total=True)
41 41
42 42 class TestView(ClusterTestCase):
43 43
44 44 def test_z_crash_mux(self):
45 45 """test graceful handling of engine death (direct)"""
46 46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
47 47 # self.add_engines(1)
48 48 eid = self.client.ids[-1]
49 49 ar = self.client[eid].apply_async(crash)
50 50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 51 eid = ar.engine_id
52 52 tic = time.time()
53 53 while eid in self.client.ids and time.time()-tic < 5:
54 54 time.sleep(.01)
55 55 self.client.spin()
56 56 self.assertFalse(eid in self.client.ids, "Engine should have died")
57 57
58 58 def test_push_pull(self):
59 59 """test pushing and pulling"""
60 60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
61 61 t = self.client.ids[-1]
62 62 v = self.client[t]
63 63 push = v.push
64 64 pull = v.pull
65 65 v.block=True
66 66 nengines = len(self.client)
67 67 push({'data':data})
68 68 d = pull('data')
69 69 self.assertEquals(d, data)
70 70 self.client[:].push({'data':data})
71 71 d = self.client[:].pull('data', block=True)
72 72 self.assertEquals(d, nengines*[data])
73 73 ar = push({'data':data}, block=False)
74 74 self.assertTrue(isinstance(ar, AsyncResult))
75 75 r = ar.get()
76 76 ar = self.client[:].pull('data', block=False)
77 77 self.assertTrue(isinstance(ar, AsyncResult))
78 78 r = ar.get()
79 79 self.assertEquals(r, nengines*[data])
80 80 self.client[:].push(dict(a=10,b=20))
81 81 r = self.client[:].pull(('a','b'), block=True)
82 82 self.assertEquals(r, nengines*[[10,20]])
83 83
84 84 def test_push_pull_function(self):
85 85 "test pushing and pulling functions"
86 86 def testf(x):
87 87 return 2.0*x
88 88
89 89 t = self.client.ids[-1]
90 90 v = self.client[t]
91 91 v.block=True
92 92 push = v.push
93 93 pull = v.pull
94 94 execute = v.execute
95 95 push({'testf':testf})
96 96 r = pull('testf')
97 97 self.assertEqual(r(1.0), testf(1.0))
98 98 execute('r = testf(10)')
99 99 r = pull('r')
100 100 self.assertEquals(r, testf(10))
101 101 ar = self.client[:].push({'testf':testf}, block=False)
102 102 ar.get()
103 103 ar = self.client[:].pull('testf', block=False)
104 104 rlist = ar.get()
105 105 for r in rlist:
106 106 self.assertEqual(r(1.0), testf(1.0))
107 107 execute("def g(x): return x*x")
108 108 r = pull(('testf','g'))
109 109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
110 110
111 111 def test_push_function_globals(self):
112 112 """test that pushed functions have access to globals"""
113 113 @interactive
114 114 def geta():
115 115 return a
116 116 # self.add_engines(1)
117 117 v = self.client[-1]
118 118 v.block=True
119 119 v['f'] = geta
120 120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 121 v.execute('a=5')
122 122 v.execute('b=f()')
123 123 self.assertEquals(v['b'], 5)
124 124
125 125 def test_push_function_defaults(self):
126 126 """test that pushed functions preserve default args"""
127 127 def echo(a=10):
128 128 return a
129 129 v = self.client[-1]
130 130 v.block=True
131 131 v['f'] = echo
132 132 v.execute('b=f()')
133 133 self.assertEquals(v['b'], 10)
134 134
135 135 def test_get_result(self):
136 136 """test getting results from the Hub."""
137 137 c = pmod.Client(profile='iptest')
138 138 # self.add_engines(1)
139 139 t = c.ids[-1]
140 140 v = c[t]
141 141 v2 = self.client[t]
142 142 ar = v.apply_async(wait, 1)
143 143 # give the monitor time to notice the message
144 144 time.sleep(.25)
145 145 ahr = v2.get_result(ar.msg_ids)
146 146 self.assertTrue(isinstance(ahr, AsyncHubResult))
147 147 self.assertEquals(ahr.get(), ar.get())
148 148 ar2 = v2.get_result(ar.msg_ids)
149 149 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 150 c.spin()
151 151 c.close()
152 152
153 153 def test_run_newline(self):
154 154 """test that run appends newline to files"""
155 155 tmpfile = mktemp()
156 156 with open(tmpfile, 'w') as f:
157 157 f.write("""def g():
158 158 return 5
159 159 """)
160 160 v = self.client[-1]
161 161 v.run(tmpfile, block=True)
162 162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
163 163
164 164 def test_apply_tracked(self):
165 165 """test tracking for apply"""
166 166 # self.add_engines(1)
167 167 t = self.client.ids[-1]
168 168 v = self.client[t]
169 169 v.block=False
170 170 def echo(n=1024*1024, **kwargs):
171 171 with v.temp_flags(**kwargs):
172 172 return v.apply(lambda x: x, 'x'*n)
173 173 ar = echo(1, track=False)
174 174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 175 self.assertTrue(ar.sent)
176 176 ar = echo(track=True)
177 177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 178 self.assertEquals(ar.sent, ar._tracker.done)
179 179 ar._tracker.wait()
180 180 self.assertTrue(ar.sent)
181 181
182 182 def test_push_tracked(self):
183 183 t = self.client.ids[-1]
184 184 ns = dict(x='x'*1024*1024)
185 185 v = self.client[t]
186 186 ar = v.push(ns, block=False, track=False)
187 187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 188 self.assertTrue(ar.sent)
189 189
190 190 ar = v.push(ns, block=False, track=True)
191 191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 192 ar._tracker.wait()
193 193 self.assertEquals(ar.sent, ar._tracker.done)
194 194 self.assertTrue(ar.sent)
195 195 ar.get()
196 196
197 197 def test_scatter_tracked(self):
198 198 t = self.client.ids
199 199 x='x'*1024*1024
200 200 ar = self.client[t].scatter('x', x, block=False, track=False)
201 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 202 self.assertTrue(ar.sent)
203 203
204 204 ar = self.client[t].scatter('x', x, block=False, track=True)
205 205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 206 self.assertEquals(ar.sent, ar._tracker.done)
207 207 ar._tracker.wait()
208 208 self.assertTrue(ar.sent)
209 209 ar.get()
210 210
211 211 def test_remote_reference(self):
212 212 v = self.client[-1]
213 213 v['a'] = 123
214 214 ra = pmod.Reference('a')
215 215 b = v.apply_sync(lambda x: x, ra)
216 216 self.assertEquals(b, 123)
217 217
218 218
219 219 def test_scatter_gather(self):
220 220 view = self.client[:]
221 221 seq1 = range(16)
222 222 view.scatter('a', seq1)
223 223 seq2 = view.gather('a', block=True)
224 224 self.assertEquals(seq2, seq1)
225 225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
226 226
227 227 @skip_without('numpy')
228 228 def test_scatter_gather_numpy(self):
229 229 import numpy
230 230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
231 231 view = self.client[:]
232 232 a = numpy.arange(64)
233 233 view.scatter('a', a)
234 234 b = view.gather('a', block=True)
235 235 assert_array_equal(b, a)
236
237 def test_scatter_gather_lazy(self):
238 """scatter/gather with targets='all'"""
239 view = self.client.direct_view(targets='all')
240 x = range(64)
241 view.scatter('x', x)
242 gathered = view.gather('x', block=True)
243 self.assertEquals(gathered, x)
244
236 245
237 246 @dec.known_failure_py3
238 247 @skip_without('numpy')
239 248 def test_push_numpy_nocopy(self):
240 249 import numpy
241 250 view = self.client[:]
242 251 a = numpy.arange(64)
243 252 view['A'] = a
244 253 @interactive
245 254 def check_writeable(x):
246 255 return x.flags.writeable
247 256
248 257 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
249 258 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
250 259
251 260 view.push(dict(B=a))
252 261 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
253 262 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
254 263
255 264 @skip_without('numpy')
256 265 def test_apply_numpy(self):
257 266 """view.apply(f, ndarray)"""
258 267 import numpy
259 268 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
260 269
261 270 A = numpy.random.random((100,100))
262 271 view = self.client[-1]
263 272 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
264 273 B = A.astype(dt)
265 274 C = view.apply_sync(lambda x:x, B)
266 275 assert_array_equal(B,C)
267 276
268 277 def test_map(self):
269 278 view = self.client[:]
270 279 def f(x):
271 280 return x**2
272 281 data = range(16)
273 282 r = view.map_sync(f, data)
274 283 self.assertEquals(r, map(f, data))
275 284
276 285 def test_map_iterable(self):
277 286 """test map on iterables (direct)"""
278 287 view = self.client[:]
279 288 # 101 is prime, so it won't be evenly distributed
280 289 arr = range(101)
281 290 # ensure it will be an iterator, even in Python 3
282 291 it = iter(arr)
283 292 r = view.map_sync(lambda x:x, arr)
284 293 self.assertEquals(r, list(arr))
285 294
286 295 def test_scatterGatherNonblocking(self):
287 296 data = range(16)
288 297 view = self.client[:]
289 298 view.scatter('a', data, block=False)
290 299 ar = view.gather('a', block=False)
291 300 self.assertEquals(ar.get(), data)
292 301
293 302 @skip_without('numpy')
294 303 def test_scatter_gather_numpy_nonblocking(self):
295 304 import numpy
296 305 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
297 306 a = numpy.arange(64)
298 307 view = self.client[:]
299 308 ar = view.scatter('a', a, block=False)
300 309 self.assertTrue(isinstance(ar, AsyncResult))
301 310 amr = view.gather('a', block=False)
302 311 self.assertTrue(isinstance(amr, AsyncMapResult))
303 312 assert_array_equal(amr.get(), a)
304 313
305 314 def test_execute(self):
306 315 view = self.client[:]
307 316 # self.client.debug=True
308 317 execute = view.execute
309 318 ar = execute('c=30', block=False)
310 319 self.assertTrue(isinstance(ar, AsyncResult))
311 320 ar = execute('d=[0,1,2]', block=False)
312 321 self.client.wait(ar, 1)
313 322 self.assertEquals(len(ar.get()), len(self.client))
314 323 for c in view['c']:
315 324 self.assertEquals(c, 30)
316 325
317 326 def test_abort(self):
318 327 view = self.client[-1]
319 328 ar = view.execute('import time; time.sleep(1)', block=False)
320 329 ar2 = view.apply_async(lambda : 2)
321 330 ar3 = view.apply_async(lambda : 3)
322 331 view.abort(ar2)
323 332 view.abort(ar3.msg_ids)
324 333 self.assertRaises(error.TaskAborted, ar2.get)
325 334 self.assertRaises(error.TaskAborted, ar3.get)
326 335
327 336 def test_abort_all(self):
328 337 """view.abort() aborts all outstanding tasks"""
329 338 view = self.client[-1]
330 339 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
331 340 view.abort()
332 341 view.wait(timeout=5)
333 342 for ar in ars[5:]:
334 343 self.assertRaises(error.TaskAborted, ar.get)
335 344
336 345 def test_temp_flags(self):
337 346 view = self.client[-1]
338 347 view.block=True
339 348 with view.temp_flags(block=False):
340 349 self.assertFalse(view.block)
341 350 self.assertTrue(view.block)
342 351
343 352 @dec.known_failure_py3
344 353 def test_importer(self):
345 354 view = self.client[-1]
346 355 view.clear(block=True)
347 356 with view.importer:
348 357 import re
349 358
350 359 @interactive
351 360 def findall(pat, s):
352 361 # this globals() step isn't necessary in real code
353 362 # only to prevent a closure in the test
354 363 re = globals()['re']
355 364 return re.findall(pat, s)
356 365
357 366 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
358 367
359 368 # parallel magic tests
360 369
361 370 def test_magic_px_blocking(self):
362 371 ip = get_ipython()
363 372 v = self.client[-1]
364 373 v.activate()
365 374 v.block=True
366 375
367 376 ip.magic_px('a=5')
368 377 self.assertEquals(v['a'], 5)
369 378 ip.magic_px('a=10')
370 379 self.assertEquals(v['a'], 10)
371 380 sio = StringIO()
372 381 savestdout = sys.stdout
373 382 sys.stdout = sio
374 383 # just 'print a' worst ~99% of the time, but this ensures that
375 384 # the stdout message has arrived when the result is finished:
376 385 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
377 386 sys.stdout = savestdout
378 387 buf = sio.getvalue()
379 388 self.assertTrue('[stdout:' in buf, buf)
380 389 self.assertTrue(buf.rstrip().endswith('10'))
381 390 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
382 391
383 392 def test_magic_px_nonblocking(self):
384 393 ip = get_ipython()
385 394 v = self.client[-1]
386 395 v.activate()
387 396 v.block=False
388 397
389 398 ip.magic_px('a=5')
390 399 self.assertEquals(v['a'], 5)
391 400 ip.magic_px('a=10')
392 401 self.assertEquals(v['a'], 10)
393 402 sio = StringIO()
394 403 savestdout = sys.stdout
395 404 sys.stdout = sio
396 405 ip.magic_px('print a')
397 406 sys.stdout = savestdout
398 407 buf = sio.getvalue()
399 408 self.assertFalse('[stdout:%i]'%v.targets in buf)
400 409 ip.magic_px('1/0')
401 410 ar = v.get_result(-1)
402 411 self.assertRaisesRemote(ZeroDivisionError, ar.get)
403 412
404 413 def test_magic_autopx_blocking(self):
405 414 ip = get_ipython()
406 415 v = self.client[-1]
407 416 v.activate()
408 417 v.block=True
409 418
410 419 sio = StringIO()
411 420 savestdout = sys.stdout
412 421 sys.stdout = sio
413 422 ip.magic_autopx()
414 423 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
415 424 ip.run_cell('print b')
416 425 ip.run_cell("b/c")
417 426 ip.run_code(compile('b*=2', '', 'single'))
418 427 ip.magic_autopx()
419 428 sys.stdout = savestdout
420 429 output = sio.getvalue().strip()
421 430 self.assertTrue(output.startswith('%autopx enabled'))
422 431 self.assertTrue(output.endswith('%autopx disabled'))
423 432 self.assertTrue('RemoteError: ZeroDivisionError' in output)
424 433 ar = v.get_result(-2)
425 434 self.assertEquals(v['a'], 5)
426 435 self.assertEquals(v['b'], 20)
427 436 self.assertRaisesRemote(ZeroDivisionError, ar.get)
428 437
429 438 def test_magic_autopx_nonblocking(self):
430 439 ip = get_ipython()
431 440 v = self.client[-1]
432 441 v.activate()
433 442 v.block=False
434 443
435 444 sio = StringIO()
436 445 savestdout = sys.stdout
437 446 sys.stdout = sio
438 447 ip.magic_autopx()
439 448 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
440 449 ip.run_cell('print b')
441 450 ip.run_cell("b/c")
442 451 ip.run_code(compile('b*=2', '', 'single'))
443 452 ip.magic_autopx()
444 453 sys.stdout = savestdout
445 454 output = sio.getvalue().strip()
446 455 self.assertTrue(output.startswith('%autopx enabled'))
447 456 self.assertTrue(output.endswith('%autopx disabled'))
448 457 self.assertFalse('ZeroDivisionError' in output)
449 458 ar = v.get_result(-2)
450 459 self.assertEquals(v['a'], 5)
451 460 self.assertEquals(v['b'], 20)
452 461 self.assertRaisesRemote(ZeroDivisionError, ar.get)
453 462
454 463 def test_magic_result(self):
455 464 ip = get_ipython()
456 465 v = self.client[-1]
457 466 v.activate()
458 467 v['a'] = 111
459 468 ra = v['a']
460 469
461 470 ar = ip.magic_result()
462 471 self.assertEquals(ar.msg_ids, [v.history[-1]])
463 472 self.assertEquals(ar.get(), 111)
464 473 ar = ip.magic_result('-2')
465 474 self.assertEquals(ar.msg_ids, [v.history[-2]])
466 475
467 476 def test_unicode_execute(self):
468 477 """test executing unicode strings"""
469 478 v = self.client[-1]
470 479 v.block=True
471 480 if sys.version_info[0] >= 3:
472 481 code="a='é'"
473 482 else:
474 483 code=u"a=u'é'"
475 484 v.execute(code)
476 485 self.assertEquals(v['a'], u'é')
477 486
478 487 def test_unicode_apply_result(self):
479 488 """test unicode apply results"""
480 489 v = self.client[-1]
481 490 r = v.apply_sync(lambda : u'é')
482 491 self.assertEquals(r, u'é')
483 492
484 493 def test_unicode_apply_arg(self):
485 494 """test passing unicode arguments to apply"""
486 495 v = self.client[-1]
487 496
488 497 @interactive
489 498 def check_unicode(a, check):
490 499 assert isinstance(a, unicode), "%r is not unicode"%a
491 500 assert isinstance(check, bytes), "%r is not bytes"%check
492 501 assert a.encode('utf8') == check, "%s != %s"%(a,check)
493 502
494 503 for s in [ u'é', u'ßø®∫',u'asdf' ]:
495 504 try:
496 505 v.apply_sync(check_unicode, s, s.encode('utf8'))
497 506 except error.RemoteError as e:
498 507 if e.ename == 'AssertionError':
499 508 self.fail(e.evalue)
500 509 else:
501 510 raise e
502 511
503 512 def test_map_reference(self):
504 513 """view.map(<Reference>, *seqs) should work"""
505 514 v = self.client[:]
506 515 v.scatter('n', self.client.ids, flatten=True)
507 516 v.execute("f = lambda x,y: x*y")
508 517 rf = pmod.Reference('f')
509 518 nlist = list(range(10))
510 519 mlist = nlist[::-1]
511 520 expected = [ m*n for m,n in zip(mlist, nlist) ]
512 521 result = v.map_sync(rf, mlist, nlist)
513 522 self.assertEquals(result, expected)
514 523
515 524 def test_apply_reference(self):
516 525 """view.apply(<Reference>, *args) should work"""
517 526 v = self.client[:]
518 527 v.scatter('n', self.client.ids, flatten=True)
519 528 v.execute("f = lambda x: n*x")
520 529 rf = pmod.Reference('f')
521 530 result = v.apply_sync(rf, 5)
522 531 expected = [ 5*id for id in self.client.ids ]
523 532 self.assertEquals(result, expected)
524 533
525 534 def test_eval_reference(self):
526 535 v = self.client[self.client.ids[0]]
527 536 v['g'] = range(5)
528 537 rg = pmod.Reference('g[0]')
529 538 echo = lambda x:x
530 539 self.assertEquals(v.apply_sync(echo, rg), 0)
531 540
532 541 def test_reference_nameerror(self):
533 542 v = self.client[self.client.ids[0]]
534 543 r = pmod.Reference('elvis_has_left')
535 544 echo = lambda x:x
536 545 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
537 546
538 547 def test_single_engine_map(self):
539 548 e0 = self.client[self.client.ids[0]]
540 549 r = range(5)
541 550 check = [ -1*i for i in r ]
542 551 result = e0.map_sync(lambda x: -1*x, r)
543 552 self.assertEquals(result, check)
544 553
General Comments 0
You need to be logged in to leave comments. Login now