Show More
@@ -40,6 +40,7 b' from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,' | |||
|
40 | 40 | from IPython.external.decorator import decorator |
|
41 | 41 | from IPython.external.ssh import tunnel |
|
42 | 42 | |
|
43 | from IPython.parallel import Reference | |
|
43 | 44 | from IPython.parallel import error |
|
44 | 45 | from IPython.parallel import util |
|
45 | 46 | |
@@ -982,7 +983,7 b' class Client(HasTraits):' | |||
|
982 | 983 | subheader = subheader if subheader is not None else {} |
|
983 | 984 | |
|
984 | 985 | # validate arguments |
|
985 | if not callable(f): | |
|
986 | if not callable(f) and not isinstance(f, Reference): | |
|
986 | 987 | raise TypeError("f must be callable, not %s"%type(f)) |
|
987 | 988 | if not isinstance(args, (tuple, list)): |
|
988 | 989 | raise TypeError("args must be tuple or list, not %s"%type(args)) |
@@ -27,7 +27,7 b' from . import map as Map' | |||
|
27 | 27 | from .asyncresult import AsyncMapResult |
|
28 | 28 | |
|
29 | 29 | #----------------------------------------------------------------------------- |
|
30 | # Decorators | |
|
30 | # Functions and Decorators | |
|
31 | 31 | #----------------------------------------------------------------------------- |
|
32 | 32 | |
|
33 | 33 | @skip_doctest |
@@ -60,6 +60,25 b" def parallel(view, dist='b', block=None, ordered=True, **flags):" | |||
|
60 | 60 | return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags) |
|
61 | 61 | return parallel_function |
|
62 | 62 | |
|
63 | def getname(f): | |
|
64 | """Get the name of an object. | |
|
65 | ||
|
66 | For use in case of callables that are not functions, and | |
|
67 | thus may not have __name__ defined. | |
|
68 | ||
|
69 | Order: f.__name__ > f.name > str(f) | |
|
70 | """ | |
|
71 | try: | |
|
72 | return f.__name__ | |
|
73 | except: | |
|
74 | pass | |
|
75 | try: | |
|
76 | return f.name | |
|
77 | except: | |
|
78 | pass | |
|
79 | ||
|
80 | return str(f) | |
|
81 | ||
|
63 | 82 | #-------------------------------------------------------------------------- |
|
64 | 83 | # Classes |
|
65 | 84 | #-------------------------------------------------------------------------- |
@@ -194,7 +213,7 b' class ParallelFunction(RemoteFunction):' | |||
|
194 | 213 | msg_ids.append(ar.msg_ids[0]) |
|
195 | 214 | |
|
196 | 215 | r = AsyncMapResult(self.view.client, msg_ids, self.mapObject, |
|
197 |
fname=self.func |
|
|
216 | fname=getname(self.func), | |
|
198 | 217 | ordered=self.ordered |
|
199 | 218 | ) |
|
200 | 219 |
@@ -34,7 +34,7 b' from IPython.parallel.controller.dependency import Dependency, dependent' | |||
|
34 | 34 | |
|
35 | 35 | from . import map as Map |
|
36 | 36 | from .asyncresult import AsyncResult, AsyncMapResult |
|
37 | from .remotefunction import ParallelFunction, parallel, remote | |
|
37 | from .remotefunction import ParallelFunction, parallel, remote, getname | |
|
38 | 38 | |
|
39 | 39 | #----------------------------------------------------------------------------- |
|
40 | 40 | # Decorators |
@@ -535,7 +535,7 b' class DirectView(View):' | |||
|
535 | 535 | trackers.append(msg['tracker']) |
|
536 | 536 | msg_ids.append(msg['header']['msg_id']) |
|
537 | 537 | tracker = None if track is False else zmq.MessageTracker(*trackers) |
|
538 |
ar = AsyncResult(self.client, msg_ids, fname=f |
|
|
538 | ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker) | |
|
539 | 539 | if block: |
|
540 | 540 | try: |
|
541 | 541 | return ar.get() |
@@ -990,7 +990,7 b' class LoadBalancedView(View):' | |||
|
990 | 990 | subheader=subheader) |
|
991 | 991 | tracker = None if track is False else msg['tracker'] |
|
992 | 992 | |
|
993 |
ar = AsyncResult(self.client, msg['header']['msg_id'], fname=f |
|
|
993 | ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker) | |
|
994 | 994 | |
|
995 | 995 | if block: |
|
996 | 996 | try: |
@@ -469,4 +469,25 b' class TestView(ClusterTestCase):' | |||
|
469 | 469 | else: |
|
470 | 470 | raise e |
|
471 | 471 | |
|
472 | def test_map_reference(self): | |
|
473 | """view.map(<Reference>, *seqs) should work""" | |
|
474 | v = self.client[:] | |
|
475 | v.scatter('n', self.client.ids, flatten=True) | |
|
476 | v.execute("f = lambda x,y: x*y") | |
|
477 | rf = pmod.Reference('f') | |
|
478 | nlist = list(range(10)) | |
|
479 | mlist = nlist[::-1] | |
|
480 | expected = [ m*n for m,n in zip(mlist, nlist) ] | |
|
481 | result = v.map_sync(rf, mlist, nlist) | |
|
482 | self.assertEquals(result, expected) | |
|
483 | ||
|
484 | def test_apply_reference(self): | |
|
485 | """view.apply(<Reference>, *args) should work""" | |
|
486 | v = self.client[:] | |
|
487 | v.scatter('n', self.client.ids, flatten=True) | |
|
488 | v.execute("f = lambda x: n*x") | |
|
489 | rf = pmod.Reference('f') | |
|
490 | result = v.apply_sync(rf, 5) | |
|
491 | expected = [ 5*id for id in self.client.ids ] | |
|
492 | self.assertEquals(result, expected) | |
|
472 | 493 |
General Comments 0
You need to be logged in to leave comments.
Login now