##// END OF EJS Templates
allow map objects to partition specified lengths
MinRK -
Show More
@@ -1,170 +1,167
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Classes used in scattering and gathering sequences.
3 """Classes used in scattering and gathering sequences.
4
4
5 Scattering consists of partitioning a sequence and sending the various
5 Scattering consists of partitioning a sequence and sending the various
6 pieces to individual nodes in a cluster.
6 pieces to individual nodes in a cluster.
7
7
8
8
9 Authors:
9 Authors:
10
10
11 * Brian Granger
11 * Brian Granger
12 * MinRK
12 * MinRK
13
13
14 """
14 """
15
15
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17 # Copyright (C) 2008-2011 The IPython Development Team
17 # Copyright (C) 2008-2011 The IPython Development Team
18 #
18 #
19 # Distributed under the terms of the BSD License. The full license is in
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
20 # the file COPYING, distributed as part of this software.
21 #-------------------------------------------------------------------------------
21 #-------------------------------------------------------------------------------
22
22
23 #-------------------------------------------------------------------------------
23 #-------------------------------------------------------------------------------
24 # Imports
24 # Imports
25 #-------------------------------------------------------------------------------
25 #-------------------------------------------------------------------------------
26
26
27 from __future__ import division
27 from __future__ import division
28
28
29 import types
29 import types
30 from itertools import islice
30 from itertools import islice
31
31
32 from IPython.utils.data import flatten as utils_flatten
32 from IPython.utils.data import flatten as utils_flatten
33
33
34 #-------------------------------------------------------------------------------
34 #-------------------------------------------------------------------------------
35 # Figure out which array packages are present and their array types
35 # Figure out which array packages are present and their array types
36 #-------------------------------------------------------------------------------
36 #-------------------------------------------------------------------------------
37
37
38 arrayModules = []
38 arrayModules = []
39 try:
39 try:
40 import Numeric
40 import Numeric
41 except ImportError:
41 except ImportError:
42 pass
42 pass
43 else:
43 else:
44 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
44 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
45 try:
45 try:
46 import numpy
46 import numpy
47 except ImportError:
47 except ImportError:
48 pass
48 pass
49 else:
49 else:
50 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
50 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
51 try:
51 try:
52 import numarray
52 import numarray
53 except ImportError:
53 except ImportError:
54 pass
54 pass
55 else:
55 else:
56 arrayModules.append({'module':numarray,
56 arrayModules.append({'module':numarray,
57 'type':numarray.numarraycore.NumArray})
57 'type':numarray.numarraycore.NumArray})
58
58
59 class Map(object):
59 class Map(object):
60 """A class for partitioning a sequence using a map."""
60 """A class for partitioning a sequence using a map."""
61
61
62 def getPartition(self, seq, p, q):
62 def getPartition(self, seq, p, q, n=None):
63 """Returns the pth partition of q partitions of seq."""
63 """Returns the pth partition of q partitions of seq.
64
64
65 The length can be specified as `n`,
66 otherwise it is the value of `len(seq)`
67 """
68 n = len(seq) if n is None else n
65 # Test for error conditions here
69 # Test for error conditions here
66 if p<0 or p>=q:
70 if p<0 or p>=q:
67 print "No partition exists."
71 raise ValueError("must have 0 <= p <= q, but have p=%s,q=%s" % (p, q))
68 return
69
72
70 N = len(seq)
73 remainder = n % q
71 remainder = N % q
74 basesize = n // q
72 basesize = N // q
73
75
74 if p < remainder:
76 if p < remainder:
75 low = p * (basesize + 1)
77 low = p * (basesize + 1)
76 high = low + basesize + 1
78 high = low + basesize + 1
77 else:
79 else:
78 low = p * basesize + remainder
80 low = p * basesize + remainder
79 high = low + basesize
81 high = low + basesize
80
82
81 try:
83 try:
82 result = seq[low:high]
84 result = seq[low:high]
83 except TypeError:
85 except TypeError:
84 # some objects (iterators) can't be sliced,
86 # some objects (iterators) can't be sliced,
85 # use islice:
87 # use islice:
86 result = list(islice(seq, low, high))
88 result = list(islice(seq, low, high))
87
89
88 return result
90 return result
89
91
90 def joinPartitions(self, listOfPartitions):
92 def joinPartitions(self, listOfPartitions):
91 return self.concatenate(listOfPartitions)
93 return self.concatenate(listOfPartitions)
92
94
93 def concatenate(self, listOfPartitions):
95 def concatenate(self, listOfPartitions):
94 testObject = listOfPartitions[0]
96 testObject = listOfPartitions[0]
95 # First see if we have a known array type
97 # First see if we have a known array type
96 for m in arrayModules:
98 for m in arrayModules:
97 #print m
99 #print m
98 if isinstance(testObject, m['type']):
100 if isinstance(testObject, m['type']):
99 return m['module'].concatenate(listOfPartitions)
101 return m['module'].concatenate(listOfPartitions)
100 # Next try for Python sequence types
102 # Next try for Python sequence types
101 if isinstance(testObject, (types.ListType, types.TupleType)):
103 if isinstance(testObject, (types.ListType, types.TupleType)):
102 return utils_flatten(listOfPartitions)
104 return utils_flatten(listOfPartitions)
103 # If we have scalars, just return listOfPartitions
105 # If we have scalars, just return listOfPartitions
104 return listOfPartitions
106 return listOfPartitions
105
107
106 class RoundRobinMap(Map):
108 class RoundRobinMap(Map):
107 """Partitions a sequence in a roun robin fashion.
109 """Partitions a sequence in a round robin fashion.
108
110
109 This currently does not work!
111 This currently does not work!
110 """
112 """
111
113
112 def getPartition(self, seq, p, q):
114 def getPartition(self, seq, p, q, n=None):
113 # if not isinstance(seq,(list,tuple)):
115 n = len(seq) if n is None else n
114 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
116 return seq[p:n:q]
115 return seq[p:len(seq):q]
116 #result = []
117 #for i in range(p,len(seq),q):
118 # result.append(seq[i])
119 #return result
120
117
121 def joinPartitions(self, listOfPartitions):
118 def joinPartitions(self, listOfPartitions):
122 testObject = listOfPartitions[0]
119 testObject = listOfPartitions[0]
123 # First see if we have a known array type
120 # First see if we have a known array type
124 for m in arrayModules:
121 for m in arrayModules:
125 #print m
122 #print m
126 if isinstance(testObject, m['type']):
123 if isinstance(testObject, m['type']):
127 return self.flatten_array(m['type'], listOfPartitions)
124 return self.flatten_array(m['type'], listOfPartitions)
128 if isinstance(testObject, (types.ListType, types.TupleType)):
125 if isinstance(testObject, (types.ListType, types.TupleType)):
129 return self.flatten_list(listOfPartitions)
126 return self.flatten_list(listOfPartitions)
130 return listOfPartitions
127 return listOfPartitions
131
128
132 def flatten_array(self, klass, listOfPartitions):
129 def flatten_array(self, klass, listOfPartitions):
133 test = listOfPartitions[0]
130 test = listOfPartitions[0]
134 shape = list(test.shape)
131 shape = list(test.shape)
135 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
132 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
136 A = klass(shape)
133 A = klass(shape)
137 N = shape[0]
134 N = shape[0]
138 q = len(listOfPartitions)
135 q = len(listOfPartitions)
139 for p,part in enumerate(listOfPartitions):
136 for p,part in enumerate(listOfPartitions):
140 A[p:N:q] = part
137 A[p:N:q] = part
141 return A
138 return A
142
139
143 def flatten_list(self, listOfPartitions):
140 def flatten_list(self, listOfPartitions):
144 flat = []
141 flat = []
145 for i in range(len(listOfPartitions[0])):
142 for i in range(len(listOfPartitions[0])):
146 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
143 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
147 return flat
144 return flat
148 #lengths = [len(x) for x in listOfPartitions]
145 #lengths = [len(x) for x in listOfPartitions]
149 #maxPartitionLength = len(listOfPartitions[0])
146 #maxPartitionLength = len(listOfPartitions[0])
150 #numberOfPartitions = len(listOfPartitions)
147 #numberOfPartitions = len(listOfPartitions)
151 #concat = self.concatenate(listOfPartitions)
148 #concat = self.concatenate(listOfPartitions)
152 #totalLength = len(concat)
149 #totalLength = len(concat)
153 #result = []
150 #result = []
154 #for i in range(maxPartitionLength):
151 #for i in range(maxPartitionLength):
155 # result.append(concat[i:totalLength:maxPartitionLength])
152 # result.append(concat[i:totalLength:maxPartitionLength])
156 # return self.concatenate(listOfPartitions)
153 # return self.concatenate(listOfPartitions)
157
154
158 def mappable(obj):
155 def mappable(obj):
159 """return whether an object is mappable or not."""
156 """return whether an object is mappable or not."""
160 if isinstance(obj, (tuple,list)):
157 if isinstance(obj, (tuple,list)):
161 return True
158 return True
162 for m in arrayModules:
159 for m in arrayModules:
163 if isinstance(obj,m['type']):
160 if isinstance(obj,m['type']):
164 return True
161 return True
165 return False
162 return False
166
163
167 dists = {'b':Map,'r':RoundRobinMap}
164 dists = {'b':Map,'r':RoundRobinMap}
168
165
169
166
170
167
General Comments 0
You need to be logged in to leave comments. Login now