From 15b756705416d0276bd44f88d67f3022feb5e6a7 2011-04-08 00:38:13 From: MinRK Date: 2011-04-08 00:38:13 Subject: [PATCH] add map/scatter/gather/ParallelFunction from kernel --- diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index 39e40e6..2d40547 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -28,6 +28,7 @@ import streamsession as ss from view import DirectView, LoadBalancedView from dependency import Dependency, depend, require import error +import map as Map #-------------------------------------------------------------------------- # helpers for implementing old MEC API via client.apply @@ -92,6 +93,18 @@ def remote(client, bound=False, block=None, targets=None): return RemoteFunction(client, f, bound, block, targets) return remote_function +def parallel(client, dist='b', bound=False, block=None, targets='all'): + """Turn a function into a parallel remote function. + + This method can be used for map: + + >>> @parallel(client,block=True) + def func(a) + """ + def parallel_function(f): + return ParallelFunction(client, f, dist, bound, block, targets) + return parallel_function + #-------------------------------------------------------------------------- # Classes #-------------------------------------------------------------------------- @@ -133,6 +146,103 @@ class RemoteFunction(object): block=self.block, targets=self.targets, bound=self.bound) +class ParallelFunction(RemoteFunction): + """Class for mapping a function to sequences.""" + def __init__(self, client, f, dist='b', bound=False, block=None, targets='all'): + super(ParallelFunction, self).__init__(client,f,bound,block,targets) + mapClass = Map.dists[dist] + self.mapObject = mapClass() + + def __call__(self, *sequences): + len_0 = len(sequences[0]) + for s in sequences: + if len(s)!=len_0: + raise ValueError('all sequences must have equal length') + + if self.targets is None: + # load-balanced: + engines = [None]*len_0 + else: + # multiplexed: + engines = self.client._build_targets(self.targets)[-1] + + nparts = len(engines) + msg_ids = [] + for index, engineid in enumerate(engines): + args = [] + for seq in sequences: + args.append(self.mapObject.getPartition(seq, index, nparts)) + mid = self.client.apply(self.func, args=args, block=False, + bound=self.bound, + targets=engineid) + msg_ids.append(mid) + + if self.block: + dg = PendingMapResult(self.client, msg_ids, self.mapObject) + dg.wait() + return dg.result + else: + return dg + + +class PendingResult(object): + """Class for representing results of non-blocking calls.""" + def __init__(self, client, msg_ids): + self.client = client + self.msg_ids = msg_ids + self._result = None + self.done = False + + def __repr__(self): + if self.done: + return "<%s: finished>"%(self.__class__.__name__) + else: + return "<%s: %r>"%(self.__class__.__name__,self.msg_ids) + + @property + def result(self): + if self._result is not None: + return self._result + if not self.done: + self.wait(0) + if self.done: + results = map(self.client.results.get, self.msg_ids) + results = error.collect_exceptions(results, 'get_result') + self._result = self.reconstruct_result(results) + return self._result + else: + raise error.ResultNotCompleted + + def reconstruct_result(self, res): + """ + Override me in subclasses for turning a list of results + into the expected form. + """ + if len(res) == 1: + return res[0] + else: + return res + + def wait(self, timout=-1): + self.done = self.client.barrier(self.msg_ids) + return self.done + +class PendingMapResult(PendingResult): + """Class for representing results of non-blocking gathers. + + This will properly reconstruct the gather. + """ + + def __init__(self, client, msg_ids, mapObject): + self.mapObject = mapObject + PendingResult.__init__(self, client, msg_ids) + + def reconstruct_result(self, res): + """Perform the gather on the actual results.""" + return self.mapObject.joinPartitions(res) + + + class AbortedTask(object): """A basic wrapper object describing an aborted task.""" def __init__(self, msg_id): @@ -498,6 +608,17 @@ class Client(object): # Begin public methods #-------------------------------------------------------------------------- + @property + def remote(self): + """property for convenient RemoteFunction generation. + + >>> @client.remote + ... def f(): + import os + print (os.getpid()) + """ + return remote(self, block=self.block) + def spin(self): """Flush any registration notifications and execution results waiting in the ZMQ queue. @@ -784,7 +905,7 @@ class Client(object): self.barrier(msg_id) return self._maybe_raise(self.results[msg_id]) else: - return msg_id + return PendingResult(self, [msg_id]) def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None, after=None, follow=None): @@ -814,10 +935,7 @@ class Client(object): if block: self.barrier(msg_ids) else: - if len(msg_ids) == 1: - return msg_ids[0] - else: - return msg_ids + return PendingResult(self, msg_ids) if len(msg_ids) == 1: return self._maybe_raise(self.results[msg_ids[0]]) else: @@ -826,12 +944,17 @@ class Client(object): result[target] = self.results[mid] return error.collect_exceptions(result, f.__name__) + @defaultblock + def map(self, f, sequences, targets=None, block=None, bound=False): + pf = ParallelFunction(self,f,block=block,bound=bound,targets=targets) + return pf(*sequences) + #-------------------------------------------------------------------------- # Data movement #-------------------------------------------------------------------------- @defaultblock - def push(self, ns, targets=None, block=None): + def push(self, ns, targets='all', block=None): """Push the contents of `ns` into the namespace on `target`""" if not isinstance(ns, dict): raise TypeError("Must be a dict, not %s"%type(ns)) @@ -839,7 +962,7 @@ class Client(object): return result @defaultblock - def pull(self, keys, targets=None, block=True): + def pull(self, keys, targets='all', block=True): """Pull objects from `target`'s namespace by `keys`""" if isinstance(keys, str): pass @@ -850,6 +973,48 @@ class Client(object): result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True) return result + @defaultblock + def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None): + """ + Partition a Python sequence and send the partitions to a set of engines. + """ + targets = self._build_targets(targets)[-1] + mapObject = Map.dists[dist]() + nparts = len(targets) + msg_ids = [] + for index, engineid in enumerate(targets): + partition = mapObject.getPartition(seq, index, nparts) + if flatten and len(partition) == 1: + mid = self.push({key: partition[0]}, targets=engineid, block=False) + else: + mid = self.push({key: partition}, targets=engineid, block=False) + msg_ids.append(mid) + r = PendingResult(self, msg_ids) + if block: + r.wait() + return + else: + return r + + @defaultblock + def gather(self, key, dist='b', targets='all', block=True): + """ + Gather a partitioned sequence on a set of engines as a single local seq. + """ + + targets = self._build_targets(targets)[-1] + mapObject = Map.dists[dist]() + msg_ids = [] + for index, engineid in enumerate(targets): + msg_ids.append(self.pull(key, targets=engineid,block=False)) + + r = PendingMapResult(self, msg_ids, mapObject) + if block: + r.wait() + return r.result + else: + return r + #-------------------------------------------------------------------------- # Query methods #-------------------------------------------------------------------------- @@ -985,4 +1150,16 @@ class AsynClient(Client): for stream in (self.queue_stream, self.notifier_stream, self.task_stream, self.control_stream): stream.flush() - + +__all__ = [ 'Client', + 'depend', + 'require', + 'remote', + 'parallel', + 'RemoteFunction', + 'ParallelFunction', + 'DirectView', + 'LoadBalancedView', + 'PendingResult', + 'PendingMapResult' + ] diff --git a/IPython/zmq/parallel/error.py b/IPython/zmq/parallel/error.py index 9f1a735..14ac61e 100644 --- a/IPython/zmq/parallel/error.py +++ b/IPython/zmq/parallel/error.py @@ -247,11 +247,15 @@ class CompositeError(KernelError): et,ev,tb = sys.exc_info() -def collect_exceptions(rdict, method): +def collect_exceptions(rdict_or_list, method): """check a result dict for errors, and raise CompositeError if any exist. Passthrough otherwise.""" elist = [] - for r in rdict.values(): + if isinstance(rdict_or_list, dict): + rlist = rdict_or_list.values() + else: + rlist = rdict_or_list + for r in rlist: if isinstance(r, RemoteError): en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info # Sometimes we could have CompositeError in our list. Just take @@ -264,7 +268,7 @@ def collect_exceptions(rdict, method): else: elist.append((en, ev, etb, ei)) if len(elist)==0: - return rdict + return rdict_or_list else: msg = "one or more exceptions from call to method: %s" % (method) # This silliness is needed so the debugger has access to the exception diff --git a/IPython/zmq/parallel/map.py b/IPython/zmq/parallel/map.py new file mode 100644 index 0000000..c2c7b2f --- /dev/null +++ b/IPython/zmq/parallel/map.py @@ -0,0 +1,158 @@ +# encoding: utf-8 + +"""Classes used in scattering and gathering sequences. + +Scattering consists of partitioning a sequence and sending the various +pieces to individual nodes in a cluster. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import types + +from IPython.utils.data import flatten as utils_flatten + +#------------------------------------------------------------------------------- +# Figure out which array packages are present and their array types +#------------------------------------------------------------------------------- + +arrayModules = [] +try: + import Numeric +except ImportError: + pass +else: + arrayModules.append({'module':Numeric, 'type':Numeric.arraytype}) +try: + import numpy +except ImportError: + pass +else: + arrayModules.append({'module':numpy, 'type':numpy.ndarray}) +try: + import numarray +except ImportError: + pass +else: + arrayModules.append({'module':numarray, + 'type':numarray.numarraycore.NumArray}) + +class Map: + """A class for partitioning a sequence using a map.""" + + def getPartition(self, seq, p, q): + """Returns the pth partition of q partitions of seq.""" + + # Test for error conditions here + if p<0 or p>=q: + print "No partition exists." + return + + remainder = len(seq)%q + basesize = len(seq)/q + hi = [] + lo = [] + for n in range(q): + if n < remainder: + lo.append(n * (basesize + 1)) + hi.append(lo[-1] + basesize + 1) + else: + lo.append(n*basesize + remainder) + hi.append(lo[-1] + basesize) + + + result = seq[lo[p]:hi[p]] + return result + + def joinPartitions(self, listOfPartitions): + return self.concatenate(listOfPartitions) + + def concatenate(self, listOfPartitions): + testObject = listOfPartitions[0] + # First see if we have a known array type + for m in arrayModules: + #print m + if isinstance(testObject, m['type']): + return m['module'].concatenate(listOfPartitions) + # Next try for Python sequence types + if isinstance(testObject, (types.ListType, types.TupleType)): + return utils_flatten(listOfPartitions) + # If we have scalars, just return listOfPartitions + return listOfPartitions + +class RoundRobinMap(Map): + """Partitions a sequence in a roun robin fashion. + + This currently does not work! + """ + + def getPartition(self, seq, p, q): + # if not isinstance(seq,(list,tuple)): + # raise NotImplementedError("cannot RR partition type %s"%type(seq)) + return seq[p:len(seq):q] + #result = [] + #for i in range(p,len(seq),q): + # result.append(seq[i]) + #return result + + def joinPartitions(self, listOfPartitions): + testObject = listOfPartitions[0] + # First see if we have a known array type + for m in arrayModules: + #print m + if isinstance(testObject, m['type']): + return self.flatten_array(m['type'], listOfPartitions) + if isinstance(testObject, (types.ListType, types.TupleType)): + return self.flatten_list(listOfPartitions) + return listOfPartitions + + def flatten_array(self, klass, listOfPartitions): + test = listOfPartitions[0] + shape = list(test.shape) + shape[0] = sum([ p.shape[0] for p in listOfPartitions]) + A = klass(shape) + N = shape[0] + q = len(listOfPartitions) + for p,part in enumerate(listOfPartitions): + A[p:N:q] = part + return A + + def flatten_list(self, listOfPartitions): + flat = [] + for i in range(len(listOfPartitions[0])): + flat.extend([ part[i] for part in listOfPartitions if len(part) > i ]) + return flat + #lengths = [len(x) for x in listOfPartitions] + #maxPartitionLength = len(listOfPartitions[0]) + #numberOfPartitions = len(listOfPartitions) + #concat = self.concatenate(listOfPartitions) + #totalLength = len(concat) + #result = [] + #for i in range(maxPartitionLength): + # result.append(concat[i:totalLength:maxPartitionLength]) + # return self.concatenate(listOfPartitions) + +def mappable(obj): + """return whether an object is mappable or not.""" + if isinstance(obj, (tuple,list)): + return True + for m in arrayModules: + if isinstance(obj,m['type']): + return True + return False + +dists = {'b':Map,'r':RoundRobinMap} + + + diff --git a/IPython/zmq/parallel/view.py b/IPython/zmq/parallel/view.py index 081c14e..382991c 100644 --- a/IPython/zmq/parallel/view.py +++ b/IPython/zmq/parallel/view.py @@ -228,6 +228,27 @@ class DirectView(View): block = block if block is not None else self.block return self.client.pull(key_s, block=block, targets=self.targets) + def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None): + """ + Partition a Python sequence and send the partitions to a set of engines. + """ + block = block if block is not None else self.block + if targets is None: + targets = self.targets + + return self.client.scatter(key, seq, dist=dist, flatten=flatten, + targets=targets, block=block) + + def gather(self, key, dist='b', targets=None, block=True): + """ + Gather a partitioned sequence on a set of engines as a single local seq. + """ + block = block if block is not None else self.block + if targets is None: + targets = self.targets + + return self.client.gather(key, dist=dist, targets=targets, block=block) + def __getitem__(self, key): return self.get(key)