##// END OF EJS Templates
Merge pull request #3282 from minrk/mapgenerator...
Brian E. Granger -
r10605:9f313b4a merge
parent child Browse files
Show More
@@ -1,170 +1,167 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Classes used in scattering and gathering sequences.
3 """Classes used in scattering and gathering sequences.
4
4
5 Scattering consists of partitioning a sequence and sending the various
5 Scattering consists of partitioning a sequence and sending the various
6 pieces to individual nodes in a cluster.
6 pieces to individual nodes in a cluster.
7
7
8
8
9 Authors:
9 Authors:
10
10
11 * Brian Granger
11 * Brian Granger
12 * MinRK
12 * MinRK
13
13
14 """
14 """
15
15
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17 # Copyright (C) 2008-2011 The IPython Development Team
17 # Copyright (C) 2008-2011 The IPython Development Team
18 #
18 #
19 # Distributed under the terms of the BSD License. The full license is in
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
20 # the file COPYING, distributed as part of this software.
21 #-------------------------------------------------------------------------------
21 #-------------------------------------------------------------------------------
22
22
23 #-------------------------------------------------------------------------------
23 #-------------------------------------------------------------------------------
24 # Imports
24 # Imports
25 #-------------------------------------------------------------------------------
25 #-------------------------------------------------------------------------------
26
26
27 from __future__ import division
27 from __future__ import division
28
28
29 import types
29 import types
30 from itertools import islice
30 from itertools import islice
31
31
32 from IPython.utils.data import flatten as utils_flatten
32 from IPython.utils.data import flatten as utils_flatten
33
33
34 #-------------------------------------------------------------------------------
34 #-------------------------------------------------------------------------------
35 # Figure out which array packages are present and their array types
35 # Figure out which array packages are present and their array types
36 #-------------------------------------------------------------------------------
36 #-------------------------------------------------------------------------------
37
37
38 arrayModules = []
38 arrayModules = []
39 try:
39 try:
40 import Numeric
40 import Numeric
41 except ImportError:
41 except ImportError:
42 pass
42 pass
43 else:
43 else:
44 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
44 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
45 try:
45 try:
46 import numpy
46 import numpy
47 except ImportError:
47 except ImportError:
48 pass
48 pass
49 else:
49 else:
50 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
50 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
51 try:
51 try:
52 import numarray
52 import numarray
53 except ImportError:
53 except ImportError:
54 pass
54 pass
55 else:
55 else:
56 arrayModules.append({'module':numarray,
56 arrayModules.append({'module':numarray,
57 'type':numarray.numarraycore.NumArray})
57 'type':numarray.numarraycore.NumArray})
58
58
59 class Map(object):
59 class Map(object):
60 """A class for partitioning a sequence using a map."""
60 """A class for partitioning a sequence using a map."""
61
61
62 def getPartition(self, seq, p, q):
62 def getPartition(self, seq, p, q, n=None):
63 """Returns the pth partition of q partitions of seq."""
63 """Returns the pth partition of q partitions of seq.
64
64
65 The length can be specified as `n`,
66 otherwise it is the value of `len(seq)`
67 """
68 n = len(seq) if n is None else n
65 # Test for error conditions here
69 # Test for error conditions here
66 if p<0 or p>=q:
70 if p<0 or p>=q:
67 print "No partition exists."
71 raise ValueError("must have 0 <= p <= q, but have p=%s,q=%s" % (p, q))
68 return
69
72
70 N = len(seq)
73 remainder = n % q
71 remainder = N % q
74 basesize = n // q
72 basesize = N // q
73
75
74 if p < remainder:
76 if p < remainder:
75 low = p * (basesize + 1)
77 low = p * (basesize + 1)
76 high = low + basesize + 1
78 high = low + basesize + 1
77 else:
79 else:
78 low = p * basesize + remainder
80 low = p * basesize + remainder
79 high = low + basesize
81 high = low + basesize
80
82
81 try:
83 try:
82 result = seq[low:high]
84 result = seq[low:high]
83 except TypeError:
85 except TypeError:
84 # some objects (iterators) can't be sliced,
86 # some objects (iterators) can't be sliced,
85 # use islice:
87 # use islice:
86 result = list(islice(seq, low, high))
88 result = list(islice(seq, low, high))
87
89
88 return result
90 return result
89
91
90 def joinPartitions(self, listOfPartitions):
92 def joinPartitions(self, listOfPartitions):
91 return self.concatenate(listOfPartitions)
93 return self.concatenate(listOfPartitions)
92
94
93 def concatenate(self, listOfPartitions):
95 def concatenate(self, listOfPartitions):
94 testObject = listOfPartitions[0]
96 testObject = listOfPartitions[0]
95 # First see if we have a known array type
97 # First see if we have a known array type
96 for m in arrayModules:
98 for m in arrayModules:
97 #print m
99 #print m
98 if isinstance(testObject, m['type']):
100 if isinstance(testObject, m['type']):
99 return m['module'].concatenate(listOfPartitions)
101 return m['module'].concatenate(listOfPartitions)
100 # Next try for Python sequence types
102 # Next try for Python sequence types
101 if isinstance(testObject, (types.ListType, types.TupleType)):
103 if isinstance(testObject, (types.ListType, types.TupleType)):
102 return utils_flatten(listOfPartitions)
104 return utils_flatten(listOfPartitions)
103 # If we have scalars, just return listOfPartitions
105 # If we have scalars, just return listOfPartitions
104 return listOfPartitions
106 return listOfPartitions
105
107
106 class RoundRobinMap(Map):
108 class RoundRobinMap(Map):
107 """Partitions a sequence in a roun robin fashion.
109 """Partitions a sequence in a round robin fashion.
108
110
109 This currently does not work!
111 This currently does not work!
110 """
112 """
111
113
112 def getPartition(self, seq, p, q):
114 def getPartition(self, seq, p, q, n=None):
113 # if not isinstance(seq,(list,tuple)):
115 n = len(seq) if n is None else n
114 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
116 return seq[p:n:q]
115 return seq[p:len(seq):q]
116 #result = []
117 #for i in range(p,len(seq),q):
118 # result.append(seq[i])
119 #return result
120
117
121 def joinPartitions(self, listOfPartitions):
118 def joinPartitions(self, listOfPartitions):
122 testObject = listOfPartitions[0]
119 testObject = listOfPartitions[0]
123 # First see if we have a known array type
120 # First see if we have a known array type
124 for m in arrayModules:
121 for m in arrayModules:
125 #print m
122 #print m
126 if isinstance(testObject, m['type']):
123 if isinstance(testObject, m['type']):
127 return self.flatten_array(m['type'], listOfPartitions)
124 return self.flatten_array(m['type'], listOfPartitions)
128 if isinstance(testObject, (types.ListType, types.TupleType)):
125 if isinstance(testObject, (types.ListType, types.TupleType)):
129 return self.flatten_list(listOfPartitions)
126 return self.flatten_list(listOfPartitions)
130 return listOfPartitions
127 return listOfPartitions
131
128
132 def flatten_array(self, klass, listOfPartitions):
129 def flatten_array(self, klass, listOfPartitions):
133 test = listOfPartitions[0]
130 test = listOfPartitions[0]
134 shape = list(test.shape)
131 shape = list(test.shape)
135 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
132 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
136 A = klass(shape)
133 A = klass(shape)
137 N = shape[0]
134 N = shape[0]
138 q = len(listOfPartitions)
135 q = len(listOfPartitions)
139 for p,part in enumerate(listOfPartitions):
136 for p,part in enumerate(listOfPartitions):
140 A[p:N:q] = part
137 A[p:N:q] = part
141 return A
138 return A
142
139
143 def flatten_list(self, listOfPartitions):
140 def flatten_list(self, listOfPartitions):
144 flat = []
141 flat = []
145 for i in range(len(listOfPartitions[0])):
142 for i in range(len(listOfPartitions[0])):
146 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
143 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
147 return flat
144 return flat
148 #lengths = [len(x) for x in listOfPartitions]
145 #lengths = [len(x) for x in listOfPartitions]
149 #maxPartitionLength = len(listOfPartitions[0])
146 #maxPartitionLength = len(listOfPartitions[0])
150 #numberOfPartitions = len(listOfPartitions)
147 #numberOfPartitions = len(listOfPartitions)
151 #concat = self.concatenate(listOfPartitions)
148 #concat = self.concatenate(listOfPartitions)
152 #totalLength = len(concat)
149 #totalLength = len(concat)
153 #result = []
150 #result = []
154 #for i in range(maxPartitionLength):
151 #for i in range(maxPartitionLength):
155 # result.append(concat[i:totalLength:maxPartitionLength])
152 # result.append(concat[i:totalLength:maxPartitionLength])
156 # return self.concatenate(listOfPartitions)
153 # return self.concatenate(listOfPartitions)
157
154
158 def mappable(obj):
155 def mappable(obj):
159 """return whether an object is mappable or not."""
156 """return whether an object is mappable or not."""
160 if isinstance(obj, (tuple,list)):
157 if isinstance(obj, (tuple,list)):
161 return True
158 return True
162 for m in arrayModules:
159 for m in arrayModules:
163 if isinstance(obj,m['type']):
160 if isinstance(obj,m['type']):
164 return True
161 return True
165 return False
162 return False
166
163
167 dists = {'b':Map,'r':RoundRobinMap}
164 dists = {'b':Map,'r':RoundRobinMap}
168
165
169
166
170
167
@@ -1,264 +1,282 b''
1 """Remote Functions and decorators for Views.
1 """Remote Functions and decorators for Views.
2
2
3 Authors:
3 Authors:
4
4
5 * Brian Granger
5 * Brian Granger
6 * Min RK
6 * Min RK
7 """
7 """
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import sys
21 import sys
22 import warnings
22 import warnings
23
23
24 from IPython.external.decorator import decorator
24 from IPython.external.decorator import decorator
25 from IPython.testing.skipdoctest import skip_doctest
25 from IPython.testing.skipdoctest import skip_doctest
26
26
27 from . import map as Map
27 from . import map as Map
28 from .asyncresult import AsyncMapResult
28 from .asyncresult import AsyncMapResult
29
29
30 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
31 # Functions and Decorators
31 # Functions and Decorators
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33
33
34 @skip_doctest
34 @skip_doctest
35 def remote(view, block=None, **flags):
35 def remote(view, block=None, **flags):
36 """Turn a function into a remote function.
36 """Turn a function into a remote function.
37
37
38 This method can be used for map:
38 This method can be used for map:
39
39
40 In [1]: @remote(view,block=True)
40 In [1]: @remote(view,block=True)
41 ...: def func(a):
41 ...: def func(a):
42 ...: pass
42 ...: pass
43 """
43 """
44
44
45 def remote_function(f):
45 def remote_function(f):
46 return RemoteFunction(view, f, block=block, **flags)
46 return RemoteFunction(view, f, block=block, **flags)
47 return remote_function
47 return remote_function
48
48
49 @skip_doctest
49 @skip_doctest
50 def parallel(view, dist='b', block=None, ordered=True, **flags):
50 def parallel(view, dist='b', block=None, ordered=True, **flags):
51 """Turn a function into a parallel remote function.
51 """Turn a function into a parallel remote function.
52
52
53 This method can be used for map:
53 This method can be used for map:
54
54
55 In [1]: @parallel(view, block=True)
55 In [1]: @parallel(view, block=True)
56 ...: def func(a):
56 ...: def func(a):
57 ...: pass
57 ...: pass
58 """
58 """
59
59
60 def parallel_function(f):
60 def parallel_function(f):
61 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
61 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
62 return parallel_function
62 return parallel_function
63
63
64 def getname(f):
64 def getname(f):
65 """Get the name of an object.
65 """Get the name of an object.
66
66
67 For use in case of callables that are not functions, and
67 For use in case of callables that are not functions, and
68 thus may not have __name__ defined.
68 thus may not have __name__ defined.
69
69
70 Order: f.__name__ > f.name > str(f)
70 Order: f.__name__ > f.name > str(f)
71 """
71 """
72 try:
72 try:
73 return f.__name__
73 return f.__name__
74 except:
74 except:
75 pass
75 pass
76 try:
76 try:
77 return f.name
77 return f.name
78 except:
78 except:
79 pass
79 pass
80
80
81 return str(f)
81 return str(f)
82
82
83 @decorator
83 @decorator
84 def sync_view_results(f, self, *args, **kwargs):
84 def sync_view_results(f, self, *args, **kwargs):
85 """sync relevant results from self.client to our results attribute.
85 """sync relevant results from self.client to our results attribute.
86
86
87 This is a clone of view.sync_results, but for remote functions
87 This is a clone of view.sync_results, but for remote functions
88 """
88 """
89 view = self.view
89 view = self.view
90 if view._in_sync_results:
90 if view._in_sync_results:
91 return f(self, *args, **kwargs)
91 return f(self, *args, **kwargs)
92 print 'in sync results', f
92 print 'in sync results', f
93 view._in_sync_results = True
93 view._in_sync_results = True
94 try:
94 try:
95 ret = f(self, *args, **kwargs)
95 ret = f(self, *args, **kwargs)
96 finally:
96 finally:
97 view._in_sync_results = False
97 view._in_sync_results = False
98 view._sync_results()
98 view._sync_results()
99 return ret
99 return ret
100
100
101 #--------------------------------------------------------------------------
101 #--------------------------------------------------------------------------
102 # Classes
102 # Classes
103 #--------------------------------------------------------------------------
103 #--------------------------------------------------------------------------
104
104
105 class RemoteFunction(object):
105 class RemoteFunction(object):
106 """Turn an existing function into a remote function.
106 """Turn an existing function into a remote function.
107
107
108 Parameters
108 Parameters
109 ----------
109 ----------
110
110
111 view : View instance
111 view : View instance
112 The view to be used for execution
112 The view to be used for execution
113 f : callable
113 f : callable
114 The function to be wrapped into a remote function
114 The function to be wrapped into a remote function
115 block : bool [default: None]
115 block : bool [default: None]
116 Whether to wait for results or not. The default behavior is
116 Whether to wait for results or not. The default behavior is
117 to use the current `block` attribute of `view`
117 to use the current `block` attribute of `view`
118
118
119 **flags : remaining kwargs are passed to View.temp_flags
119 **flags : remaining kwargs are passed to View.temp_flags
120 """
120 """
121
121
122 view = None # the remote connection
122 view = None # the remote connection
123 func = None # the wrapped function
123 func = None # the wrapped function
124 block = None # whether to block
124 block = None # whether to block
125 flags = None # dict of extra kwargs for temp_flags
125 flags = None # dict of extra kwargs for temp_flags
126
126
127 def __init__(self, view, f, block=None, **flags):
127 def __init__(self, view, f, block=None, **flags):
128 self.view = view
128 self.view = view
129 self.func = f
129 self.func = f
130 self.block=block
130 self.block=block
131 self.flags=flags
131 self.flags=flags
132
132
133 def __call__(self, *args, **kwargs):
133 def __call__(self, *args, **kwargs):
134 block = self.view.block if self.block is None else self.block
134 block = self.view.block if self.block is None else self.block
135 with self.view.temp_flags(block=block, **self.flags):
135 with self.view.temp_flags(block=block, **self.flags):
136 return self.view.apply(self.func, *args, **kwargs)
136 return self.view.apply(self.func, *args, **kwargs)
137
137
138
138
139 class ParallelFunction(RemoteFunction):
139 class ParallelFunction(RemoteFunction):
140 """Class for mapping a function to sequences.
140 """Class for mapping a function to sequences.
141
141
142 This will distribute the sequences according the a mapper, and call
142 This will distribute the sequences according the a mapper, and call
143 the function on each sub-sequence. If called via map, then the function
143 the function on each sub-sequence. If called via map, then the function
144 will be called once on each element, rather that each sub-sequence.
144 will be called once on each element, rather that each sub-sequence.
145
145
146 Parameters
146 Parameters
147 ----------
147 ----------
148
148
149 view : View instance
149 view : View instance
150 The view to be used for execution
150 The view to be used for execution
151 f : callable
151 f : callable
152 The function to be wrapped into a remote function
152 The function to be wrapped into a remote function
153 dist : str [default: 'b']
153 dist : str [default: 'b']
154 The key for which mapObject to use to distribute sequences
154 The key for which mapObject to use to distribute sequences
155 options are:
155 options are:
156 * 'b' : use contiguous chunks in order
156 * 'b' : use contiguous chunks in order
157 * 'r' : use round-robin striping
157 * 'r' : use round-robin striping
158 block : bool [default: None]
158 block : bool [default: None]
159 Whether to wait for results or not. The default behavior is
159 Whether to wait for results or not. The default behavior is
160 to use the current `block` attribute of `view`
160 to use the current `block` attribute of `view`
161 chunksize : int or None
161 chunksize : int or None
162 The size of chunk to use when breaking up sequences in a load-balanced manner
162 The size of chunk to use when breaking up sequences in a load-balanced manner
163 ordered : bool [default: True]
163 ordered : bool [default: True]
164 Whether
164 Whether the result should be kept in order. If False,
165 results become available as they arrive, regardless of submission order.
165 **flags : remaining kwargs are passed to View.temp_flags
166 **flags : remaining kwargs are passed to View.temp_flags
166 """
167 """
167
168
168 chunksize=None
169 chunksize = None
169 ordered=None
170 ordered = None
170 mapObject=None
171 mapObject = None
172 _mapping = False
171
173
172 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
174 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
173 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
175 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
174 self.chunksize = chunksize
176 self.chunksize = chunksize
175 self.ordered = ordered
177 self.ordered = ordered
176
178
177 mapClass = Map.dists[dist]
179 mapClass = Map.dists[dist]
178 self.mapObject = mapClass()
180 self.mapObject = mapClass()
179
181
180 @sync_view_results
182 @sync_view_results
181 def __call__(self, *sequences):
183 def __call__(self, *sequences):
182 client = self.view.client
184 client = self.view.client
183
185
186 lens = []
187 maxlen = minlen = -1
188 for i, seq in enumerate(sequences):
189 try:
190 n = len(seq)
191 except Exception:
192 seq = list(seq)
193 if isinstance(sequences, tuple):
194 # can't alter a tuple
195 sequences = list(sequences)
196 sequences[i] = seq
197 n = len(seq)
198 if n > maxlen:
199 maxlen = n
200 if minlen == -1 or n < minlen:
201 minlen = n
202 lens.append(n)
203
184 # check that the length of sequences match
204 # check that the length of sequences match
185 len_0 = len(sequences[0])
205 if not self._mapping and minlen != maxlen:
186 for s in sequences:
206 msg = 'all sequences must have equal length, but have %s' % lens
187 if len(s)!=len_0:
207 raise ValueError(msg)
188 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
208
189 raise ValueError(msg)
190 balanced = 'Balanced' in self.view.__class__.__name__
209 balanced = 'Balanced' in self.view.__class__.__name__
191 if balanced:
210 if balanced:
192 if self.chunksize:
211 if self.chunksize:
193 nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0)
212 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
194 else:
213 else:
195 nparts = len_0
214 nparts = maxlen
196 targets = [None]*nparts
215 targets = [None]*nparts
197 else:
216 else:
198 if self.chunksize:
217 if self.chunksize:
199 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
218 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
200 # multiplexed:
219 # multiplexed:
201 targets = self.view.targets
220 targets = self.view.targets
202 # 'all' is lazily evaluated at execution time, which is now:
221 # 'all' is lazily evaluated at execution time, which is now:
203 if targets == 'all':
222 if targets == 'all':
204 targets = client._build_targets(targets)[1]
223 targets = client._build_targets(targets)[1]
205 elif isinstance(targets, int):
224 elif isinstance(targets, int):
206 # single-engine view, targets must be iterable
225 # single-engine view, targets must be iterable
207 targets = [targets]
226 targets = [targets]
208 nparts = len(targets)
227 nparts = len(targets)
209
228
210 msg_ids = []
229 msg_ids = []
211 for index, t in enumerate(targets):
230 for index, t in enumerate(targets):
212 args = []
231 args = []
213 for seq in sequences:
232 for seq in sequences:
214 part = self.mapObject.getPartition(seq, index, nparts)
233 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
215 if len(part) == 0:
234 args.append(part)
216 continue
235 if not any(args):
217 else:
218 args.append(part)
219 if not args:
220 continue
236 continue
221
237
222 # print (args)
238 if self._mapping:
223 if hasattr(self, '_map'):
224 if sys.version_info[0] >= 3:
239 if sys.version_info[0] >= 3:
225 f = lambda f, *sequences: list(map(f, *sequences))
240 f = lambda f, *sequences: list(map(f, *sequences))
226 else:
241 else:
227 f = map
242 f = map
228 args = [self.func]+args
243 args = [self.func] + args
229 else:
244 else:
230 f=self.func
245 f=self.func
231
246
232 view = self.view if balanced else client[t]
247 view = self.view if balanced else client[t]
233 with view.temp_flags(block=False, **self.flags):
248 with view.temp_flags(block=False, **self.flags):
234 ar = view.apply(f, *args)
249 ar = view.apply(f, *args)
235
250
236 msg_ids.append(ar.msg_ids[0])
251 msg_ids.extend(ar.msg_ids)
237
252
238 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
253 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
239 fname=getname(self.func),
254 fname=getname(self.func),
240 ordered=self.ordered
255 ordered=self.ordered
241 )
256 )
242
257
243 if self.block:
258 if self.block:
244 try:
259 try:
245 return r.get()
260 return r.get()
246 except KeyboardInterrupt:
261 except KeyboardInterrupt:
247 return r
262 return r
248 else:
263 else:
249 return r
264 return r
250
265
251 def map(self, *sequences):
266 def map(self, *sequences):
252 """call a function on each element of a sequence remotely.
267 """call a function on each element of one or more sequence(s) remotely.
253 This should behave very much like the builtin map, but return an AsyncMapResult
268 This should behave very much like the builtin map, but return an AsyncMapResult
254 if self.block is False.
269 if self.block is False.
270
271 That means it can take generators (will be cast to lists locally),
272 and mismatched sequence lengths will be padded with None.
255 """
273 """
256 # set _map as a flag for use inside self.__call__
274 # set _mapping as a flag for use inside self.__call__
257 self._map = True
275 self._mapping = True
258 try:
276 try:
259 ret = self.__call__(*sequences)
277 ret = self(*sequences)
260 finally:
278 finally:
261 del self._map
279 self._mapping = False
262 return ret
280 return ret
263
281
264 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
282 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,177 +1,211 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test LoadBalancedView objects
2 """test LoadBalancedView objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import time
20 import time
21
21
22 import zmq
22 import zmq
23 from nose import SkipTest
23 from nose import SkipTest
24 from nose.plugins.attrib import attr
24 from nose.plugins.attrib import attr
25
25
26 from IPython import parallel as pmod
26 from IPython import parallel as pmod
27 from IPython.parallel import error
27 from IPython.parallel import error
28
28
29 from IPython.parallel.tests import add_engines
29 from IPython.parallel.tests import add_engines
30
30
31 from .clienttest import ClusterTestCase, crash, wait, skip_without
31 from .clienttest import ClusterTestCase, crash, wait, skip_without
32
32
33 def setup():
33 def setup():
34 add_engines(3, total=True)
34 add_engines(3, total=True)
35
35
36 class TestLoadBalancedView(ClusterTestCase):
36 class TestLoadBalancedView(ClusterTestCase):
37
37
38 def setUp(self):
38 def setUp(self):
39 ClusterTestCase.setUp(self)
39 ClusterTestCase.setUp(self)
40 self.view = self.client.load_balanced_view()
40 self.view = self.client.load_balanced_view()
41
41
42 @attr('crash')
42 @attr('crash')
43 def test_z_crash_task(self):
43 def test_z_crash_task(self):
44 """test graceful handling of engine death (balanced)"""
44 """test graceful handling of engine death (balanced)"""
45 # self.add_engines(1)
45 # self.add_engines(1)
46 ar = self.view.apply_async(crash)
46 ar = self.view.apply_async(crash)
47 self.assertRaisesRemote(error.EngineError, ar.get, 10)
47 self.assertRaisesRemote(error.EngineError, ar.get, 10)
48 eid = ar.engine_id
48 eid = ar.engine_id
49 tic = time.time()
49 tic = time.time()
50 while eid in self.client.ids and time.time()-tic < 5:
50 while eid in self.client.ids and time.time()-tic < 5:
51 time.sleep(.01)
51 time.sleep(.01)
52 self.client.spin()
52 self.client.spin()
53 self.assertFalse(eid in self.client.ids, "Engine should have died")
53 self.assertFalse(eid in self.client.ids, "Engine should have died")
54
54
55 def test_map(self):
55 def test_map(self):
56 def f(x):
56 def f(x):
57 return x**2
57 return x**2
58 data = range(16)
58 data = range(16)
59 r = self.view.map_sync(f, data)
59 r = self.view.map_sync(f, data)
60 self.assertEqual(r, map(f, data))
60 self.assertEqual(r, map(f, data))
61
62 def test_map_generator(self):
63 def f(x):
64 return x**2
65
66 data = range(16)
67 r = self.view.map_sync(f, iter(data))
68 self.assertEqual(r, map(f, iter(data)))
69
70 def test_map_short_first(self):
71 def f(x,y):
72 if y is None:
73 return y
74 if x is None:
75 return x
76 return x*y
77 data = range(10)
78 data2 = range(4)
79
80 r = self.view.map_sync(f, data, data2)
81 self.assertEqual(r, map(f, data, data2))
82
83 def test_map_short_last(self):
84 def f(x,y):
85 if y is None:
86 return y
87 if x is None:
88 return x
89 return x*y
90 data = range(4)
91 data2 = range(10)
92
93 r = self.view.map_sync(f, data, data2)
94 self.assertEqual(r, map(f, data, data2))
61
95
62 def test_map_unordered(self):
96 def test_map_unordered(self):
63 def f(x):
97 def f(x):
64 return x**2
98 return x**2
65 def slow_f(x):
99 def slow_f(x):
66 import time
100 import time
67 time.sleep(0.05*x)
101 time.sleep(0.05*x)
68 return x**2
102 return x**2
69 data = range(16,0,-1)
103 data = range(16,0,-1)
70 reference = map(f, data)
104 reference = map(f, data)
71
105
72 amr = self.view.map_async(slow_f, data, ordered=False)
106 amr = self.view.map_async(slow_f, data, ordered=False)
73 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
107 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
74 # check individual elements, retrieved as they come
108 # check individual elements, retrieved as they come
75 # list comprehension uses __iter__
109 # list comprehension uses __iter__
76 astheycame = [ r for r in amr ]
110 astheycame = [ r for r in amr ]
77 # Ensure that at least one result came out of order:
111 # Ensure that at least one result came out of order:
78 self.assertNotEqual(astheycame, reference, "should not have preserved order")
112 self.assertNotEqual(astheycame, reference, "should not have preserved order")
79 self.assertEqual(sorted(astheycame, reverse=True), reference, "result corrupted")
113 self.assertEqual(sorted(astheycame, reverse=True), reference, "result corrupted")
80
114
81 def test_map_ordered(self):
115 def test_map_ordered(self):
82 def f(x):
116 def f(x):
83 return x**2
117 return x**2
84 def slow_f(x):
118 def slow_f(x):
85 import time
119 import time
86 time.sleep(0.05*x)
120 time.sleep(0.05*x)
87 return x**2
121 return x**2
88 data = range(16,0,-1)
122 data = range(16,0,-1)
89 reference = map(f, data)
123 reference = map(f, data)
90
124
91 amr = self.view.map_async(slow_f, data)
125 amr = self.view.map_async(slow_f, data)
92 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
126 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
93 # check individual elements, retrieved as they come
127 # check individual elements, retrieved as they come
94 # list(amr) uses __iter__
128 # list(amr) uses __iter__
95 astheycame = list(amr)
129 astheycame = list(amr)
96 # Ensure that results came in order
130 # Ensure that results came in order
97 self.assertEqual(astheycame, reference)
131 self.assertEqual(astheycame, reference)
98 self.assertEqual(amr.result, reference)
132 self.assertEqual(amr.result, reference)
99
133
100 def test_map_iterable(self):
134 def test_map_iterable(self):
101 """test map on iterables (balanced)"""
135 """test map on iterables (balanced)"""
102 view = self.view
136 view = self.view
103 # 101 is prime, so it won't be evenly distributed
137 # 101 is prime, so it won't be evenly distributed
104 arr = range(101)
138 arr = range(101)
105 # so that it will be an iterator, even in Python 3
139 # so that it will be an iterator, even in Python 3
106 it = iter(arr)
140 it = iter(arr)
107 r = view.map_sync(lambda x:x, arr)
141 r = view.map_sync(lambda x:x, arr)
108 self.assertEqual(r, list(arr))
142 self.assertEqual(r, list(arr))
109
143
110
144
111 def test_abort(self):
145 def test_abort(self):
112 view = self.view
146 view = self.view
113 ar = self.client[:].apply_async(time.sleep, .5)
147 ar = self.client[:].apply_async(time.sleep, .5)
114 ar = self.client[:].apply_async(time.sleep, .5)
148 ar = self.client[:].apply_async(time.sleep, .5)
115 time.sleep(0.2)
149 time.sleep(0.2)
116 ar2 = view.apply_async(lambda : 2)
150 ar2 = view.apply_async(lambda : 2)
117 ar3 = view.apply_async(lambda : 3)
151 ar3 = view.apply_async(lambda : 3)
118 view.abort(ar2)
152 view.abort(ar2)
119 view.abort(ar3.msg_ids)
153 view.abort(ar3.msg_ids)
120 self.assertRaises(error.TaskAborted, ar2.get)
154 self.assertRaises(error.TaskAborted, ar2.get)
121 self.assertRaises(error.TaskAborted, ar3.get)
155 self.assertRaises(error.TaskAborted, ar3.get)
122
156
123 def test_retries(self):
157 def test_retries(self):
124 view = self.view
158 view = self.view
125 view.timeout = 1 # prevent hang if this doesn't behave
159 view.timeout = 1 # prevent hang if this doesn't behave
126 def fail():
160 def fail():
127 assert False
161 assert False
128 for r in range(len(self.client)-1):
162 for r in range(len(self.client)-1):
129 with view.temp_flags(retries=r):
163 with view.temp_flags(retries=r):
130 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
164 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
131
165
132 with view.temp_flags(retries=len(self.client), timeout=0.25):
166 with view.temp_flags(retries=len(self.client), timeout=0.25):
133 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
167 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
134
168
135 def test_invalid_dependency(self):
169 def test_invalid_dependency(self):
136 view = self.view
170 view = self.view
137 with view.temp_flags(after='12345'):
171 with view.temp_flags(after='12345'):
138 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
172 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
139
173
140 def test_impossible_dependency(self):
174 def test_impossible_dependency(self):
141 self.minimum_engines(2)
175 self.minimum_engines(2)
142 view = self.client.load_balanced_view()
176 view = self.client.load_balanced_view()
143 ar1 = view.apply_async(lambda : 1)
177 ar1 = view.apply_async(lambda : 1)
144 ar1.get()
178 ar1.get()
145 e1 = ar1.engine_id
179 e1 = ar1.engine_id
146 e2 = e1
180 e2 = e1
147 while e2 == e1:
181 while e2 == e1:
148 ar2 = view.apply_async(lambda : 1)
182 ar2 = view.apply_async(lambda : 1)
149 ar2.get()
183 ar2.get()
150 e2 = ar2.engine_id
184 e2 = ar2.engine_id
151
185
152 with view.temp_flags(follow=[ar1, ar2]):
186 with view.temp_flags(follow=[ar1, ar2]):
153 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
187 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
154
188
155
189
156 def test_follow(self):
190 def test_follow(self):
157 ar = self.view.apply_async(lambda : 1)
191 ar = self.view.apply_async(lambda : 1)
158 ar.get()
192 ar.get()
159 ars = []
193 ars = []
160 first_id = ar.engine_id
194 first_id = ar.engine_id
161
195
162 self.view.follow = ar
196 self.view.follow = ar
163 for i in range(5):
197 for i in range(5):
164 ars.append(self.view.apply_async(lambda : 1))
198 ars.append(self.view.apply_async(lambda : 1))
165 self.view.wait(ars)
199 self.view.wait(ars)
166 for ar in ars:
200 for ar in ars:
167 self.assertEqual(ar.engine_id, first_id)
201 self.assertEqual(ar.engine_id, first_id)
168
202
169 def test_after(self):
203 def test_after(self):
170 view = self.view
204 view = self.view
171 ar = view.apply_async(time.sleep, 0.5)
205 ar = view.apply_async(time.sleep, 0.5)
172 with view.temp_flags(after=ar):
206 with view.temp_flags(after=ar):
173 ar2 = view.apply_async(lambda : 1)
207 ar2 = view.apply_async(lambda : 1)
174
208
175 ar.wait()
209 ar.wait()
176 ar2.wait()
210 ar2.wait()
177 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
211 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
General Comments 0
You need to be logged in to leave comments. Login now