##// END OF EJS Templates
allow Reference as callable in map/apply...
MinRK -
Show More
@@ -40,6 +40,7 b' from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,'
40 40 from IPython.external.decorator import decorator
41 41 from IPython.external.ssh import tunnel
42 42
43 from IPython.parallel import Reference
43 44 from IPython.parallel import error
44 45 from IPython.parallel import util
45 46
@@ -982,7 +983,7 b' class Client(HasTraits):'
982 983 subheader = subheader if subheader is not None else {}
983 984
984 985 # validate arguments
985 if not callable(f):
986 if not callable(f) and not isinstance(f, Reference):
986 987 raise TypeError("f must be callable, not %s"%type(f))
987 988 if not isinstance(args, (tuple, list)):
988 989 raise TypeError("args must be tuple or list, not %s"%type(args))
@@ -27,7 +27,7 b' from . import map as Map'
27 27 from .asyncresult import AsyncMapResult
28 28
29 29 #-----------------------------------------------------------------------------
30 # Decorators
30 # Functions and Decorators
31 31 #-----------------------------------------------------------------------------
32 32
33 33 @skip_doctest
@@ -60,6 +60,25 b" def parallel(view, dist='b', block=None, ordered=True, **flags):"
60 60 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
61 61 return parallel_function
62 62
63 def getname(f):
64 """Get the name of an object.
65
66 For use in case of callables that are not functions, and
67 thus may not have __name__ defined.
68
69 Order: f.__name__ > f.name > str(f)
70 """
71 try:
72 return f.__name__
73 except:
74 pass
75 try:
76 return f.name
77 except:
78 pass
79
80 return str(f)
81
63 82 #--------------------------------------------------------------------------
64 83 # Classes
65 84 #--------------------------------------------------------------------------
@@ -194,7 +213,7 b' class ParallelFunction(RemoteFunction):'
194 213 msg_ids.append(ar.msg_ids[0])
195 214
196 215 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
197 fname=self.func.__name__,
216 fname=getname(self.func),
198 217 ordered=self.ordered
199 218 )
200 219
@@ -34,7 +34,7 b' from IPython.parallel.controller.dependency import Dependency, dependent'
34 34
35 35 from . import map as Map
36 36 from .asyncresult import AsyncResult, AsyncMapResult
37 from .remotefunction import ParallelFunction, parallel, remote
37 from .remotefunction import ParallelFunction, parallel, remote, getname
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Decorators
@@ -535,7 +535,7 b' class DirectView(View):'
535 535 trackers.append(msg['tracker'])
536 536 msg_ids.append(msg['header']['msg_id'])
537 537 tracker = None if track is False else zmq.MessageTracker(*trackers)
538 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
538 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker)
539 539 if block:
540 540 try:
541 541 return ar.get()
@@ -990,7 +990,7 b' class LoadBalancedView(View):'
990 990 subheader=subheader)
991 991 tracker = None if track is False else msg['tracker']
992 992
993 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
993 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
994 994
995 995 if block:
996 996 try:
@@ -469,4 +469,25 b' class TestView(ClusterTestCase):'
469 469 else:
470 470 raise e
471 471
472 def test_map_reference(self):
473 """view.map(<Reference>, *seqs) should work"""
474 v = self.client[:]
475 v.scatter('n', self.client.ids, flatten=True)
476 v.execute("f = lambda x,y: x*y")
477 rf = pmod.Reference('f')
478 nlist = list(range(10))
479 mlist = nlist[::-1]
480 expected = [ m*n for m,n in zip(mlist, nlist) ]
481 result = v.map_sync(rf, mlist, nlist)
482 self.assertEquals(result, expected)
483
484 def test_apply_reference(self):
485 """view.apply(<Reference>, *args) should work"""
486 v = self.client[:]
487 v.scatter('n', self.client.ids, flatten=True)
488 v.execute("f = lambda x: n*x")
489 rf = pmod.Reference('f')
490 result = v.apply_sync(rf, 5)
491 expected = [ 5*id for id in self.client.ids ]
492 self.assertEquals(result, expected)
472 493
General Comments 0
You need to be logged in to leave comments. Login now