##// END OF EJS Templates
Merge pull request #3282 from minrk/mapgenerator...
Brian E. Granger -
r10605:9f313b4a merge
parent child Browse files
Show More
@@ -59,17 +59,19 b' else:'
59 class Map(object):
59 class Map(object):
60 """A class for partitioning a sequence using a map."""
60 """A class for partitioning a sequence using a map."""
61
61
62 def getPartition(self, seq, p, q):
62 def getPartition(self, seq, p, q, n=None):
63 """Returns the pth partition of q partitions of seq."""
63 """Returns the pth partition of q partitions of seq.
64
64
65 The length can be specified as `n`,
66 otherwise it is the value of `len(seq)`
67 """
68 n = len(seq) if n is None else n
65 # Test for error conditions here
69 # Test for error conditions here
66 if p<0 or p>=q:
70 if p<0 or p>=q:
67 print "No partition exists."
71 raise ValueError("must have 0 <= p <= q, but have p=%s,q=%s" % (p, q))
68 return
69
72
70 N = len(seq)
73 remainder = n % q
71 remainder = N % q
74 basesize = n // q
72 basesize = N // q
73
75
74 if p < remainder:
76 if p < remainder:
75 low = p * (basesize + 1)
77 low = p * (basesize + 1)
@@ -104,19 +106,14 b' class Map(object):'
104 return listOfPartitions
106 return listOfPartitions
105
107
106 class RoundRobinMap(Map):
108 class RoundRobinMap(Map):
107 """Partitions a sequence in a roun robin fashion.
109 """Partitions a sequence in a round robin fashion.
108
110
109 This currently does not work!
111 This currently does not work!
110 """
112 """
111
113
112 def getPartition(self, seq, p, q):
114 def getPartition(self, seq, p, q, n=None):
113 # if not isinstance(seq,(list,tuple)):
115 n = len(seq) if n is None else n
114 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
116 return seq[p:n:q]
115 return seq[p:len(seq):q]
116 #result = []
117 #for i in range(p,len(seq),q):
118 # result.append(seq[i])
119 #return result
120
117
121 def joinPartitions(self, listOfPartitions):
118 def joinPartitions(self, listOfPartitions):
122 testObject = listOfPartitions[0]
119 testObject = listOfPartitions[0]
@@ -161,13 +161,15 b' class ParallelFunction(RemoteFunction):'
161 chunksize : int or None
161 chunksize : int or None
162 The size of chunk to use when breaking up sequences in a load-balanced manner
162 The size of chunk to use when breaking up sequences in a load-balanced manner
163 ordered : bool [default: True]
163 ordered : bool [default: True]
164 Whether
164 Whether the result should be kept in order. If False,
165 results become available as they arrive, regardless of submission order.
165 **flags : remaining kwargs are passed to View.temp_flags
166 **flags : remaining kwargs are passed to View.temp_flags
166 """
167 """
167
168
168 chunksize=None
169 chunksize = None
169 ordered=None
170 ordered = None
170 mapObject=None
171 mapObject = None
172 _mapping = False
171
173
172 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
174 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
173 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
175 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
@@ -176,23 +178,40 b' class ParallelFunction(RemoteFunction):'
176
178
177 mapClass = Map.dists[dist]
179 mapClass = Map.dists[dist]
178 self.mapObject = mapClass()
180 self.mapObject = mapClass()
179
181
180 @sync_view_results
182 @sync_view_results
181 def __call__(self, *sequences):
183 def __call__(self, *sequences):
182 client = self.view.client
184 client = self.view.client
183
185
186 lens = []
187 maxlen = minlen = -1
188 for i, seq in enumerate(sequences):
189 try:
190 n = len(seq)
191 except Exception:
192 seq = list(seq)
193 if isinstance(sequences, tuple):
194 # can't alter a tuple
195 sequences = list(sequences)
196 sequences[i] = seq
197 n = len(seq)
198 if n > maxlen:
199 maxlen = n
200 if minlen == -1 or n < minlen:
201 minlen = n
202 lens.append(n)
203
184 # check that the length of sequences match
204 # check that the length of sequences match
185 len_0 = len(sequences[0])
205 if not self._mapping and minlen != maxlen:
186 for s in sequences:
206 msg = 'all sequences must have equal length, but have %s' % lens
187 if len(s)!=len_0:
207 raise ValueError(msg)
188 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
208
189 raise ValueError(msg)
190 balanced = 'Balanced' in self.view.__class__.__name__
209 balanced = 'Balanced' in self.view.__class__.__name__
191 if balanced:
210 if balanced:
192 if self.chunksize:
211 if self.chunksize:
193 nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0)
212 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
194 else:
213 else:
195 nparts = len_0
214 nparts = maxlen
196 targets = [None]*nparts
215 targets = [None]*nparts
197 else:
216 else:
198 if self.chunksize:
217 if self.chunksize:
@@ -211,21 +230,17 b' class ParallelFunction(RemoteFunction):'
211 for index, t in enumerate(targets):
230 for index, t in enumerate(targets):
212 args = []
231 args = []
213 for seq in sequences:
232 for seq in sequences:
214 part = self.mapObject.getPartition(seq, index, nparts)
233 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
215 if len(part) == 0:
234 args.append(part)
216 continue
235 if not any(args):
217 else:
218 args.append(part)
219 if not args:
220 continue
236 continue
221
237
222 # print (args)
238 if self._mapping:
223 if hasattr(self, '_map'):
224 if sys.version_info[0] >= 3:
239 if sys.version_info[0] >= 3:
225 f = lambda f, *sequences: list(map(f, *sequences))
240 f = lambda f, *sequences: list(map(f, *sequences))
226 else:
241 else:
227 f = map
242 f = map
228 args = [self.func]+args
243 args = [self.func] + args
229 else:
244 else:
230 f=self.func
245 f=self.func
231
246
@@ -233,9 +248,9 b' class ParallelFunction(RemoteFunction):'
233 with view.temp_flags(block=False, **self.flags):
248 with view.temp_flags(block=False, **self.flags):
234 ar = view.apply(f, *args)
249 ar = view.apply(f, *args)
235
250
236 msg_ids.append(ar.msg_ids[0])
251 msg_ids.extend(ar.msg_ids)
237
252
238 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
253 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
239 fname=getname(self.func),
254 fname=getname(self.func),
240 ordered=self.ordered
255 ordered=self.ordered
241 )
256 )
@@ -249,16 +264,19 b' class ParallelFunction(RemoteFunction):'
249 return r
264 return r
250
265
251 def map(self, *sequences):
266 def map(self, *sequences):
252 """call a function on each element of a sequence remotely.
267 """call a function on each element of one or more sequence(s) remotely.
253 This should behave very much like the builtin map, but return an AsyncMapResult
268 This should behave very much like the builtin map, but return an AsyncMapResult
254 if self.block is False.
269 if self.block is False.
270
271 That means it can take generators (will be cast to lists locally),
272 and mismatched sequence lengths will be padded with None.
255 """
273 """
256 # set _map as a flag for use inside self.__call__
274 # set _mapping as a flag for use inside self.__call__
257 self._map = True
275 self._mapping = True
258 try:
276 try:
259 ret = self.__call__(*sequences)
277 ret = self(*sequences)
260 finally:
278 finally:
261 del self._map
279 self._mapping = False
262 return ret
280 return ret
263
281
264 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
282 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -58,6 +58,40 b' class TestLoadBalancedView(ClusterTestCase):'
58 data = range(16)
58 data = range(16)
59 r = self.view.map_sync(f, data)
59 r = self.view.map_sync(f, data)
60 self.assertEqual(r, map(f, data))
60 self.assertEqual(r, map(f, data))
61
62 def test_map_generator(self):
63 def f(x):
64 return x**2
65
66 data = range(16)
67 r = self.view.map_sync(f, iter(data))
68 self.assertEqual(r, map(f, iter(data)))
69
70 def test_map_short_first(self):
71 def f(x,y):
72 if y is None:
73 return y
74 if x is None:
75 return x
76 return x*y
77 data = range(10)
78 data2 = range(4)
79
80 r = self.view.map_sync(f, data, data2)
81 self.assertEqual(r, map(f, data, data2))
82
83 def test_map_short_last(self):
84 def f(x,y):
85 if y is None:
86 return y
87 if x is None:
88 return x
89 return x*y
90 data = range(4)
91 data2 = range(10)
92
93 r = self.view.map_sync(f, data, data2)
94 self.assertEqual(r, map(f, data, data2))
61
95
62 def test_map_unordered(self):
96 def test_map_unordered(self):
63 def f(x):
97 def f(x):
General Comments 0
You need to be logged in to leave comments. Login now