##// END OF EJS Templates
don't create lists...
MinRK -
Show More
@@ -1,171 +1,170
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Classes used in scattering and gathering sequences.
3 """Classes used in scattering and gathering sequences.
4
4
5 Scattering consists of partitioning a sequence and sending the various
5 Scattering consists of partitioning a sequence and sending the various
6 pieces to individual nodes in a cluster.
6 pieces to individual nodes in a cluster.
7
7
8
8
9 Authors:
9 Authors:
10
10
11 * Brian Granger
11 * Brian Granger
12 * MinRK
12 * MinRK
13
13
14 """
14 """
15
15
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17 # Copyright (C) 2008-2011 The IPython Development Team
17 # Copyright (C) 2008-2011 The IPython Development Team
18 #
18 #
19 # Distributed under the terms of the BSD License. The full license is in
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
20 # the file COPYING, distributed as part of this software.
21 #-------------------------------------------------------------------------------
21 #-------------------------------------------------------------------------------
22
22
23 #-------------------------------------------------------------------------------
23 #-------------------------------------------------------------------------------
24 # Imports
24 # Imports
25 #-------------------------------------------------------------------------------
25 #-------------------------------------------------------------------------------
26
26
27 from __future__ import division
27 from __future__ import division
28
28
29 import types
29 import types
30 from itertools import islice
30 from itertools import islice
31
31
32 from IPython.utils.data import flatten as utils_flatten
32 from IPython.utils.data import flatten as utils_flatten
33
33
34 #-------------------------------------------------------------------------------
34 #-------------------------------------------------------------------------------
35 # Figure out which array packages are present and their array types
35 # Figure out which array packages are present and their array types
36 #-------------------------------------------------------------------------------
36 #-------------------------------------------------------------------------------
37
37
38 arrayModules = []
38 arrayModules = []
39 try:
39 try:
40 import Numeric
40 import Numeric
41 except ImportError:
41 except ImportError:
42 pass
42 pass
43 else:
43 else:
44 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
44 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
45 try:
45 try:
46 import numpy
46 import numpy
47 except ImportError:
47 except ImportError:
48 pass
48 pass
49 else:
49 else:
50 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
50 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
51 try:
51 try:
52 import numarray
52 import numarray
53 except ImportError:
53 except ImportError:
54 pass
54 pass
55 else:
55 else:
56 arrayModules.append({'module':numarray,
56 arrayModules.append({'module':numarray,
57 'type':numarray.numarraycore.NumArray})
57 'type':numarray.numarraycore.NumArray})
58
58
59 class Map:
59 class Map(object):
60 """A class for partitioning a sequence using a map."""
60 """A class for partitioning a sequence using a map."""
61
61
62 def getPartition(self, seq, p, q):
62 def getPartition(self, seq, p, q):
63 """Returns the pth partition of q partitions of seq."""
63 """Returns the pth partition of q partitions of seq."""
64
64
65 # Test for error conditions here
65 # Test for error conditions here
66 if p<0 or p>=q:
66 if p<0 or p>=q:
67 print "No partition exists."
67 print "No partition exists."
68 return
68 return
69
69
70 remainder = len(seq)%q
70 N = len(seq)
71 basesize = len(seq)//q
71 remainder = N % q
72 hi = []
72 basesize = N // q
73 lo = []
73
74 for n in range(q):
74 if p < remainder:
75 if n < remainder:
75 low = p * (basesize + 1)
76 lo.append(n * (basesize + 1))
76 high = low + basesize + 1
77 hi.append(lo[-1] + basesize + 1)
77 else:
78 else:
78 low = p * basesize + remainder
79 lo.append(n*basesize + remainder)
79 high = low + basesize
80 hi.append(lo[-1] + basesize)
81
80
82 try:
81 try:
83 result = seq[lo[p]:hi[p]]
82 result = seq[low:high]
84 except TypeError:
83 except TypeError:
85 # some objects (iterators) can't be sliced,
84 # some objects (iterators) can't be sliced,
86 # use islice:
85 # use islice:
87 result = list(islice(seq, lo[p], hi[p]))
86 result = list(islice(seq, low, high))
88
87
89 return result
88 return result
90
89
91 def joinPartitions(self, listOfPartitions):
90 def joinPartitions(self, listOfPartitions):
92 return self.concatenate(listOfPartitions)
91 return self.concatenate(listOfPartitions)
93
92
94 def concatenate(self, listOfPartitions):
93 def concatenate(self, listOfPartitions):
95 testObject = listOfPartitions[0]
94 testObject = listOfPartitions[0]
96 # First see if we have a known array type
95 # First see if we have a known array type
97 for m in arrayModules:
96 for m in arrayModules:
98 #print m
97 #print m
99 if isinstance(testObject, m['type']):
98 if isinstance(testObject, m['type']):
100 return m['module'].concatenate(listOfPartitions)
99 return m['module'].concatenate(listOfPartitions)
101 # Next try for Python sequence types
100 # Next try for Python sequence types
102 if isinstance(testObject, (types.ListType, types.TupleType)):
101 if isinstance(testObject, (types.ListType, types.TupleType)):
103 return utils_flatten(listOfPartitions)
102 return utils_flatten(listOfPartitions)
104 # If we have scalars, just return listOfPartitions
103 # If we have scalars, just return listOfPartitions
105 return listOfPartitions
104 return listOfPartitions
106
105
107 class RoundRobinMap(Map):
106 class RoundRobinMap(Map):
108 """Partitions a sequence in a roun robin fashion.
107 """Partitions a sequence in a roun robin fashion.
109
108
110 This currently does not work!
109 This currently does not work!
111 """
110 """
112
111
113 def getPartition(self, seq, p, q):
112 def getPartition(self, seq, p, q):
114 # if not isinstance(seq,(list,tuple)):
113 # if not isinstance(seq,(list,tuple)):
115 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
114 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
116 return seq[p:len(seq):q]
115 return seq[p:len(seq):q]
117 #result = []
116 #result = []
118 #for i in range(p,len(seq),q):
117 #for i in range(p,len(seq),q):
119 # result.append(seq[i])
118 # result.append(seq[i])
120 #return result
119 #return result
121
120
122 def joinPartitions(self, listOfPartitions):
121 def joinPartitions(self, listOfPartitions):
123 testObject = listOfPartitions[0]
122 testObject = listOfPartitions[0]
124 # First see if we have a known array type
123 # First see if we have a known array type
125 for m in arrayModules:
124 for m in arrayModules:
126 #print m
125 #print m
127 if isinstance(testObject, m['type']):
126 if isinstance(testObject, m['type']):
128 return self.flatten_array(m['type'], listOfPartitions)
127 return self.flatten_array(m['type'], listOfPartitions)
129 if isinstance(testObject, (types.ListType, types.TupleType)):
128 if isinstance(testObject, (types.ListType, types.TupleType)):
130 return self.flatten_list(listOfPartitions)
129 return self.flatten_list(listOfPartitions)
131 return listOfPartitions
130 return listOfPartitions
132
131
133 def flatten_array(self, klass, listOfPartitions):
132 def flatten_array(self, klass, listOfPartitions):
134 test = listOfPartitions[0]
133 test = listOfPartitions[0]
135 shape = list(test.shape)
134 shape = list(test.shape)
136 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
135 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
137 A = klass(shape)
136 A = klass(shape)
138 N = shape[0]
137 N = shape[0]
139 q = len(listOfPartitions)
138 q = len(listOfPartitions)
140 for p,part in enumerate(listOfPartitions):
139 for p,part in enumerate(listOfPartitions):
141 A[p:N:q] = part
140 A[p:N:q] = part
142 return A
141 return A
143
142
144 def flatten_list(self, listOfPartitions):
143 def flatten_list(self, listOfPartitions):
145 flat = []
144 flat = []
146 for i in range(len(listOfPartitions[0])):
145 for i in range(len(listOfPartitions[0])):
147 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
146 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
148 return flat
147 return flat
149 #lengths = [len(x) for x in listOfPartitions]
148 #lengths = [len(x) for x in listOfPartitions]
150 #maxPartitionLength = len(listOfPartitions[0])
149 #maxPartitionLength = len(listOfPartitions[0])
151 #numberOfPartitions = len(listOfPartitions)
150 #numberOfPartitions = len(listOfPartitions)
152 #concat = self.concatenate(listOfPartitions)
151 #concat = self.concatenate(listOfPartitions)
153 #totalLength = len(concat)
152 #totalLength = len(concat)
154 #result = []
153 #result = []
155 #for i in range(maxPartitionLength):
154 #for i in range(maxPartitionLength):
156 # result.append(concat[i:totalLength:maxPartitionLength])
155 # result.append(concat[i:totalLength:maxPartitionLength])
157 # return self.concatenate(listOfPartitions)
156 # return self.concatenate(listOfPartitions)
158
157
159 def mappable(obj):
158 def mappable(obj):
160 """return whether an object is mappable or not."""
159 """return whether an object is mappable or not."""
161 if isinstance(obj, (tuple,list)):
160 if isinstance(obj, (tuple,list)):
162 return True
161 return True
163 for m in arrayModules:
162 for m in arrayModules:
164 if isinstance(obj,m['type']):
163 if isinstance(obj,m['type']):
165 return True
164 return True
166 return False
165 return False
167
166
168 dists = {'b':Map,'r':RoundRobinMap}
167 dists = {'b':Map,'r':RoundRobinMap}
169
168
170
169
171
170
General Comments 0
You need to be logged in to leave comments. Login now