diff --git a/IPython/parallel/client/map.py b/IPython/parallel/client/map.py index 1fc6d65..67b5b29 100644 --- a/IPython/parallel/client/map.py +++ b/IPython/parallel/client/map.py @@ -4,56 +4,33 @@ Scattering consists of partitioning a sequence and sending the various pieces to individual nodes in a cluster. - - -Authors: - -* Brian Granger -* MinRK - """ -#------------------------------------------------------------------------------- -# Copyright (C) 2008-2011 The IPython Development Team -# -# Distributed under the terms of the BSD License. The full license is in -# the file COPYING, distributed as part of this software. -#------------------------------------------------------------------------------- - -#------------------------------------------------------------------------------- -# Imports -#------------------------------------------------------------------------------- +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. from __future__ import division +import sys from itertools import islice from IPython.utils.data import flatten as utils_flatten -#------------------------------------------------------------------------------- -# Figure out which array packages are present and their array types -#------------------------------------------------------------------------------- - -arrayModules = [] -try: - import Numeric -except ImportError: - pass -else: - arrayModules.append({'module':Numeric, 'type':Numeric.arraytype}) -try: - import numpy -except ImportError: - pass -else: - arrayModules.append({'module':numpy, 'type':numpy.ndarray}) -try: - import numarray -except ImportError: - pass -else: - arrayModules.append({'module':numarray, - 'type':numarray.numarraycore.NumArray}) + +numpy = None + +def is_array(obj): + """Is an object a numpy array? + + Avoids importing numpy until it is requested + """ + global numpy + if 'numpy' not in sys.modules: + return False + + if numpy is None: + import numpy + return isinstance(obj, numpy.ndarray) class Map(object): """A class for partitioning a sequence using a map.""" @@ -90,14 +67,12 @@ class Map(object): def joinPartitions(self, listOfPartitions): return self.concatenate(listOfPartitions) - + def concatenate(self, listOfPartitions): testObject = listOfPartitions[0] # First see if we have a known array type - for m in arrayModules: - #print m - if isinstance(testObject, m['type']): - return m['module'].concatenate(listOfPartitions) + if is_array(testObject): + return numpy.concatenate(listOfPartitions) # Next try for Python sequence types if isinstance(testObject, (list, tuple)): return utils_flatten(listOfPartitions) @@ -117,19 +92,17 @@ class RoundRobinMap(Map): def joinPartitions(self, listOfPartitions): testObject = listOfPartitions[0] # First see if we have a known array type - for m in arrayModules: - #print m - if isinstance(testObject, m['type']): - return self.flatten_array(m['type'], listOfPartitions) + if is_array(testObject): + return self.flatten_array(listOfPartitions) if isinstance(testObject, (list, tuple)): return self.flatten_list(listOfPartitions) return listOfPartitions - def flatten_array(self, klass, listOfPartitions): + def flatten_array(self, listOfPartitions): test = listOfPartitions[0] shape = list(test.shape) shape[0] = sum([ p.shape[0] for p in listOfPartitions]) - A = klass(shape) + A = numpy.ndarray(shape) N = shape[0] q = len(listOfPartitions) for p,part in enumerate(listOfPartitions): @@ -141,23 +114,13 @@ class RoundRobinMap(Map): for i in range(len(listOfPartitions[0])): flat.extend([ part[i] for part in listOfPartitions if len(part) > i ]) return flat - #lengths = [len(x) for x in listOfPartitions] - #maxPartitionLength = len(listOfPartitions[0]) - #numberOfPartitions = len(listOfPartitions) - #concat = self.concatenate(listOfPartitions) - #totalLength = len(concat) - #result = [] - #for i in range(maxPartitionLength): - # result.append(concat[i:totalLength:maxPartitionLength]) - # return self.concatenate(listOfPartitions) def mappable(obj): """return whether an object is mappable or not.""" if isinstance(obj, (tuple,list)): return True - for m in arrayModules: - if isinstance(obj,m['type']): - return True + if is_array(obj): + return True return False dists = {'b':Map,'r':RoundRobinMap}