Show More
@@ -40,6 +40,7 b' from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,' | |||||
40 | from IPython.external.decorator import decorator |
|
40 | from IPython.external.decorator import decorator | |
41 | from IPython.external.ssh import tunnel |
|
41 | from IPython.external.ssh import tunnel | |
42 |
|
42 | |||
|
43 | from IPython.parallel import Reference | |||
43 | from IPython.parallel import error |
|
44 | from IPython.parallel import error | |
44 | from IPython.parallel import util |
|
45 | from IPython.parallel import util | |
45 |
|
46 | |||
@@ -982,7 +983,7 b' class Client(HasTraits):' | |||||
982 | subheader = subheader if subheader is not None else {} |
|
983 | subheader = subheader if subheader is not None else {} | |
983 |
|
984 | |||
984 | # validate arguments |
|
985 | # validate arguments | |
985 | if not callable(f): |
|
986 | if not callable(f) and not isinstance(f, Reference): | |
986 | raise TypeError("f must be callable, not %s"%type(f)) |
|
987 | raise TypeError("f must be callable, not %s"%type(f)) | |
987 | if not isinstance(args, (tuple, list)): |
|
988 | if not isinstance(args, (tuple, list)): | |
988 | raise TypeError("args must be tuple or list, not %s"%type(args)) |
|
989 | raise TypeError("args must be tuple or list, not %s"%type(args)) |
@@ -27,7 +27,7 b' from . import map as Map' | |||||
27 | from .asyncresult import AsyncMapResult |
|
27 | from .asyncresult import AsyncMapResult | |
28 |
|
28 | |||
29 | #----------------------------------------------------------------------------- |
|
29 | #----------------------------------------------------------------------------- | |
30 | # Decorators |
|
30 | # Functions and Decorators | |
31 | #----------------------------------------------------------------------------- |
|
31 | #----------------------------------------------------------------------------- | |
32 |
|
32 | |||
33 | @skip_doctest |
|
33 | @skip_doctest | |
@@ -60,6 +60,25 b" def parallel(view, dist='b', block=None, ordered=True, **flags):" | |||||
60 | return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags) |
|
60 | return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags) | |
61 | return parallel_function |
|
61 | return parallel_function | |
62 |
|
62 | |||
|
63 | def getname(f): | |||
|
64 | """Get the name of an object. | |||
|
65 | ||||
|
66 | For use in case of callables that are not functions, and | |||
|
67 | thus may not have __name__ defined. | |||
|
68 | ||||
|
69 | Order: f.__name__ > f.name > str(f) | |||
|
70 | """ | |||
|
71 | try: | |||
|
72 | return f.__name__ | |||
|
73 | except: | |||
|
74 | pass | |||
|
75 | try: | |||
|
76 | return f.name | |||
|
77 | except: | |||
|
78 | pass | |||
|
79 | ||||
|
80 | return str(f) | |||
|
81 | ||||
63 | #-------------------------------------------------------------------------- |
|
82 | #-------------------------------------------------------------------------- | |
64 | # Classes |
|
83 | # Classes | |
65 | #-------------------------------------------------------------------------- |
|
84 | #-------------------------------------------------------------------------- | |
@@ -194,7 +213,7 b' class ParallelFunction(RemoteFunction):' | |||||
194 | msg_ids.append(ar.msg_ids[0]) |
|
213 | msg_ids.append(ar.msg_ids[0]) | |
195 |
|
214 | |||
196 | r = AsyncMapResult(self.view.client, msg_ids, self.mapObject, |
|
215 | r = AsyncMapResult(self.view.client, msg_ids, self.mapObject, | |
197 |
fname=self.func |
|
216 | fname=getname(self.func), | |
198 | ordered=self.ordered |
|
217 | ordered=self.ordered | |
199 | ) |
|
218 | ) | |
200 |
|
219 |
@@ -34,7 +34,7 b' from IPython.parallel.controller.dependency import Dependency, dependent' | |||||
34 |
|
34 | |||
35 | from . import map as Map |
|
35 | from . import map as Map | |
36 | from .asyncresult import AsyncResult, AsyncMapResult |
|
36 | from .asyncresult import AsyncResult, AsyncMapResult | |
37 | from .remotefunction import ParallelFunction, parallel, remote |
|
37 | from .remotefunction import ParallelFunction, parallel, remote, getname | |
38 |
|
38 | |||
39 | #----------------------------------------------------------------------------- |
|
39 | #----------------------------------------------------------------------------- | |
40 | # Decorators |
|
40 | # Decorators | |
@@ -535,7 +535,7 b' class DirectView(View):' | |||||
535 | trackers.append(msg['tracker']) |
|
535 | trackers.append(msg['tracker']) | |
536 | msg_ids.append(msg['header']['msg_id']) |
|
536 | msg_ids.append(msg['header']['msg_id']) | |
537 | tracker = None if track is False else zmq.MessageTracker(*trackers) |
|
537 | tracker = None if track is False else zmq.MessageTracker(*trackers) | |
538 |
ar = AsyncResult(self.client, msg_ids, fname=f |
|
538 | ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker) | |
539 | if block: |
|
539 | if block: | |
540 | try: |
|
540 | try: | |
541 | return ar.get() |
|
541 | return ar.get() | |
@@ -990,7 +990,7 b' class LoadBalancedView(View):' | |||||
990 | subheader=subheader) |
|
990 | subheader=subheader) | |
991 | tracker = None if track is False else msg['tracker'] |
|
991 | tracker = None if track is False else msg['tracker'] | |
992 |
|
992 | |||
993 |
ar = AsyncResult(self.client, msg['header']['msg_id'], fname=f |
|
993 | ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker) | |
994 |
|
994 | |||
995 | if block: |
|
995 | if block: | |
996 | try: |
|
996 | try: |
@@ -469,4 +469,25 b' class TestView(ClusterTestCase):' | |||||
469 | else: |
|
469 | else: | |
470 | raise e |
|
470 | raise e | |
471 |
|
471 | |||
|
472 | def test_map_reference(self): | |||
|
473 | """view.map(<Reference>, *seqs) should work""" | |||
|
474 | v = self.client[:] | |||
|
475 | v.scatter('n', self.client.ids, flatten=True) | |||
|
476 | v.execute("f = lambda x,y: x*y") | |||
|
477 | rf = pmod.Reference('f') | |||
|
478 | nlist = list(range(10)) | |||
|
479 | mlist = nlist[::-1] | |||
|
480 | expected = [ m*n for m,n in zip(mlist, nlist) ] | |||
|
481 | result = v.map_sync(rf, mlist, nlist) | |||
|
482 | self.assertEquals(result, expected) | |||
|
483 | ||||
|
484 | def test_apply_reference(self): | |||
|
485 | """view.apply(<Reference>, *args) should work""" | |||
|
486 | v = self.client[:] | |||
|
487 | v.scatter('n', self.client.ids, flatten=True) | |||
|
488 | v.execute("f = lambda x: n*x") | |||
|
489 | rf = pmod.Reference('f') | |||
|
490 | result = v.apply_sync(rf, 5) | |||
|
491 | expected = [ 5*id for id in self.client.ids ] | |||
|
492 | self.assertEquals(result, expected) | |||
472 |
|
493 |
General Comments 0
You need to be logged in to leave comments.
Login now