Show More
map.py
129 lines
| 3.7 KiB
| text/x-python
|
PythonLexer
MinRK
|
r3587 | # encoding: utf-8 | ||
"""Classes used in scattering and gathering sequences. | ||||
Scattering consists of partitioning a sequence and sending the various | ||||
pieces to individual nodes in a cluster. | ||||
""" | ||||
MinRK
|
r17053 | # Copyright (c) IPython Development Team. | ||
# Distributed under the terms of the Modified BSD License. | ||||
MinRK
|
r3587 | |||
MinRK
|
r4155 | from __future__ import division | ||
MinRK
|
r17053 | import sys | ||
MinRK
|
r5560 | from itertools import islice | ||
MinRK
|
r3587 | |||
from IPython.utils.data import flatten as utils_flatten | ||||
MinRK
|
r17053 | |||
numpy = None | ||||
def is_array(obj): | ||||
"""Is an object a numpy array? | ||||
Avoids importing numpy until it is requested | ||||
""" | ||||
global numpy | ||||
if 'numpy' not in sys.modules: | ||||
return False | ||||
if numpy is None: | ||||
import numpy | ||||
return isinstance(obj, numpy.ndarray) | ||||
MinRK
|
r3587 | |||
MinRK
|
r10072 | class Map(object): | ||
MinRK
|
r3587 | """A class for partitioning a sequence using a map.""" | ||
MinRK
|
r10072 | |||
MinRK
|
r10567 | def getPartition(self, seq, p, q, n=None): | ||
"""Returns the pth partition of q partitions of seq. | ||||
MinRK
|
r3587 | |||
MinRK
|
r10567 | The length can be specified as `n`, | ||
otherwise it is the value of `len(seq)` | ||||
""" | ||||
n = len(seq) if n is None else n | ||||
MinRK
|
r3587 | # Test for error conditions here | ||
if p<0 or p>=q: | ||||
MinRK
|
r10567 | raise ValueError("must have 0 <= p <= q, but have p=%s,q=%s" % (p, q)) | ||
MinRK
|
r10072 | |||
MinRK
|
r10567 | remainder = n % q | ||
basesize = n // q | ||||
MinRK
|
r10072 | |||
if p < remainder: | ||||
low = p * (basesize + 1) | ||||
high = low + basesize + 1 | ||||
else: | ||||
low = p * basesize + remainder | ||||
high = low + basesize | ||||
MinRK
|
r3587 | |||
MinRK
|
r5560 | try: | ||
MinRK
|
r10072 | result = seq[low:high] | ||
MinRK
|
r5560 | except TypeError: | ||
# some objects (iterators) can't be sliced, | ||||
# use islice: | ||||
MinRK
|
r10072 | result = list(islice(seq, low, high)) | ||
MinRK
|
r5560 | |||
MinRK
|
r3587 | return result | ||
def joinPartitions(self, listOfPartitions): | ||||
return self.concatenate(listOfPartitions) | ||||
MinRK
|
r17053 | |||
MinRK
|
r3587 | def concatenate(self, listOfPartitions): | ||
testObject = listOfPartitions[0] | ||||
# First see if we have a known array type | ||||
MinRK
|
r17053 | if is_array(testObject): | ||
return numpy.concatenate(listOfPartitions) | ||||
MinRK
|
r3587 | # Next try for Python sequence types | ||
Thomas Kluyver
|
r13399 | if isinstance(testObject, (list, tuple)): | ||
MinRK
|
r3587 | return utils_flatten(listOfPartitions) | ||
# If we have scalars, just return listOfPartitions | ||||
return listOfPartitions | ||||
class RoundRobinMap(Map): | ||||
MinRK
|
r10567 | """Partitions a sequence in a round robin fashion. | ||
MinRK
|
r3587 | |||
This currently does not work! | ||||
""" | ||||
MinRK
|
r10567 | def getPartition(self, seq, p, q, n=None): | ||
n = len(seq) if n is None else n | ||||
return seq[p:n:q] | ||||
MinRK
|
r3587 | |||
def joinPartitions(self, listOfPartitions): | ||||
testObject = listOfPartitions[0] | ||||
# First see if we have a known array type | ||||
MinRK
|
r17053 | if is_array(testObject): | ||
return self.flatten_array(listOfPartitions) | ||||
Thomas Kluyver
|
r13399 | if isinstance(testObject, (list, tuple)): | ||
MinRK
|
r3587 | return self.flatten_list(listOfPartitions) | ||
return listOfPartitions | ||||
MinRK
|
r17053 | def flatten_array(self, listOfPartitions): | ||
MinRK
|
r3587 | test = listOfPartitions[0] | ||
shape = list(test.shape) | ||||
shape[0] = sum([ p.shape[0] for p in listOfPartitions]) | ||||
MinRK
|
r17053 | A = numpy.ndarray(shape) | ||
MinRK
|
r3587 | N = shape[0] | ||
q = len(listOfPartitions) | ||||
for p,part in enumerate(listOfPartitions): | ||||
A[p:N:q] = part | ||||
return A | ||||
def flatten_list(self, listOfPartitions): | ||||
flat = [] | ||||
for i in range(len(listOfPartitions[0])): | ||||
flat.extend([ part[i] for part in listOfPartitions if len(part) > i ]) | ||||
return flat | ||||
def mappable(obj): | ||||
"""return whether an object is mappable or not.""" | ||||
if isinstance(obj, (tuple,list)): | ||||
return True | ||||
MinRK
|
r17053 | if is_array(obj): | ||
return True | ||||
MinRK
|
r3587 | return False | ||
dists = {'b':Map,'r':RoundRobinMap} | ||||