##// END OF EJS Templates
fix AsyncResult.abort...
MinRK -
Show More
@@ -1,396 +1,396 b''
1 1 """AsyncResult objects for the client
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import time
19 19
20 20 from zmq import MessageTracker
21 21
22 22 from IPython.external.decorator import decorator
23 23 from IPython.parallel import error
24 24
25 25 #-----------------------------------------------------------------------------
26 26 # Classes
27 27 #-----------------------------------------------------------------------------
28 28
29 29 # global empty tracker that's always done:
30 30 finished_tracker = MessageTracker()
31 31
32 32 @decorator
33 33 def check_ready(f, self, *args, **kwargs):
34 34 """Call spin() to sync state prior to calling the method."""
35 35 self.wait(0)
36 36 if not self._ready:
37 37 raise error.TimeoutError("result not ready")
38 38 return f(self, *args, **kwargs)
39 39
40 40 class AsyncResult(object):
41 41 """Class for representing results of non-blocking calls.
42 42
43 43 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
44 44 """
45 45
46 46 msg_ids = None
47 47 _targets = None
48 48 _tracker = None
49 49 _single_result = False
50 50
51 51 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
52 52 if isinstance(msg_ids, basestring):
53 53 # always a list
54 54 msg_ids = [msg_ids]
55 55 if tracker is None:
56 56 # default to always done
57 57 tracker = finished_tracker
58 58 self._client = client
59 59 self.msg_ids = msg_ids
60 60 self._fname=fname
61 61 self._targets = targets
62 62 self._tracker = tracker
63 63 self._ready = False
64 64 self._success = None
65 65 self._metadata = None
66 66 if len(msg_ids) == 1:
67 67 self._single_result = not isinstance(targets, (list, tuple))
68 68 else:
69 69 self._single_result = False
70 70
71 71 def __repr__(self):
72 72 if self._ready:
73 73 return "<%s: finished>"%(self.__class__.__name__)
74 74 else:
75 75 return "<%s: %s>"%(self.__class__.__name__,self._fname)
76 76
77 77
78 78 def _reconstruct_result(self, res):
79 79 """Reconstruct our result from actual result list (always a list)
80 80
81 81 Override me in subclasses for turning a list of results
82 82 into the expected form.
83 83 """
84 84 if self._single_result:
85 85 return res[0]
86 86 else:
87 87 return res
88 88
89 89 def get(self, timeout=-1):
90 90 """Return the result when it arrives.
91 91
92 92 If `timeout` is not ``None`` and the result does not arrive within
93 93 `timeout` seconds then ``TimeoutError`` is raised. If the
94 94 remote call raised an exception then that exception will be reraised
95 95 by get() inside a `RemoteError`.
96 96 """
97 97 if not self.ready():
98 98 self.wait(timeout)
99 99
100 100 if self._ready:
101 101 if self._success:
102 102 return self._result
103 103 else:
104 104 raise self._exception
105 105 else:
106 106 raise error.TimeoutError("Result not ready.")
107 107
108 108 def ready(self):
109 109 """Return whether the call has completed."""
110 110 if not self._ready:
111 111 self.wait(0)
112 112 return self._ready
113 113
114 114 def wait(self, timeout=-1):
115 115 """Wait until the result is available or until `timeout` seconds pass.
116 116
117 117 This method always returns None.
118 118 """
119 119 if self._ready:
120 120 return
121 121 self._ready = self._client.wait(self.msg_ids, timeout)
122 122 if self._ready:
123 123 try:
124 124 results = map(self._client.results.get, self.msg_ids)
125 125 self._result = results
126 126 if self._single_result:
127 127 r = results[0]
128 128 if isinstance(r, Exception):
129 129 raise r
130 130 else:
131 131 results = error.collect_exceptions(results, self._fname)
132 132 self._result = self._reconstruct_result(results)
133 133 except Exception, e:
134 134 self._exception = e
135 135 self._success = False
136 136 else:
137 137 self._success = True
138 138 finally:
139 139 self._metadata = map(self._client.metadata.get, self.msg_ids)
140 140
141 141
142 142 def successful(self):
143 143 """Return whether the call completed without raising an exception.
144 144
145 145 Will raise ``AssertionError`` if the result is not ready.
146 146 """
147 147 assert self.ready()
148 148 return self._success
149 149
150 150 #----------------------------------------------------------------
151 151 # Extra methods not in mp.pool.AsyncResult
152 152 #----------------------------------------------------------------
153 153
154 154 def get_dict(self, timeout=-1):
155 155 """Get the results as a dict, keyed by engine_id.
156 156
157 157 timeout behavior is described in `get()`.
158 158 """
159 159
160 160 results = self.get(timeout)
161 161 engine_ids = [ md['engine_id'] for md in self._metadata ]
162 162 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
163 163 maxcount = bycount.count(bycount[-1])
164 164 if maxcount > 1:
165 165 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
166 166 maxcount, bycount[-1]))
167 167
168 168 return dict(zip(engine_ids,results))
169 169
170 170 @property
171 171 def result(self):
172 172 """result property wrapper for `get(timeout=0)`."""
173 173 return self.get()
174 174
175 175 # abbreviated alias:
176 176 r = result
177 177
178 178 @property
179 179 @check_ready
180 180 def metadata(self):
181 181 """property for accessing execution metadata."""
182 182 if self._single_result:
183 183 return self._metadata[0]
184 184 else:
185 185 return self._metadata
186 186
187 187 @property
188 188 def result_dict(self):
189 189 """result property as a dict."""
190 190 return self.get_dict()
191 191
192 192 def __dict__(self):
193 193 return self.get_dict(0)
194 194
195 195 def abort(self):
196 196 """abort my tasks."""
197 197 assert not self.ready(), "Can't abort, I am already done!"
198 return self.client.abort(self.msg_ids, targets=self._targets, block=True)
198 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
199 199
200 200 @property
201 201 def sent(self):
202 202 """check whether my messages have been sent."""
203 203 return self._tracker.done
204 204
205 205 def wait_for_send(self, timeout=-1):
206 206 """wait for pyzmq send to complete.
207 207
208 208 This is necessary when sending arrays that you intend to edit in-place.
209 209 `timeout` is in seconds, and will raise TimeoutError if it is reached
210 210 before the send completes.
211 211 """
212 212 return self._tracker.wait(timeout)
213 213
214 214 #-------------------------------------
215 215 # dict-access
216 216 #-------------------------------------
217 217
218 218 @check_ready
219 219 def __getitem__(self, key):
220 220 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
221 221 """
222 222 if isinstance(key, int):
223 223 return error.collect_exceptions([self._result[key]], self._fname)[0]
224 224 elif isinstance(key, slice):
225 225 return error.collect_exceptions(self._result[key], self._fname)
226 226 elif isinstance(key, basestring):
227 227 values = [ md[key] for md in self._metadata ]
228 228 if self._single_result:
229 229 return values[0]
230 230 else:
231 231 return values
232 232 else:
233 233 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
234 234
235 235 def __getattr__(self, key):
236 236 """getattr maps to getitem for convenient attr access to metadata."""
237 237 try:
238 238 return self.__getitem__(key)
239 239 except (error.TimeoutError, KeyError):
240 240 raise AttributeError("%r object has no attribute %r"%(
241 241 self.__class__.__name__, key))
242 242
243 243 # asynchronous iterator:
244 244 def __iter__(self):
245 245 if self._single_result:
246 246 raise TypeError("AsyncResults with a single result are not iterable.")
247 247 try:
248 248 rlist = self.get(0)
249 249 except error.TimeoutError:
250 250 # wait for each result individually
251 251 for msg_id in self.msg_ids:
252 252 ar = AsyncResult(self._client, msg_id, self._fname)
253 253 yield ar.get()
254 254 else:
255 255 # already done
256 256 for r in rlist:
257 257 yield r
258 258
259 259
260 260
261 261 class AsyncMapResult(AsyncResult):
262 262 """Class for representing results of non-blocking gathers.
263 263
264 264 This will properly reconstruct the gather.
265 265
266 266 This class is iterable at any time, and will wait on results as they come.
267 267
268 268 If ordered=False, then the first results to arrive will come first, otherwise
269 269 results will be yielded in the order they were submitted.
270 270
271 271 """
272 272
273 273 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
274 274 AsyncResult.__init__(self, client, msg_ids, fname=fname)
275 275 self._mapObject = mapObject
276 276 self._single_result = False
277 277 self.ordered = ordered
278 278
279 279 def _reconstruct_result(self, res):
280 280 """Perform the gather on the actual results."""
281 281 return self._mapObject.joinPartitions(res)
282 282
283 283 # asynchronous iterator:
284 284 def __iter__(self):
285 285 it = self._ordered_iter if self.ordered else self._unordered_iter
286 286 for r in it():
287 287 yield r
288 288
289 289 # asynchronous ordered iterator:
290 290 def _ordered_iter(self):
291 291 """iterator for results *as they arrive*, preserving submission order."""
292 292 try:
293 293 rlist = self.get(0)
294 294 except error.TimeoutError:
295 295 # wait for each result individually
296 296 for msg_id in self.msg_ids:
297 297 ar = AsyncResult(self._client, msg_id, self._fname)
298 298 rlist = ar.get()
299 299 try:
300 300 for r in rlist:
301 301 yield r
302 302 except TypeError:
303 303 # flattened, not a list
304 304 # this could get broken by flattened data that returns iterables
305 305 # but most calls to map do not expose the `flatten` argument
306 306 yield rlist
307 307 else:
308 308 # already done
309 309 for r in rlist:
310 310 yield r
311 311
312 312 # asynchronous unordered iterator:
313 313 def _unordered_iter(self):
314 314 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
315 315 try:
316 316 rlist = self.get(0)
317 317 except error.TimeoutError:
318 318 pending = set(self.msg_ids)
319 319 while pending:
320 320 try:
321 321 self._client.wait(pending, 1e-3)
322 322 except error.TimeoutError:
323 323 # ignore timeout error, because that only means
324 324 # *some* jobs are outstanding
325 325 pass
326 326 # update ready set with those no longer outstanding:
327 327 ready = pending.difference(self._client.outstanding)
328 328 # update pending to exclude those that are finished
329 329 pending = pending.difference(ready)
330 330 while ready:
331 331 msg_id = ready.pop()
332 332 ar = AsyncResult(self._client, msg_id, self._fname)
333 333 rlist = ar.get()
334 334 try:
335 335 for r in rlist:
336 336 yield r
337 337 except TypeError:
338 338 # flattened, not a list
339 339 # this could get broken by flattened data that returns iterables
340 340 # but most calls to map do not expose the `flatten` argument
341 341 yield rlist
342 342 else:
343 343 # already done
344 344 for r in rlist:
345 345 yield r
346 346
347 347
348 348
349 349 class AsyncHubResult(AsyncResult):
350 350 """Class to wrap pending results that must be requested from the Hub.
351 351
352 352 Note that waiting/polling on these objects requires polling the Hubover the network,
353 353 so use `AsyncHubResult.wait()` sparingly.
354 354 """
355 355
356 356 def wait(self, timeout=-1):
357 357 """wait for result to complete."""
358 358 start = time.time()
359 359 if self._ready:
360 360 return
361 361 local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
362 362 local_ready = self._client.wait(local_ids, timeout)
363 363 if local_ready:
364 364 remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
365 365 if not remote_ids:
366 366 self._ready = True
367 367 else:
368 368 rdict = self._client.result_status(remote_ids, status_only=False)
369 369 pending = rdict['pending']
370 370 while pending and (timeout < 0 or time.time() < start+timeout):
371 371 rdict = self._client.result_status(remote_ids, status_only=False)
372 372 pending = rdict['pending']
373 373 if pending:
374 374 time.sleep(0.1)
375 375 if not pending:
376 376 self._ready = True
377 377 if self._ready:
378 378 try:
379 379 results = map(self._client.results.get, self.msg_ids)
380 380 self._result = results
381 381 if self._single_result:
382 382 r = results[0]
383 383 if isinstance(r, Exception):
384 384 raise r
385 385 else:
386 386 results = error.collect_exceptions(results, self._fname)
387 387 self._result = self._reconstruct_result(results)
388 388 except Exception, e:
389 389 self._exception = e
390 390 self._success = False
391 391 else:
392 392 self._success = True
393 393 finally:
394 394 self._metadata = map(self._client.metadata.get, self.msg_ids)
395 395
396 396 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult'] No newline at end of file
@@ -1,115 +1,125 b''
1 1 """Tests for asyncresult.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19
20 20 from IPython.parallel.error import TimeoutError
21 21
22 from IPython.parallel import error
22 23 from IPython.parallel.tests import add_engines
23 24 from .clienttest import ClusterTestCase
24 25
25 26 def setup():
26 27 add_engines(2, total=True)
27 28
28 29 def wait(n):
29 30 import time
30 31 time.sleep(n)
31 32 return n
32 33
33 34 class AsyncResultTest(ClusterTestCase):
34 35
35 def test_single_result(self):
36 def test_single_result_view(self):
37 """various one-target views get the right value for single_result"""
36 38 eid = self.client.ids[-1]
37 39 ar = self.client[eid].apply_async(lambda : 42)
38 40 self.assertEquals(ar.get(), 42)
39 41 ar = self.client[[eid]].apply_async(lambda : 42)
40 42 self.assertEquals(ar.get(), [42])
41 43 ar = self.client[-1:].apply_async(lambda : 42)
42 44 self.assertEquals(ar.get(), [42])
43 45
44 46 def test_get_after_done(self):
45 47 ar = self.client[-1].apply_async(lambda : 42)
46 48 ar.wait()
47 49 self.assertTrue(ar.ready())
48 50 self.assertEquals(ar.get(), 42)
49 51 self.assertEquals(ar.get(), 42)
50 52
51 53 def test_get_before_done(self):
52 54 ar = self.client[-1].apply_async(wait, 0.1)
53 55 self.assertRaises(TimeoutError, ar.get, 0)
54 56 ar.wait(0)
55 57 self.assertFalse(ar.ready())
56 58 self.assertEquals(ar.get(), 0.1)
57 59
58 60 def test_get_after_error(self):
59 61 ar = self.client[-1].apply_async(lambda : 1/0)
60 62 ar.wait(10)
61 63 self.assertRaisesRemote(ZeroDivisionError, ar.get)
62 64 self.assertRaisesRemote(ZeroDivisionError, ar.get)
63 65 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
64 66
65 67 def test_get_dict(self):
66 68 n = len(self.client)
67 69 ar = self.client[:].apply_async(lambda : 5)
68 70 self.assertEquals(ar.get(), [5]*n)
69 71 d = ar.get_dict()
70 72 self.assertEquals(sorted(d.keys()), sorted(self.client.ids))
71 73 for eid,r in d.iteritems():
72 74 self.assertEquals(r, 5)
73 75
74 76 def test_list_amr(self):
75 77 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
76 78 rlist = list(ar)
77 79
78 80 def test_getattr(self):
79 81 ar = self.client[:].apply_async(wait, 0.5)
80 82 self.assertRaises(AttributeError, lambda : ar._foo)
81 83 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
82 84 self.assertRaises(AttributeError, lambda : ar.foo)
83 85 self.assertRaises(AttributeError, lambda : ar.engine_id)
84 86 self.assertFalse(hasattr(ar, '__length_hint__'))
85 87 self.assertFalse(hasattr(ar, 'foo'))
86 88 self.assertFalse(hasattr(ar, 'engine_id'))
87 89 ar.get(5)
88 90 self.assertRaises(AttributeError, lambda : ar._foo)
89 91 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
90 92 self.assertRaises(AttributeError, lambda : ar.foo)
91 93 self.assertTrue(isinstance(ar.engine_id, list))
92 94 self.assertEquals(ar.engine_id, ar['engine_id'])
93 95 self.assertFalse(hasattr(ar, '__length_hint__'))
94 96 self.assertFalse(hasattr(ar, 'foo'))
95 97 self.assertTrue(hasattr(ar, 'engine_id'))
96 98
97 99 def test_getitem(self):
98 100 ar = self.client[:].apply_async(wait, 0.5)
99 101 self.assertRaises(TimeoutError, lambda : ar['foo'])
100 102 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
101 103 ar.get(5)
102 104 self.assertRaises(KeyError, lambda : ar['foo'])
103 105 self.assertTrue(isinstance(ar['engine_id'], list))
104 106 self.assertEquals(ar.engine_id, ar['engine_id'])
105 107
106 108 def test_single_result(self):
107 109 ar = self.client[-1].apply_async(wait, 0.5)
108 110 self.assertRaises(TimeoutError, lambda : ar['foo'])
109 111 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
110 112 self.assertTrue(ar.get(5) == 0.5)
111 113 self.assertTrue(isinstance(ar['engine_id'], int))
112 114 self.assertTrue(isinstance(ar.engine_id, int))
113 115 self.assertEquals(ar.engine_id, ar['engine_id'])
116
117 def test_abort(self):
118 e = self.client[-1]
119 ar = e.execute('import time; time.sleep(1)', block=False)
120 ar2 = e.apply_async(lambda : 2)
121 ar2.abort()
122 self.assertRaises(error.TaskAborted, ar2.get)
123 ar.get()
114 124
115 125
General Comments 0
You need to be logged in to leave comments. Login now