From 287b095cfd3b8a9c97e02717085cfba12190fea8 2012-04-06 21:34:50 From: MinRK Date: 2012-04-06 21:34:50 Subject: [PATCH] allow Reference as callable in map/apply Assumptions were made that the first argument was a callable/function with a __name__ attribute. These assumptions were the only barrier to using References, and have been removed. Associated tests included. --- diff --git a/IPython/parallel/client/client.py b/IPython/parallel/client/client.py index 7b8735f..82e1e0a 100644 --- a/IPython/parallel/client/client.py +++ b/IPython/parallel/client/client.py @@ -40,6 +40,7 @@ from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode, from IPython.external.decorator import decorator from IPython.external.ssh import tunnel +from IPython.parallel import Reference from IPython.parallel import error from IPython.parallel import util @@ -982,7 +983,7 @@ class Client(HasTraits): subheader = subheader if subheader is not None else {} # validate arguments - if not callable(f): + if not callable(f) and not isinstance(f, Reference): raise TypeError("f must be callable, not %s"%type(f)) if not isinstance(args, (tuple, list)): raise TypeError("args must be tuple or list, not %s"%type(args)) diff --git a/IPython/parallel/client/remotefunction.py b/IPython/parallel/client/remotefunction.py index f7850f6..ea47005 100644 --- a/IPython/parallel/client/remotefunction.py +++ b/IPython/parallel/client/remotefunction.py @@ -27,7 +27,7 @@ from . import map as Map from .asyncresult import AsyncMapResult #----------------------------------------------------------------------------- -# Decorators +# Functions and Decorators #----------------------------------------------------------------------------- @skip_doctest @@ -60,6 +60,25 @@ def parallel(view, dist='b', block=None, ordered=True, **flags): return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags) return parallel_function +def getname(f): + """Get the name of an object. + + For use in case of callables that are not functions, and + thus may not have __name__ defined. + + Order: f.__name__ > f.name > str(f) + """ + try: + return f.__name__ + except: + pass + try: + return f.name + except: + pass + + return str(f) + #-------------------------------------------------------------------------- # Classes #-------------------------------------------------------------------------- @@ -194,7 +213,7 @@ class ParallelFunction(RemoteFunction): msg_ids.append(ar.msg_ids[0]) r = AsyncMapResult(self.view.client, msg_ids, self.mapObject, - fname=self.func.__name__, + fname=getname(self.func), ordered=self.ordered ) diff --git a/IPython/parallel/client/view.py b/IPython/parallel/client/view.py index 92e876f..60388b1 100644 --- a/IPython/parallel/client/view.py +++ b/IPython/parallel/client/view.py @@ -34,7 +34,7 @@ from IPython.parallel.controller.dependency import Dependency, dependent from . import map as Map from .asyncresult import AsyncResult, AsyncMapResult -from .remotefunction import ParallelFunction, parallel, remote +from .remotefunction import ParallelFunction, parallel, remote, getname #----------------------------------------------------------------------------- # Decorators @@ -535,7 +535,7 @@ class DirectView(View): trackers.append(msg['tracker']) msg_ids.append(msg['header']['msg_id']) tracker = None if track is False else zmq.MessageTracker(*trackers) - ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker) + ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker) if block: try: return ar.get() @@ -990,7 +990,7 @@ class LoadBalancedView(View): subheader=subheader) tracker = None if track is False else msg['tracker'] - ar = AsyncResult(self.client, msg['header']['msg_id'], fname=f.__name__, targets=None, tracker=tracker) + ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker) if block: try: diff --git a/IPython/parallel/tests/test_view.py b/IPython/parallel/tests/test_view.py index 6e3e559..fa0a9c1 100644 --- a/IPython/parallel/tests/test_view.py +++ b/IPython/parallel/tests/test_view.py @@ -469,4 +469,25 @@ class TestView(ClusterTestCase): else: raise e + def test_map_reference(self): + """view.map(, *seqs) should work""" + v = self.client[:] + v.scatter('n', self.client.ids, flatten=True) + v.execute("f = lambda x,y: x*y") + rf = pmod.Reference('f') + nlist = list(range(10)) + mlist = nlist[::-1] + expected = [ m*n for m,n in zip(mlist, nlist) ] + result = v.map_sync(rf, mlist, nlist) + self.assertEquals(result, expected) + + def test_apply_reference(self): + """view.apply(, *args) should work""" + v = self.client[:] + v.scatter('n', self.client.ids, flatten=True) + v.execute("f = lambda x: n*x") + rf = pmod.Reference('f') + result = v.apply_sync(rf, 5) + expected = [ 5*id for id in self.client.ids ] + self.assertEquals(result, expected)