##// END OF EJS Templates
Merge pull request #3282 from minrk/mapgenerator...
Brian E. Granger -
r10605:9f313b4a merge
parent child Browse files
Show More
@@ -1,170 +1,167 b''
1 1 # encoding: utf-8
2 2
3 3 """Classes used in scattering and gathering sequences.
4 4
5 5 Scattering consists of partitioning a sequence and sending the various
6 6 pieces to individual nodes in a cluster.
7 7
8 8
9 9 Authors:
10 10
11 11 * Brian Granger
12 12 * MinRK
13 13
14 14 """
15 15
16 16 #-------------------------------------------------------------------------------
17 17 # Copyright (C) 2008-2011 The IPython Development Team
18 18 #
19 19 # Distributed under the terms of the BSD License. The full license is in
20 20 # the file COPYING, distributed as part of this software.
21 21 #-------------------------------------------------------------------------------
22 22
23 23 #-------------------------------------------------------------------------------
24 24 # Imports
25 25 #-------------------------------------------------------------------------------
26 26
27 27 from __future__ import division
28 28
29 29 import types
30 30 from itertools import islice
31 31
32 32 from IPython.utils.data import flatten as utils_flatten
33 33
34 34 #-------------------------------------------------------------------------------
35 35 # Figure out which array packages are present and their array types
36 36 #-------------------------------------------------------------------------------
37 37
38 38 arrayModules = []
39 39 try:
40 40 import Numeric
41 41 except ImportError:
42 42 pass
43 43 else:
44 44 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
45 45 try:
46 46 import numpy
47 47 except ImportError:
48 48 pass
49 49 else:
50 50 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
51 51 try:
52 52 import numarray
53 53 except ImportError:
54 54 pass
55 55 else:
56 56 arrayModules.append({'module':numarray,
57 57 'type':numarray.numarraycore.NumArray})
58 58
59 59 class Map(object):
60 60 """A class for partitioning a sequence using a map."""
61 61
62 def getPartition(self, seq, p, q):
63 """Returns the pth partition of q partitions of seq."""
62 def getPartition(self, seq, p, q, n=None):
63 """Returns the pth partition of q partitions of seq.
64 64
65 The length can be specified as `n`,
66 otherwise it is the value of `len(seq)`
67 """
68 n = len(seq) if n is None else n
65 69 # Test for error conditions here
66 70 if p<0 or p>=q:
67 print "No partition exists."
68 return
71 raise ValueError("must have 0 <= p <= q, but have p=%s,q=%s" % (p, q))
69 72
70 N = len(seq)
71 remainder = N % q
72 basesize = N // q
73 remainder = n % q
74 basesize = n // q
73 75
74 76 if p < remainder:
75 77 low = p * (basesize + 1)
76 78 high = low + basesize + 1
77 79 else:
78 80 low = p * basesize + remainder
79 81 high = low + basesize
80 82
81 83 try:
82 84 result = seq[low:high]
83 85 except TypeError:
84 86 # some objects (iterators) can't be sliced,
85 87 # use islice:
86 88 result = list(islice(seq, low, high))
87 89
88 90 return result
89 91
90 92 def joinPartitions(self, listOfPartitions):
91 93 return self.concatenate(listOfPartitions)
92 94
93 95 def concatenate(self, listOfPartitions):
94 96 testObject = listOfPartitions[0]
95 97 # First see if we have a known array type
96 98 for m in arrayModules:
97 99 #print m
98 100 if isinstance(testObject, m['type']):
99 101 return m['module'].concatenate(listOfPartitions)
100 102 # Next try for Python sequence types
101 103 if isinstance(testObject, (types.ListType, types.TupleType)):
102 104 return utils_flatten(listOfPartitions)
103 105 # If we have scalars, just return listOfPartitions
104 106 return listOfPartitions
105 107
106 108 class RoundRobinMap(Map):
107 """Partitions a sequence in a roun robin fashion.
109 """Partitions a sequence in a round robin fashion.
108 110
109 111 This currently does not work!
110 112 """
111 113
112 def getPartition(self, seq, p, q):
113 # if not isinstance(seq,(list,tuple)):
114 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
115 return seq[p:len(seq):q]
116 #result = []
117 #for i in range(p,len(seq),q):
118 # result.append(seq[i])
119 #return result
114 def getPartition(self, seq, p, q, n=None):
115 n = len(seq) if n is None else n
116 return seq[p:n:q]
120 117
121 118 def joinPartitions(self, listOfPartitions):
122 119 testObject = listOfPartitions[0]
123 120 # First see if we have a known array type
124 121 for m in arrayModules:
125 122 #print m
126 123 if isinstance(testObject, m['type']):
127 124 return self.flatten_array(m['type'], listOfPartitions)
128 125 if isinstance(testObject, (types.ListType, types.TupleType)):
129 126 return self.flatten_list(listOfPartitions)
130 127 return listOfPartitions
131 128
132 129 def flatten_array(self, klass, listOfPartitions):
133 130 test = listOfPartitions[0]
134 131 shape = list(test.shape)
135 132 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
136 133 A = klass(shape)
137 134 N = shape[0]
138 135 q = len(listOfPartitions)
139 136 for p,part in enumerate(listOfPartitions):
140 137 A[p:N:q] = part
141 138 return A
142 139
143 140 def flatten_list(self, listOfPartitions):
144 141 flat = []
145 142 for i in range(len(listOfPartitions[0])):
146 143 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
147 144 return flat
148 145 #lengths = [len(x) for x in listOfPartitions]
149 146 #maxPartitionLength = len(listOfPartitions[0])
150 147 #numberOfPartitions = len(listOfPartitions)
151 148 #concat = self.concatenate(listOfPartitions)
152 149 #totalLength = len(concat)
153 150 #result = []
154 151 #for i in range(maxPartitionLength):
155 152 # result.append(concat[i:totalLength:maxPartitionLength])
156 153 # return self.concatenate(listOfPartitions)
157 154
158 155 def mappable(obj):
159 156 """return whether an object is mappable or not."""
160 157 if isinstance(obj, (tuple,list)):
161 158 return True
162 159 for m in arrayModules:
163 160 if isinstance(obj,m['type']):
164 161 return True
165 162 return False
166 163
167 164 dists = {'b':Map,'r':RoundRobinMap}
168 165
169 166
170 167
@@ -1,264 +1,282 b''
1 1 """Remote Functions and decorators for Views.
2 2
3 3 Authors:
4 4
5 5 * Brian Granger
6 6 * Min RK
7 7 """
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2010-2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import sys
22 22 import warnings
23 23
24 24 from IPython.external.decorator import decorator
25 25 from IPython.testing.skipdoctest import skip_doctest
26 26
27 27 from . import map as Map
28 28 from .asyncresult import AsyncMapResult
29 29
30 30 #-----------------------------------------------------------------------------
31 31 # Functions and Decorators
32 32 #-----------------------------------------------------------------------------
33 33
34 34 @skip_doctest
35 35 def remote(view, block=None, **flags):
36 36 """Turn a function into a remote function.
37 37
38 38 This method can be used for map:
39 39
40 40 In [1]: @remote(view,block=True)
41 41 ...: def func(a):
42 42 ...: pass
43 43 """
44 44
45 45 def remote_function(f):
46 46 return RemoteFunction(view, f, block=block, **flags)
47 47 return remote_function
48 48
49 49 @skip_doctest
50 50 def parallel(view, dist='b', block=None, ordered=True, **flags):
51 51 """Turn a function into a parallel remote function.
52 52
53 53 This method can be used for map:
54 54
55 55 In [1]: @parallel(view, block=True)
56 56 ...: def func(a):
57 57 ...: pass
58 58 """
59 59
60 60 def parallel_function(f):
61 61 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
62 62 return parallel_function
63 63
64 64 def getname(f):
65 65 """Get the name of an object.
66 66
67 67 For use in case of callables that are not functions, and
68 68 thus may not have __name__ defined.
69 69
70 70 Order: f.__name__ > f.name > str(f)
71 71 """
72 72 try:
73 73 return f.__name__
74 74 except:
75 75 pass
76 76 try:
77 77 return f.name
78 78 except:
79 79 pass
80 80
81 81 return str(f)
82 82
83 83 @decorator
84 84 def sync_view_results(f, self, *args, **kwargs):
85 85 """sync relevant results from self.client to our results attribute.
86 86
87 87 This is a clone of view.sync_results, but for remote functions
88 88 """
89 89 view = self.view
90 90 if view._in_sync_results:
91 91 return f(self, *args, **kwargs)
92 92 print 'in sync results', f
93 93 view._in_sync_results = True
94 94 try:
95 95 ret = f(self, *args, **kwargs)
96 96 finally:
97 97 view._in_sync_results = False
98 98 view._sync_results()
99 99 return ret
100 100
101 101 #--------------------------------------------------------------------------
102 102 # Classes
103 103 #--------------------------------------------------------------------------
104 104
105 105 class RemoteFunction(object):
106 106 """Turn an existing function into a remote function.
107 107
108 108 Parameters
109 109 ----------
110 110
111 111 view : View instance
112 112 The view to be used for execution
113 113 f : callable
114 114 The function to be wrapped into a remote function
115 115 block : bool [default: None]
116 116 Whether to wait for results or not. The default behavior is
117 117 to use the current `block` attribute of `view`
118 118
119 119 **flags : remaining kwargs are passed to View.temp_flags
120 120 """
121 121
122 122 view = None # the remote connection
123 123 func = None # the wrapped function
124 124 block = None # whether to block
125 125 flags = None # dict of extra kwargs for temp_flags
126 126
127 127 def __init__(self, view, f, block=None, **flags):
128 128 self.view = view
129 129 self.func = f
130 130 self.block=block
131 131 self.flags=flags
132 132
133 133 def __call__(self, *args, **kwargs):
134 134 block = self.view.block if self.block is None else self.block
135 135 with self.view.temp_flags(block=block, **self.flags):
136 136 return self.view.apply(self.func, *args, **kwargs)
137 137
138 138
139 139 class ParallelFunction(RemoteFunction):
140 140 """Class for mapping a function to sequences.
141 141
142 142 This will distribute the sequences according the a mapper, and call
143 143 the function on each sub-sequence. If called via map, then the function
144 144 will be called once on each element, rather that each sub-sequence.
145 145
146 146 Parameters
147 147 ----------
148 148
149 149 view : View instance
150 150 The view to be used for execution
151 151 f : callable
152 152 The function to be wrapped into a remote function
153 153 dist : str [default: 'b']
154 154 The key for which mapObject to use to distribute sequences
155 155 options are:
156 156 * 'b' : use contiguous chunks in order
157 157 * 'r' : use round-robin striping
158 158 block : bool [default: None]
159 159 Whether to wait for results or not. The default behavior is
160 160 to use the current `block` attribute of `view`
161 161 chunksize : int or None
162 162 The size of chunk to use when breaking up sequences in a load-balanced manner
163 163 ordered : bool [default: True]
164 Whether
164 Whether the result should be kept in order. If False,
165 results become available as they arrive, regardless of submission order.
165 166 **flags : remaining kwargs are passed to View.temp_flags
166 167 """
167 168
168 chunksize=None
169 ordered=None
170 mapObject=None
169 chunksize = None
170 ordered = None
171 mapObject = None
172 _mapping = False
171 173
172 174 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
173 175 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
174 176 self.chunksize = chunksize
175 177 self.ordered = ordered
176 178
177 179 mapClass = Map.dists[dist]
178 180 self.mapObject = mapClass()
179
181
180 182 @sync_view_results
181 183 def __call__(self, *sequences):
182 184 client = self.view.client
183 185
186 lens = []
187 maxlen = minlen = -1
188 for i, seq in enumerate(sequences):
189 try:
190 n = len(seq)
191 except Exception:
192 seq = list(seq)
193 if isinstance(sequences, tuple):
194 # can't alter a tuple
195 sequences = list(sequences)
196 sequences[i] = seq
197 n = len(seq)
198 if n > maxlen:
199 maxlen = n
200 if minlen == -1 or n < minlen:
201 minlen = n
202 lens.append(n)
203
184 204 # check that the length of sequences match
185 len_0 = len(sequences[0])
186 for s in sequences:
187 if len(s)!=len_0:
188 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
189 raise ValueError(msg)
205 if not self._mapping and minlen != maxlen:
206 msg = 'all sequences must have equal length, but have %s' % lens
207 raise ValueError(msg)
208
190 209 balanced = 'Balanced' in self.view.__class__.__name__
191 210 if balanced:
192 211 if self.chunksize:
193 nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0)
212 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
194 213 else:
195 nparts = len_0
214 nparts = maxlen
196 215 targets = [None]*nparts
197 216 else:
198 217 if self.chunksize:
199 218 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
200 219 # multiplexed:
201 220 targets = self.view.targets
202 221 # 'all' is lazily evaluated at execution time, which is now:
203 222 if targets == 'all':
204 223 targets = client._build_targets(targets)[1]
205 224 elif isinstance(targets, int):
206 225 # single-engine view, targets must be iterable
207 226 targets = [targets]
208 227 nparts = len(targets)
209 228
210 229 msg_ids = []
211 230 for index, t in enumerate(targets):
212 231 args = []
213 232 for seq in sequences:
214 part = self.mapObject.getPartition(seq, index, nparts)
215 if len(part) == 0:
216 continue
217 else:
218 args.append(part)
219 if not args:
233 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
234 args.append(part)
235 if not any(args):
220 236 continue
221 237
222 # print (args)
223 if hasattr(self, '_map'):
238 if self._mapping:
224 239 if sys.version_info[0] >= 3:
225 240 f = lambda f, *sequences: list(map(f, *sequences))
226 241 else:
227 242 f = map
228 args = [self.func]+args
243 args = [self.func] + args
229 244 else:
230 245 f=self.func
231 246
232 247 view = self.view if balanced else client[t]
233 248 with view.temp_flags(block=False, **self.flags):
234 249 ar = view.apply(f, *args)
235 250
236 msg_ids.append(ar.msg_ids[0])
251 msg_ids.extend(ar.msg_ids)
237 252
238 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
253 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
239 254 fname=getname(self.func),
240 255 ordered=self.ordered
241 256 )
242 257
243 258 if self.block:
244 259 try:
245 260 return r.get()
246 261 except KeyboardInterrupt:
247 262 return r
248 263 else:
249 264 return r
250 265
251 266 def map(self, *sequences):
252 """call a function on each element of a sequence remotely.
267 """call a function on each element of one or more sequence(s) remotely.
253 268 This should behave very much like the builtin map, but return an AsyncMapResult
254 269 if self.block is False.
270
271 That means it can take generators (will be cast to lists locally),
272 and mismatched sequence lengths will be padded with None.
255 273 """
256 # set _map as a flag for use inside self.__call__
257 self._map = True
274 # set _mapping as a flag for use inside self.__call__
275 self._mapping = True
258 276 try:
259 ret = self.__call__(*sequences)
277 ret = self(*sequences)
260 278 finally:
261 del self._map
279 self._mapping = False
262 280 return ret
263 281
264 282 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,177 +1,211 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test LoadBalancedView objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import time
21 21
22 22 import zmq
23 23 from nose import SkipTest
24 24 from nose.plugins.attrib import attr
25 25
26 26 from IPython import parallel as pmod
27 27 from IPython.parallel import error
28 28
29 29 from IPython.parallel.tests import add_engines
30 30
31 31 from .clienttest import ClusterTestCase, crash, wait, skip_without
32 32
33 33 def setup():
34 34 add_engines(3, total=True)
35 35
36 36 class TestLoadBalancedView(ClusterTestCase):
37 37
38 38 def setUp(self):
39 39 ClusterTestCase.setUp(self)
40 40 self.view = self.client.load_balanced_view()
41 41
42 42 @attr('crash')
43 43 def test_z_crash_task(self):
44 44 """test graceful handling of engine death (balanced)"""
45 45 # self.add_engines(1)
46 46 ar = self.view.apply_async(crash)
47 47 self.assertRaisesRemote(error.EngineError, ar.get, 10)
48 48 eid = ar.engine_id
49 49 tic = time.time()
50 50 while eid in self.client.ids and time.time()-tic < 5:
51 51 time.sleep(.01)
52 52 self.client.spin()
53 53 self.assertFalse(eid in self.client.ids, "Engine should have died")
54 54
55 55 def test_map(self):
56 56 def f(x):
57 57 return x**2
58 58 data = range(16)
59 59 r = self.view.map_sync(f, data)
60 60 self.assertEqual(r, map(f, data))
61
62 def test_map_generator(self):
63 def f(x):
64 return x**2
65
66 data = range(16)
67 r = self.view.map_sync(f, iter(data))
68 self.assertEqual(r, map(f, iter(data)))
69
70 def test_map_short_first(self):
71 def f(x,y):
72 if y is None:
73 return y
74 if x is None:
75 return x
76 return x*y
77 data = range(10)
78 data2 = range(4)
79
80 r = self.view.map_sync(f, data, data2)
81 self.assertEqual(r, map(f, data, data2))
82
83 def test_map_short_last(self):
84 def f(x,y):
85 if y is None:
86 return y
87 if x is None:
88 return x
89 return x*y
90 data = range(4)
91 data2 = range(10)
92
93 r = self.view.map_sync(f, data, data2)
94 self.assertEqual(r, map(f, data, data2))
61 95
62 96 def test_map_unordered(self):
63 97 def f(x):
64 98 return x**2
65 99 def slow_f(x):
66 100 import time
67 101 time.sleep(0.05*x)
68 102 return x**2
69 103 data = range(16,0,-1)
70 104 reference = map(f, data)
71 105
72 106 amr = self.view.map_async(slow_f, data, ordered=False)
73 107 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
74 108 # check individual elements, retrieved as they come
75 109 # list comprehension uses __iter__
76 110 astheycame = [ r for r in amr ]
77 111 # Ensure that at least one result came out of order:
78 112 self.assertNotEqual(astheycame, reference, "should not have preserved order")
79 113 self.assertEqual(sorted(astheycame, reverse=True), reference, "result corrupted")
80 114
81 115 def test_map_ordered(self):
82 116 def f(x):
83 117 return x**2
84 118 def slow_f(x):
85 119 import time
86 120 time.sleep(0.05*x)
87 121 return x**2
88 122 data = range(16,0,-1)
89 123 reference = map(f, data)
90 124
91 125 amr = self.view.map_async(slow_f, data)
92 126 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
93 127 # check individual elements, retrieved as they come
94 128 # list(amr) uses __iter__
95 129 astheycame = list(amr)
96 130 # Ensure that results came in order
97 131 self.assertEqual(astheycame, reference)
98 132 self.assertEqual(amr.result, reference)
99 133
100 134 def test_map_iterable(self):
101 135 """test map on iterables (balanced)"""
102 136 view = self.view
103 137 # 101 is prime, so it won't be evenly distributed
104 138 arr = range(101)
105 139 # so that it will be an iterator, even in Python 3
106 140 it = iter(arr)
107 141 r = view.map_sync(lambda x:x, arr)
108 142 self.assertEqual(r, list(arr))
109 143
110 144
111 145 def test_abort(self):
112 146 view = self.view
113 147 ar = self.client[:].apply_async(time.sleep, .5)
114 148 ar = self.client[:].apply_async(time.sleep, .5)
115 149 time.sleep(0.2)
116 150 ar2 = view.apply_async(lambda : 2)
117 151 ar3 = view.apply_async(lambda : 3)
118 152 view.abort(ar2)
119 153 view.abort(ar3.msg_ids)
120 154 self.assertRaises(error.TaskAborted, ar2.get)
121 155 self.assertRaises(error.TaskAborted, ar3.get)
122 156
123 157 def test_retries(self):
124 158 view = self.view
125 159 view.timeout = 1 # prevent hang if this doesn't behave
126 160 def fail():
127 161 assert False
128 162 for r in range(len(self.client)-1):
129 163 with view.temp_flags(retries=r):
130 164 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
131 165
132 166 with view.temp_flags(retries=len(self.client), timeout=0.25):
133 167 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
134 168
135 169 def test_invalid_dependency(self):
136 170 view = self.view
137 171 with view.temp_flags(after='12345'):
138 172 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
139 173
140 174 def test_impossible_dependency(self):
141 175 self.minimum_engines(2)
142 176 view = self.client.load_balanced_view()
143 177 ar1 = view.apply_async(lambda : 1)
144 178 ar1.get()
145 179 e1 = ar1.engine_id
146 180 e2 = e1
147 181 while e2 == e1:
148 182 ar2 = view.apply_async(lambda : 1)
149 183 ar2.get()
150 184 e2 = ar2.engine_id
151 185
152 186 with view.temp_flags(follow=[ar1, ar2]):
153 187 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
154 188
155 189
156 190 def test_follow(self):
157 191 ar = self.view.apply_async(lambda : 1)
158 192 ar.get()
159 193 ars = []
160 194 first_id = ar.engine_id
161 195
162 196 self.view.follow = ar
163 197 for i in range(5):
164 198 ars.append(self.view.apply_async(lambda : 1))
165 199 self.view.wait(ars)
166 200 for ar in ars:
167 201 self.assertEqual(ar.engine_id, first_id)
168 202
169 203 def test_after(self):
170 204 view = self.view
171 205 ar = view.apply_async(time.sleep, 0.5)
172 206 with view.temp_flags(after=ar):
173 207 ar2 = view.apply_async(lambda : 1)
174 208
175 209 ar.wait()
176 210 ar2.wait()
177 211 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
General Comments 0
You need to be logged in to leave comments. Login now