##// END OF EJS Templates
avoid unnecessary imports of numpy...
MinRK -
Show More
@@ -4,56 +4,33 b''
4 4
5 5 Scattering consists of partitioning a sequence and sending the various
6 6 pieces to individual nodes in a cluster.
7
8
9 Authors:
10
11 * Brian Granger
12 * MinRK
13
14 7 """
15 8
16 #-------------------------------------------------------------------------------
17 # Copyright (C) 2008-2011 The IPython Development Team
18 #
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
21 #-------------------------------------------------------------------------------
22
23 #-------------------------------------------------------------------------------
24 # Imports
25 #-------------------------------------------------------------------------------
9 # Copyright (c) IPython Development Team.
10 # Distributed under the terms of the Modified BSD License.
26 11
27 12 from __future__ import division
28 13
14 import sys
29 15 from itertools import islice
30 16
31 17 from IPython.utils.data import flatten as utils_flatten
32 18
33 #-------------------------------------------------------------------------------
34 # Figure out which array packages are present and their array types
35 #-------------------------------------------------------------------------------
36 19
37 arrayModules = []
38 try:
39 import Numeric
40 except ImportError:
41 pass
42 else:
43 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
44 try:
20 numpy = None
21
22 def is_array(obj):
23 """Is an object a numpy array?
24
25 Avoids importing numpy until it is requested
26 """
27 global numpy
28 if 'numpy' not in sys.modules:
29 return False
30
31 if numpy is None:
45 32 import numpy
46 except ImportError:
47 pass
48 else:
49 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
50 try:
51 import numarray
52 except ImportError:
53 pass
54 else:
55 arrayModules.append({'module':numarray,
56 'type':numarray.numarraycore.NumArray})
33 return isinstance(obj, numpy.ndarray)
57 34
58 35 class Map(object):
59 36 """A class for partitioning a sequence using a map."""
@@ -94,10 +71,8 b' class Map(object):'
94 71 def concatenate(self, listOfPartitions):
95 72 testObject = listOfPartitions[0]
96 73 # First see if we have a known array type
97 for m in arrayModules:
98 #print m
99 if isinstance(testObject, m['type']):
100 return m['module'].concatenate(listOfPartitions)
74 if is_array(testObject):
75 return numpy.concatenate(listOfPartitions)
101 76 # Next try for Python sequence types
102 77 if isinstance(testObject, (list, tuple)):
103 78 return utils_flatten(listOfPartitions)
@@ -117,19 +92,17 b' class RoundRobinMap(Map):'
117 92 def joinPartitions(self, listOfPartitions):
118 93 testObject = listOfPartitions[0]
119 94 # First see if we have a known array type
120 for m in arrayModules:
121 #print m
122 if isinstance(testObject, m['type']):
123 return self.flatten_array(m['type'], listOfPartitions)
95 if is_array(testObject):
96 return self.flatten_array(listOfPartitions)
124 97 if isinstance(testObject, (list, tuple)):
125 98 return self.flatten_list(listOfPartitions)
126 99 return listOfPartitions
127 100
128 def flatten_array(self, klass, listOfPartitions):
101 def flatten_array(self, listOfPartitions):
129 102 test = listOfPartitions[0]
130 103 shape = list(test.shape)
131 104 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
132 A = klass(shape)
105 A = numpy.ndarray(shape)
133 106 N = shape[0]
134 107 q = len(listOfPartitions)
135 108 for p,part in enumerate(listOfPartitions):
@@ -141,22 +114,12 b' class RoundRobinMap(Map):'
141 114 for i in range(len(listOfPartitions[0])):
142 115 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
143 116 return flat
144 #lengths = [len(x) for x in listOfPartitions]
145 #maxPartitionLength = len(listOfPartitions[0])
146 #numberOfPartitions = len(listOfPartitions)
147 #concat = self.concatenate(listOfPartitions)
148 #totalLength = len(concat)
149 #result = []
150 #for i in range(maxPartitionLength):
151 # result.append(concat[i:totalLength:maxPartitionLength])
152 # return self.concatenate(listOfPartitions)
153 117
154 118 def mappable(obj):
155 119 """return whether an object is mappable or not."""
156 120 if isinstance(obj, (tuple,list)):
157 121 return True
158 for m in arrayModules:
159 if isinstance(obj,m['type']):
122 if is_array(obj):
160 123 return True
161 124 return False
162 125
General Comments 0
You need to be logged in to leave comments. Login now