##// END OF EJS Templates
avoid unnecessary imports of numpy...
MinRK -
Show More
@@ -1,166 +1,129 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Classes used in scattering and gathering sequences.
3 """Classes used in scattering and gathering sequences.
4
4
5 Scattering consists of partitioning a sequence and sending the various
5 Scattering consists of partitioning a sequence and sending the various
6 pieces to individual nodes in a cluster.
6 pieces to individual nodes in a cluster.
7
8
9 Authors:
10
11 * Brian Granger
12 * MinRK
13
14 """
7 """
15
8
16 #-------------------------------------------------------------------------------
9 # Copyright (c) IPython Development Team.
17 # Copyright (C) 2008-2011 The IPython Development Team
10 # Distributed under the terms of the Modified BSD License.
18 #
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
21 #-------------------------------------------------------------------------------
22
23 #-------------------------------------------------------------------------------
24 # Imports
25 #-------------------------------------------------------------------------------
26
11
27 from __future__ import division
12 from __future__ import division
28
13
14 import sys
29 from itertools import islice
15 from itertools import islice
30
16
31 from IPython.utils.data import flatten as utils_flatten
17 from IPython.utils.data import flatten as utils_flatten
32
18
33 #-------------------------------------------------------------------------------
19
34 # Figure out which array packages are present and their array types
20 numpy = None
35 #-------------------------------------------------------------------------------
21
36
22 def is_array(obj):
37 arrayModules = []
23 """Is an object a numpy array?
38 try:
24
39 import Numeric
25 Avoids importing numpy until it is requested
40 except ImportError:
26 """
41 pass
27 global numpy
42 else:
28 if 'numpy' not in sys.modules:
43 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
29 return False
44 try:
30
45 import numpy
31 if numpy is None:
46 except ImportError:
32 import numpy
47 pass
33 return isinstance(obj, numpy.ndarray)
48 else:
49 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
50 try:
51 import numarray
52 except ImportError:
53 pass
54 else:
55 arrayModules.append({'module':numarray,
56 'type':numarray.numarraycore.NumArray})
57
34
58 class Map(object):
35 class Map(object):
59 """A class for partitioning a sequence using a map."""
36 """A class for partitioning a sequence using a map."""
60
37
61 def getPartition(self, seq, p, q, n=None):
38 def getPartition(self, seq, p, q, n=None):
62 """Returns the pth partition of q partitions of seq.
39 """Returns the pth partition of q partitions of seq.
63
40
64 The length can be specified as `n`,
41 The length can be specified as `n`,
65 otherwise it is the value of `len(seq)`
42 otherwise it is the value of `len(seq)`
66 """
43 """
67 n = len(seq) if n is None else n
44 n = len(seq) if n is None else n
68 # Test for error conditions here
45 # Test for error conditions here
69 if p<0 or p>=q:
46 if p<0 or p>=q:
70 raise ValueError("must have 0 <= p <= q, but have p=%s,q=%s" % (p, q))
47 raise ValueError("must have 0 <= p <= q, but have p=%s,q=%s" % (p, q))
71
48
72 remainder = n % q
49 remainder = n % q
73 basesize = n // q
50 basesize = n // q
74
51
75 if p < remainder:
52 if p < remainder:
76 low = p * (basesize + 1)
53 low = p * (basesize + 1)
77 high = low + basesize + 1
54 high = low + basesize + 1
78 else:
55 else:
79 low = p * basesize + remainder
56 low = p * basesize + remainder
80 high = low + basesize
57 high = low + basesize
81
58
82 try:
59 try:
83 result = seq[low:high]
60 result = seq[low:high]
84 except TypeError:
61 except TypeError:
85 # some objects (iterators) can't be sliced,
62 # some objects (iterators) can't be sliced,
86 # use islice:
63 # use islice:
87 result = list(islice(seq, low, high))
64 result = list(islice(seq, low, high))
88
65
89 return result
66 return result
90
67
91 def joinPartitions(self, listOfPartitions):
68 def joinPartitions(self, listOfPartitions):
92 return self.concatenate(listOfPartitions)
69 return self.concatenate(listOfPartitions)
93
70
94 def concatenate(self, listOfPartitions):
71 def concatenate(self, listOfPartitions):
95 testObject = listOfPartitions[0]
72 testObject = listOfPartitions[0]
96 # First see if we have a known array type
73 # First see if we have a known array type
97 for m in arrayModules:
74 if is_array(testObject):
98 #print m
75 return numpy.concatenate(listOfPartitions)
99 if isinstance(testObject, m['type']):
100 return m['module'].concatenate(listOfPartitions)
101 # Next try for Python sequence types
76 # Next try for Python sequence types
102 if isinstance(testObject, (list, tuple)):
77 if isinstance(testObject, (list, tuple)):
103 return utils_flatten(listOfPartitions)
78 return utils_flatten(listOfPartitions)
104 # If we have scalars, just return listOfPartitions
79 # If we have scalars, just return listOfPartitions
105 return listOfPartitions
80 return listOfPartitions
106
81
107 class RoundRobinMap(Map):
82 class RoundRobinMap(Map):
108 """Partitions a sequence in a round robin fashion.
83 """Partitions a sequence in a round robin fashion.
109
84
110 This currently does not work!
85 This currently does not work!
111 """
86 """
112
87
113 def getPartition(self, seq, p, q, n=None):
88 def getPartition(self, seq, p, q, n=None):
114 n = len(seq) if n is None else n
89 n = len(seq) if n is None else n
115 return seq[p:n:q]
90 return seq[p:n:q]
116
91
117 def joinPartitions(self, listOfPartitions):
92 def joinPartitions(self, listOfPartitions):
118 testObject = listOfPartitions[0]
93 testObject = listOfPartitions[0]
119 # First see if we have a known array type
94 # First see if we have a known array type
120 for m in arrayModules:
95 if is_array(testObject):
121 #print m
96 return self.flatten_array(listOfPartitions)
122 if isinstance(testObject, m['type']):
123 return self.flatten_array(m['type'], listOfPartitions)
124 if isinstance(testObject, (list, tuple)):
97 if isinstance(testObject, (list, tuple)):
125 return self.flatten_list(listOfPartitions)
98 return self.flatten_list(listOfPartitions)
126 return listOfPartitions
99 return listOfPartitions
127
100
128 def flatten_array(self, klass, listOfPartitions):
101 def flatten_array(self, listOfPartitions):
129 test = listOfPartitions[0]
102 test = listOfPartitions[0]
130 shape = list(test.shape)
103 shape = list(test.shape)
131 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
104 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
132 A = klass(shape)
105 A = numpy.ndarray(shape)
133 N = shape[0]
106 N = shape[0]
134 q = len(listOfPartitions)
107 q = len(listOfPartitions)
135 for p,part in enumerate(listOfPartitions):
108 for p,part in enumerate(listOfPartitions):
136 A[p:N:q] = part
109 A[p:N:q] = part
137 return A
110 return A
138
111
139 def flatten_list(self, listOfPartitions):
112 def flatten_list(self, listOfPartitions):
140 flat = []
113 flat = []
141 for i in range(len(listOfPartitions[0])):
114 for i in range(len(listOfPartitions[0])):
142 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
115 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
143 return flat
116 return flat
144 #lengths = [len(x) for x in listOfPartitions]
145 #maxPartitionLength = len(listOfPartitions[0])
146 #numberOfPartitions = len(listOfPartitions)
147 #concat = self.concatenate(listOfPartitions)
148 #totalLength = len(concat)
149 #result = []
150 #for i in range(maxPartitionLength):
151 # result.append(concat[i:totalLength:maxPartitionLength])
152 # return self.concatenate(listOfPartitions)
153
117
154 def mappable(obj):
118 def mappable(obj):
155 """return whether an object is mappable or not."""
119 """return whether an object is mappable or not."""
156 if isinstance(obj, (tuple,list)):
120 if isinstance(obj, (tuple,list)):
157 return True
121 return True
158 for m in arrayModules:
122 if is_array(obj):
159 if isinstance(obj,m['type']):
123 return True
160 return True
161 return False
124 return False
162
125
163 dists = {'b':Map,'r':RoundRobinMap}
126 dists = {'b':Map,'r':RoundRobinMap}
164
127
165
128
166
129
General Comments 0
You need to be logged in to leave comments. Login now