diff --git a/IPython/parallel/client/map.py b/IPython/parallel/client/map.py index f9ee3a3..af2d863 100644 --- a/IPython/parallel/client/map.py +++ b/IPython/parallel/client/map.py @@ -59,17 +59,19 @@ else: class Map(object): """A class for partitioning a sequence using a map.""" - def getPartition(self, seq, p, q): - """Returns the pth partition of q partitions of seq.""" + def getPartition(self, seq, p, q, n=None): + """Returns the pth partition of q partitions of seq. + The length can be specified as `n`, + otherwise it is the value of `len(seq)` + """ + n = len(seq) if n is None else n # Test for error conditions here if p<0 or p>=q: - print "No partition exists." - return + raise ValueError("must have 0 <= p <= q, but have p=%s,q=%s" % (p, q)) - N = len(seq) - remainder = N % q - basesize = N // q + remainder = n % q + basesize = n // q if p < remainder: low = p * (basesize + 1) @@ -104,19 +106,14 @@ class Map(object): return listOfPartitions class RoundRobinMap(Map): - """Partitions a sequence in a roun robin fashion. + """Partitions a sequence in a round robin fashion. This currently does not work! """ - def getPartition(self, seq, p, q): - # if not isinstance(seq,(list,tuple)): - # raise NotImplementedError("cannot RR partition type %s"%type(seq)) - return seq[p:len(seq):q] - #result = [] - #for i in range(p,len(seq),q): - # result.append(seq[i]) - #return result + def getPartition(self, seq, p, q, n=None): + n = len(seq) if n is None else n + return seq[p:n:q] def joinPartitions(self, listOfPartitions): testObject = listOfPartitions[0] diff --git a/IPython/parallel/client/remotefunction.py b/IPython/parallel/client/remotefunction.py index 2fc2239..07e9327 100644 --- a/IPython/parallel/client/remotefunction.py +++ b/IPython/parallel/client/remotefunction.py @@ -161,13 +161,15 @@ class ParallelFunction(RemoteFunction): chunksize : int or None The size of chunk to use when breaking up sequences in a load-balanced manner ordered : bool [default: True] - Whether + Whether the result should be kept in order. If False, + results become available as they arrive, regardless of submission order. **flags : remaining kwargs are passed to View.temp_flags """ - chunksize=None - ordered=None - mapObject=None + chunksize = None + ordered = None + mapObject = None + _mapping = False def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags): super(ParallelFunction, self).__init__(view, f, block=block, **flags) @@ -176,23 +178,40 @@ class ParallelFunction(RemoteFunction): mapClass = Map.dists[dist] self.mapObject = mapClass() - + @sync_view_results def __call__(self, *sequences): client = self.view.client + lens = [] + maxlen = minlen = -1 + for i, seq in enumerate(sequences): + try: + n = len(seq) + except Exception: + seq = list(seq) + if isinstance(sequences, tuple): + # can't alter a tuple + sequences = list(sequences) + sequences[i] = seq + n = len(seq) + if n > maxlen: + maxlen = n + if minlen == -1 or n < minlen: + minlen = n + lens.append(n) + # check that the length of sequences match - len_0 = len(sequences[0]) - for s in sequences: - if len(s)!=len_0: - msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s)) - raise ValueError(msg) + if not self._mapping and minlen != maxlen: + msg = 'all sequences must have equal length, but have %s' % lens + raise ValueError(msg) + balanced = 'Balanced' in self.view.__class__.__name__ if balanced: if self.chunksize: - nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0) + nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0) else: - nparts = len_0 + nparts = maxlen targets = [None]*nparts else: if self.chunksize: @@ -211,21 +230,17 @@ class ParallelFunction(RemoteFunction): for index, t in enumerate(targets): args = [] for seq in sequences: - part = self.mapObject.getPartition(seq, index, nparts) - if len(part) == 0: - continue - else: - args.append(part) - if not args: + part = self.mapObject.getPartition(seq, index, nparts, maxlen) + args.append(part) + if not any(args): continue - # print (args) - if hasattr(self, '_map'): + if self._mapping: if sys.version_info[0] >= 3: f = lambda f, *sequences: list(map(f, *sequences)) else: f = map - args = [self.func]+args + args = [self.func] + args else: f=self.func @@ -233,9 +248,9 @@ class ParallelFunction(RemoteFunction): with view.temp_flags(block=False, **self.flags): ar = view.apply(f, *args) - msg_ids.append(ar.msg_ids[0]) + msg_ids.extend(ar.msg_ids) - r = AsyncMapResult(self.view.client, msg_ids, self.mapObject, + r = AsyncMapResult(self.view.client, msg_ids, self.mapObject, fname=getname(self.func), ordered=self.ordered ) @@ -249,16 +264,19 @@ class ParallelFunction(RemoteFunction): return r def map(self, *sequences): - """call a function on each element of a sequence remotely. + """call a function on each element of one or more sequence(s) remotely. This should behave very much like the builtin map, but return an AsyncMapResult if self.block is False. + + That means it can take generators (will be cast to lists locally), + and mismatched sequence lengths will be padded with None. """ - # set _map as a flag for use inside self.__call__ - self._map = True + # set _mapping as a flag for use inside self.__call__ + self._mapping = True try: - ret = self.__call__(*sequences) + ret = self(*sequences) finally: - del self._map + self._mapping = False return ret __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction'] diff --git a/IPython/parallel/tests/test_lbview.py b/IPython/parallel/tests/test_lbview.py index 4b04520..dd8f57a 100644 --- a/IPython/parallel/tests/test_lbview.py +++ b/IPython/parallel/tests/test_lbview.py @@ -58,6 +58,40 @@ class TestLoadBalancedView(ClusterTestCase): data = range(16) r = self.view.map_sync(f, data) self.assertEqual(r, map(f, data)) + + def test_map_generator(self): + def f(x): + return x**2 + + data = range(16) + r = self.view.map_sync(f, iter(data)) + self.assertEqual(r, map(f, iter(data))) + + def test_map_short_first(self): + def f(x,y): + if y is None: + return y + if x is None: + return x + return x*y + data = range(10) + data2 = range(4) + + r = self.view.map_sync(f, data, data2) + self.assertEqual(r, map(f, data, data2)) + + def test_map_short_last(self): + def f(x,y): + if y is None: + return y + if x is None: + return x + return x*y + data = range(4) + data2 = range(10) + + r = self.view.map_sync(f, data, data2) + self.assertEqual(r, map(f, data, data2)) def test_map_unordered(self): def f(x):