##// END OF EJS Templates
allow Reference as callable in map/apply...
MinRK -
Show More
@@ -40,6 +40,7 b' from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,'
40 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
41 from IPython.external.ssh import tunnel
41 from IPython.external.ssh import tunnel
42
42
43 from IPython.parallel import Reference
43 from IPython.parallel import error
44 from IPython.parallel import error
44 from IPython.parallel import util
45 from IPython.parallel import util
45
46
@@ -982,7 +983,7 b' class Client(HasTraits):'
982 subheader = subheader if subheader is not None else {}
983 subheader = subheader if subheader is not None else {}
983
984
984 # validate arguments
985 # validate arguments
985 if not callable(f):
986 if not callable(f) and not isinstance(f, Reference):
986 raise TypeError("f must be callable, not %s"%type(f))
987 raise TypeError("f must be callable, not %s"%type(f))
987 if not isinstance(args, (tuple, list)):
988 if not isinstance(args, (tuple, list)):
988 raise TypeError("args must be tuple or list, not %s"%type(args))
989 raise TypeError("args must be tuple or list, not %s"%type(args))
@@ -27,7 +27,7 b' from . import map as Map'
27 from .asyncresult import AsyncMapResult
27 from .asyncresult import AsyncMapResult
28
28
29 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
30 # Decorators
30 # Functions and Decorators
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32
32
33 @skip_doctest
33 @skip_doctest
@@ -60,6 +60,25 b" def parallel(view, dist='b', block=None, ordered=True, **flags):"
60 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
60 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
61 return parallel_function
61 return parallel_function
62
62
63 def getname(f):
64 """Get the name of an object.
65
66 For use in case of callables that are not functions, and
67 thus may not have __name__ defined.
68
69 Order: f.__name__ > f.name > str(f)
70 """
71 try:
72 return f.__name__
73 except:
74 pass
75 try:
76 return f.name
77 except:
78 pass
79
80 return str(f)
81
63 #--------------------------------------------------------------------------
82 #--------------------------------------------------------------------------
64 # Classes
83 # Classes
65 #--------------------------------------------------------------------------
84 #--------------------------------------------------------------------------
@@ -194,7 +213,7 b' class ParallelFunction(RemoteFunction):'
194 msg_ids.append(ar.msg_ids[0])
213 msg_ids.append(ar.msg_ids[0])
195
214
196 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
215 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
197 fname=self.func.__name__,
216 fname=getname(self.func),
198 ordered=self.ordered
217 ordered=self.ordered
199 )
218 )
200
219
@@ -34,7 +34,7 b' from IPython.parallel.controller.dependency import Dependency, dependent'
34
34
35 from . import map as Map
35 from . import map as Map
36 from .asyncresult import AsyncResult, AsyncMapResult
36 from .asyncresult import AsyncResult, AsyncMapResult
37 from .remotefunction import ParallelFunction, parallel, remote
37 from .remotefunction import ParallelFunction, parallel, remote, getname
38
38
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40 # Decorators
40 # Decorators
@@ -535,7 +535,7 b' class DirectView(View):'
535 trackers.append(msg['tracker'])
535 trackers.append(msg['tracker'])
536 msg_ids.append(msg['header']['msg_id'])
536 msg_ids.append(msg['header']['msg_id'])
537 tracker = None if track is False else zmq.MessageTracker(*trackers)
537 tracker = None if track is False else zmq.MessageTracker(*trackers)
538 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
538 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker)
539 if block:
539 if block:
540 try:
540 try:
541 return ar.get()
541 return ar.get()
@@ -990,7 +990,7 b' class LoadBalancedView(View):'
990 subheader=subheader)
990 subheader=subheader)
991 tracker = None if track is False else msg['tracker']
991 tracker = None if track is False else msg['tracker']
992
992
993 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
993 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
994
994
995 if block:
995 if block:
996 try:
996 try:
@@ -469,4 +469,25 b' class TestView(ClusterTestCase):'
469 else:
469 else:
470 raise e
470 raise e
471
471
472 def test_map_reference(self):
473 """view.map(<Reference>, *seqs) should work"""
474 v = self.client[:]
475 v.scatter('n', self.client.ids, flatten=True)
476 v.execute("f = lambda x,y: x*y")
477 rf = pmod.Reference('f')
478 nlist = list(range(10))
479 mlist = nlist[::-1]
480 expected = [ m*n for m,n in zip(mlist, nlist) ]
481 result = v.map_sync(rf, mlist, nlist)
482 self.assertEquals(result, expected)
483
484 def test_apply_reference(self):
485 """view.apply(<Reference>, *args) should work"""
486 v = self.client[:]
487 v.scatter('n', self.client.ids, flatten=True)
488 v.execute("f = lambda x: n*x")
489 rf = pmod.Reference('f')
490 result = v.apply_sync(rf, 5)
491 expected = [ 5*id for id in self.client.ids ]
492 self.assertEquals(result, expected)
472
493
General Comments 0
You need to be logged in to leave comments. Login now