##// END OF EJS Templates
avoid unnecessary imports of numpy...
MinRK -
Show More
@@ -4,56 +4,33 b''
4
4
5 Scattering consists of partitioning a sequence and sending the various
5 Scattering consists of partitioning a sequence and sending the various
6 pieces to individual nodes in a cluster.
6 pieces to individual nodes in a cluster.
7
8
9 Authors:
10
11 * Brian Granger
12 * MinRK
13
14 """
7 """
15
8
16 #-------------------------------------------------------------------------------
9 # Copyright (c) IPython Development Team.
17 # Copyright (C) 2008-2011 The IPython Development Team
10 # Distributed under the terms of the Modified BSD License.
18 #
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
21 #-------------------------------------------------------------------------------
22
23 #-------------------------------------------------------------------------------
24 # Imports
25 #-------------------------------------------------------------------------------
26
11
27 from __future__ import division
12 from __future__ import division
28
13
14 import sys
29 from itertools import islice
15 from itertools import islice
30
16
31 from IPython.utils.data import flatten as utils_flatten
17 from IPython.utils.data import flatten as utils_flatten
32
18
33 #-------------------------------------------------------------------------------
19
34 # Figure out which array packages are present and their array types
20 numpy = None
35 #-------------------------------------------------------------------------------
21
36
22 def is_array(obj):
37 arrayModules = []
23 """Is an object a numpy array?
38 try:
24
39 import Numeric
25 Avoids importing numpy until it is requested
40 except ImportError:
26 """
41 pass
27 global numpy
42 else:
28 if 'numpy' not in sys.modules:
43 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
29 return False
44 try:
30
45 import numpy
31 if numpy is None:
46 except ImportError:
32 import numpy
47 pass
33 return isinstance(obj, numpy.ndarray)
48 else:
49 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
50 try:
51 import numarray
52 except ImportError:
53 pass
54 else:
55 arrayModules.append({'module':numarray,
56 'type':numarray.numarraycore.NumArray})
57
34
58 class Map(object):
35 class Map(object):
59 """A class for partitioning a sequence using a map."""
36 """A class for partitioning a sequence using a map."""
@@ -90,14 +67,12 b' class Map(object):'
90
67
91 def joinPartitions(self, listOfPartitions):
68 def joinPartitions(self, listOfPartitions):
92 return self.concatenate(listOfPartitions)
69 return self.concatenate(listOfPartitions)
93
70
94 def concatenate(self, listOfPartitions):
71 def concatenate(self, listOfPartitions):
95 testObject = listOfPartitions[0]
72 testObject = listOfPartitions[0]
96 # First see if we have a known array type
73 # First see if we have a known array type
97 for m in arrayModules:
74 if is_array(testObject):
98 #print m
75 return numpy.concatenate(listOfPartitions)
99 if isinstance(testObject, m['type']):
100 return m['module'].concatenate(listOfPartitions)
101 # Next try for Python sequence types
76 # Next try for Python sequence types
102 if isinstance(testObject, (list, tuple)):
77 if isinstance(testObject, (list, tuple)):
103 return utils_flatten(listOfPartitions)
78 return utils_flatten(listOfPartitions)
@@ -117,19 +92,17 b' class RoundRobinMap(Map):'
117 def joinPartitions(self, listOfPartitions):
92 def joinPartitions(self, listOfPartitions):
118 testObject = listOfPartitions[0]
93 testObject = listOfPartitions[0]
119 # First see if we have a known array type
94 # First see if we have a known array type
120 for m in arrayModules:
95 if is_array(testObject):
121 #print m
96 return self.flatten_array(listOfPartitions)
122 if isinstance(testObject, m['type']):
123 return self.flatten_array(m['type'], listOfPartitions)
124 if isinstance(testObject, (list, tuple)):
97 if isinstance(testObject, (list, tuple)):
125 return self.flatten_list(listOfPartitions)
98 return self.flatten_list(listOfPartitions)
126 return listOfPartitions
99 return listOfPartitions
127
100
128 def flatten_array(self, klass, listOfPartitions):
101 def flatten_array(self, listOfPartitions):
129 test = listOfPartitions[0]
102 test = listOfPartitions[0]
130 shape = list(test.shape)
103 shape = list(test.shape)
131 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
104 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
132 A = klass(shape)
105 A = numpy.ndarray(shape)
133 N = shape[0]
106 N = shape[0]
134 q = len(listOfPartitions)
107 q = len(listOfPartitions)
135 for p,part in enumerate(listOfPartitions):
108 for p,part in enumerate(listOfPartitions):
@@ -141,23 +114,13 b' class RoundRobinMap(Map):'
141 for i in range(len(listOfPartitions[0])):
114 for i in range(len(listOfPartitions[0])):
142 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
115 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
143 return flat
116 return flat
144 #lengths = [len(x) for x in listOfPartitions]
145 #maxPartitionLength = len(listOfPartitions[0])
146 #numberOfPartitions = len(listOfPartitions)
147 #concat = self.concatenate(listOfPartitions)
148 #totalLength = len(concat)
149 #result = []
150 #for i in range(maxPartitionLength):
151 # result.append(concat[i:totalLength:maxPartitionLength])
152 # return self.concatenate(listOfPartitions)
153
117
154 def mappable(obj):
118 def mappable(obj):
155 """return whether an object is mappable or not."""
119 """return whether an object is mappable or not."""
156 if isinstance(obj, (tuple,list)):
120 if isinstance(obj, (tuple,list)):
157 return True
121 return True
158 for m in arrayModules:
122 if is_array(obj):
159 if isinstance(obj,m['type']):
123 return True
160 return True
161 return False
124 return False
162
125
163 dists = {'b':Map,'r':RoundRobinMap}
126 dists = {'b':Map,'r':RoundRobinMap}
General Comments 0
You need to be logged in to leave comments. Login now