From f0f04372539073e0ae9d48b908fa905a0f741ac0 2013-08-14 03:24:46 From: Samuel Ainsworth Date: 2013-08-14 03:24:46 Subject: [PATCH] Fix parallel.client.View map() on numpy arrays --- diff --git a/IPython/parallel/client/remotefunction.py b/IPython/parallel/client/remotefunction.py index 07e9327..4b31a0b 100644 --- a/IPython/parallel/client/remotefunction.py +++ b/IPython/parallel/client/remotefunction.py @@ -232,7 +232,8 @@ class ParallelFunction(RemoteFunction): for seq in sequences: part = self.mapObject.getPartition(seq, index, nparts, maxlen) args.append(part) - if not any(args): + + if sum([len(arg) for arg in args]) == 0: continue if self._mapping: diff --git a/IPython/parallel/tests/test_view.py b/IPython/parallel/tests/test_view.py index 7a3540e..b44994b 100644 --- a/IPython/parallel/tests/test_view.py +++ b/IPython/parallel/tests/test_view.py @@ -344,6 +344,18 @@ class TestView(ClusterTestCase, ParametricTestCase): it = iter(arr) r = view.map_sync(lambda x:x, arr) self.assertEqual(r, list(arr)) + + @skip_without('numpy') + def test_map_numpy(self): + """test map on numpy arrays (direct)""" + import numpy + from numpy.testing.utils import assert_array_equal + + view = self.client[:] + # 101 is prime, so it won't be evenly distributed + arr = numpy.arange(101) + r = view.map_sync(lambda x: x, arr) + assert_array_equal(r, arr) def test_scatter_gather_nonblocking(self): data = range(16)