##// END OF EJS Templates
Merge pull request #3282 from minrk/mapgenerator...
Brian E. Granger -
r10605:9f313b4a merge
parent child Browse files
Show More
@@ -59,17 +59,19 b' else:'
59 59 class Map(object):
60 60 """A class for partitioning a sequence using a map."""
61 61
62 def getPartition(self, seq, p, q):
63 """Returns the pth partition of q partitions of seq."""
62 def getPartition(self, seq, p, q, n=None):
63 """Returns the pth partition of q partitions of seq.
64 64
65 The length can be specified as `n`,
66 otherwise it is the value of `len(seq)`
67 """
68 n = len(seq) if n is None else n
65 69 # Test for error conditions here
66 70 if p<0 or p>=q:
67 print "No partition exists."
68 return
71 raise ValueError("must have 0 <= p <= q, but have p=%s,q=%s" % (p, q))
69 72
70 N = len(seq)
71 remainder = N % q
72 basesize = N // q
73 remainder = n % q
74 basesize = n // q
73 75
74 76 if p < remainder:
75 77 low = p * (basesize + 1)
@@ -104,19 +106,14 b' class Map(object):'
104 106 return listOfPartitions
105 107
106 108 class RoundRobinMap(Map):
107 """Partitions a sequence in a roun robin fashion.
109 """Partitions a sequence in a round robin fashion.
108 110
109 111 This currently does not work!
110 112 """
111 113
112 def getPartition(self, seq, p, q):
113 # if not isinstance(seq,(list,tuple)):
114 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
115 return seq[p:len(seq):q]
116 #result = []
117 #for i in range(p,len(seq),q):
118 # result.append(seq[i])
119 #return result
114 def getPartition(self, seq, p, q, n=None):
115 n = len(seq) if n is None else n
116 return seq[p:n:q]
120 117
121 118 def joinPartitions(self, listOfPartitions):
122 119 testObject = listOfPartitions[0]
@@ -161,13 +161,15 b' class ParallelFunction(RemoteFunction):'
161 161 chunksize : int or None
162 162 The size of chunk to use when breaking up sequences in a load-balanced manner
163 163 ordered : bool [default: True]
164 Whether
164 Whether the result should be kept in order. If False,
165 results become available as they arrive, regardless of submission order.
165 166 **flags : remaining kwargs are passed to View.temp_flags
166 167 """
167 168
168 169 chunksize=None
169 170 ordered=None
170 171 mapObject=None
172 _mapping = False
171 173
172 174 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
173 175 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
@@ -181,18 +183,35 b' class ParallelFunction(RemoteFunction):'
181 183 def __call__(self, *sequences):
182 184 client = self.view.client
183 185
186 lens = []
187 maxlen = minlen = -1
188 for i, seq in enumerate(sequences):
189 try:
190 n = len(seq)
191 except Exception:
192 seq = list(seq)
193 if isinstance(sequences, tuple):
194 # can't alter a tuple
195 sequences = list(sequences)
196 sequences[i] = seq
197 n = len(seq)
198 if n > maxlen:
199 maxlen = n
200 if minlen == -1 or n < minlen:
201 minlen = n
202 lens.append(n)
203
184 204 # check that the length of sequences match
185 len_0 = len(sequences[0])
186 for s in sequences:
187 if len(s)!=len_0:
188 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
205 if not self._mapping and minlen != maxlen:
206 msg = 'all sequences must have equal length, but have %s' % lens
189 207 raise ValueError(msg)
208
190 209 balanced = 'Balanced' in self.view.__class__.__name__
191 210 if balanced:
192 211 if self.chunksize:
193 nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0)
212 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
194 213 else:
195 nparts = len_0
214 nparts = maxlen
196 215 targets = [None]*nparts
197 216 else:
198 217 if self.chunksize:
@@ -211,16 +230,12 b' class ParallelFunction(RemoteFunction):'
211 230 for index, t in enumerate(targets):
212 231 args = []
213 232 for seq in sequences:
214 part = self.mapObject.getPartition(seq, index, nparts)
215 if len(part) == 0:
216 continue
217 else:
233 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
218 234 args.append(part)
219 if not args:
235 if not any(args):
220 236 continue
221 237
222 # print (args)
223 if hasattr(self, '_map'):
238 if self._mapping:
224 239 if sys.version_info[0] >= 3:
225 240 f = lambda f, *sequences: list(map(f, *sequences))
226 241 else:
@@ -233,7 +248,7 b' class ParallelFunction(RemoteFunction):'
233 248 with view.temp_flags(block=False, **self.flags):
234 249 ar = view.apply(f, *args)
235 250
236 msg_ids.append(ar.msg_ids[0])
251 msg_ids.extend(ar.msg_ids)
237 252
238 253 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
239 254 fname=getname(self.func),
@@ -249,16 +264,19 b' class ParallelFunction(RemoteFunction):'
249 264 return r
250 265
251 266 def map(self, *sequences):
252 """call a function on each element of a sequence remotely.
267 """call a function on each element of one or more sequence(s) remotely.
253 268 This should behave very much like the builtin map, but return an AsyncMapResult
254 269 if self.block is False.
270
271 That means it can take generators (will be cast to lists locally),
272 and mismatched sequence lengths will be padded with None.
255 273 """
256 # set _map as a flag for use inside self.__call__
257 self._map = True
274 # set _mapping as a flag for use inside self.__call__
275 self._mapping = True
258 276 try:
259 ret = self.__call__(*sequences)
277 ret = self(*sequences)
260 278 finally:
261 del self._map
279 self._mapping = False
262 280 return ret
263 281
264 282 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -59,6 +59,40 b' class TestLoadBalancedView(ClusterTestCase):'
59 59 r = self.view.map_sync(f, data)
60 60 self.assertEqual(r, map(f, data))
61 61
62 def test_map_generator(self):
63 def f(x):
64 return x**2
65
66 data = range(16)
67 r = self.view.map_sync(f, iter(data))
68 self.assertEqual(r, map(f, iter(data)))
69
70 def test_map_short_first(self):
71 def f(x,y):
72 if y is None:
73 return y
74 if x is None:
75 return x
76 return x*y
77 data = range(10)
78 data2 = range(4)
79
80 r = self.view.map_sync(f, data, data2)
81 self.assertEqual(r, map(f, data, data2))
82
83 def test_map_short_last(self):
84 def f(x,y):
85 if y is None:
86 return y
87 if x is None:
88 return x
89 return x*y
90 data = range(4)
91 data2 = range(10)
92
93 r = self.view.map_sync(f, data, data2)
94 self.assertEqual(r, map(f, data, data2))
95
62 96 def test_map_unordered(self):
63 97 def f(x):
64 98 return x**2
General Comments 0
You need to be logged in to leave comments. Login now