##// END OF EJS Templates
Merge pull request #1608 from minrk/ar_sugar_2.6...
Fernando Perez -
r6505:5750e2dd merge
parent child Browse files
Show More
@@ -1,505 +1,517 b''
1 1 """AsyncResult objects for the client
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import sys
19 19 import time
20 20 from datetime import datetime
21 21
22 22 from zmq import MessageTracker
23 23
24 24 from IPython.core.display import clear_output
25 25 from IPython.external.decorator import decorator
26 26 from IPython.parallel import error
27 27
28 #-----------------------------------------------------------------------------
29 # Functions
30 #-----------------------------------------------------------------------------
31
32 def _total_seconds(td):
33 """timedelta.total_seconds was added in 2.7"""
34 try:
35 # Python >= 2.7
36 return td.total_seconds()
37 except AttributeError:
38 # Python 2.6
39 return 1e-6 * (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6)
28 40
29 41 #-----------------------------------------------------------------------------
30 42 # Classes
31 43 #-----------------------------------------------------------------------------
32 44
33 45 # global empty tracker that's always done:
34 46 finished_tracker = MessageTracker()
35 47
36 48 @decorator
37 49 def check_ready(f, self, *args, **kwargs):
38 50 """Call spin() to sync state prior to calling the method."""
39 51 self.wait(0)
40 52 if not self._ready:
41 53 raise error.TimeoutError("result not ready")
42 54 return f(self, *args, **kwargs)
43 55
44 56 class AsyncResult(object):
45 57 """Class for representing results of non-blocking calls.
46 58
47 59 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
48 60 """
49 61
50 62 msg_ids = None
51 63 _targets = None
52 64 _tracker = None
53 65 _single_result = False
54 66
55 67 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
56 68 if isinstance(msg_ids, basestring):
57 69 # always a list
58 70 msg_ids = [msg_ids]
59 71 if tracker is None:
60 72 # default to always done
61 73 tracker = finished_tracker
62 74 self._client = client
63 75 self.msg_ids = msg_ids
64 76 self._fname=fname
65 77 self._targets = targets
66 78 self._tracker = tracker
67 79 self._ready = False
68 80 self._success = None
69 81 self._metadata = None
70 82 if len(msg_ids) == 1:
71 83 self._single_result = not isinstance(targets, (list, tuple))
72 84 else:
73 85 self._single_result = False
74 86
75 87 def __repr__(self):
76 88 if self._ready:
77 89 return "<%s: finished>"%(self.__class__.__name__)
78 90 else:
79 91 return "<%s: %s>"%(self.__class__.__name__,self._fname)
80 92
81 93
82 94 def _reconstruct_result(self, res):
83 95 """Reconstruct our result from actual result list (always a list)
84 96
85 97 Override me in subclasses for turning a list of results
86 98 into the expected form.
87 99 """
88 100 if self._single_result:
89 101 return res[0]
90 102 else:
91 103 return res
92 104
93 105 def get(self, timeout=-1):
94 106 """Return the result when it arrives.
95 107
96 108 If `timeout` is not ``None`` and the result does not arrive within
97 109 `timeout` seconds then ``TimeoutError`` is raised. If the
98 110 remote call raised an exception then that exception will be reraised
99 111 by get() inside a `RemoteError`.
100 112 """
101 113 if not self.ready():
102 114 self.wait(timeout)
103 115
104 116 if self._ready:
105 117 if self._success:
106 118 return self._result
107 119 else:
108 120 raise self._exception
109 121 else:
110 122 raise error.TimeoutError("Result not ready.")
111 123
112 124 def ready(self):
113 125 """Return whether the call has completed."""
114 126 if not self._ready:
115 127 self.wait(0)
116 128 return self._ready
117 129
118 130 def wait(self, timeout=-1):
119 131 """Wait until the result is available or until `timeout` seconds pass.
120 132
121 133 This method always returns None.
122 134 """
123 135 if self._ready:
124 136 return
125 137 self._ready = self._client.wait(self.msg_ids, timeout)
126 138 if self._ready:
127 139 try:
128 140 results = map(self._client.results.get, self.msg_ids)
129 141 self._result = results
130 142 if self._single_result:
131 143 r = results[0]
132 144 if isinstance(r, Exception):
133 145 raise r
134 146 else:
135 147 results = error.collect_exceptions(results, self._fname)
136 148 self._result = self._reconstruct_result(results)
137 149 except Exception, e:
138 150 self._exception = e
139 151 self._success = False
140 152 else:
141 153 self._success = True
142 154 finally:
143 155 self._metadata = map(self._client.metadata.get, self.msg_ids)
144 156
145 157
146 158 def successful(self):
147 159 """Return whether the call completed without raising an exception.
148 160
149 161 Will raise ``AssertionError`` if the result is not ready.
150 162 """
151 163 assert self.ready()
152 164 return self._success
153 165
154 166 #----------------------------------------------------------------
155 167 # Extra methods not in mp.pool.AsyncResult
156 168 #----------------------------------------------------------------
157 169
158 170 def get_dict(self, timeout=-1):
159 171 """Get the results as a dict, keyed by engine_id.
160 172
161 173 timeout behavior is described in `get()`.
162 174 """
163 175
164 176 results = self.get(timeout)
165 177 engine_ids = [ md['engine_id'] for md in self._metadata ]
166 178 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
167 179 maxcount = bycount.count(bycount[-1])
168 180 if maxcount > 1:
169 181 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
170 182 maxcount, bycount[-1]))
171 183
172 184 return dict(zip(engine_ids,results))
173 185
174 186 @property
175 187 def result(self):
176 188 """result property wrapper for `get(timeout=0)`."""
177 189 return self.get()
178 190
179 191 # abbreviated alias:
180 192 r = result
181 193
182 194 @property
183 195 @check_ready
184 196 def metadata(self):
185 197 """property for accessing execution metadata."""
186 198 if self._single_result:
187 199 return self._metadata[0]
188 200 else:
189 201 return self._metadata
190 202
191 203 @property
192 204 def result_dict(self):
193 205 """result property as a dict."""
194 206 return self.get_dict()
195 207
196 208 def __dict__(self):
197 209 return self.get_dict(0)
198 210
199 211 def abort(self):
200 212 """abort my tasks."""
201 213 assert not self.ready(), "Can't abort, I am already done!"
202 214 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
203 215
204 216 @property
205 217 def sent(self):
206 218 """check whether my messages have been sent."""
207 219 return self._tracker.done
208 220
209 221 def wait_for_send(self, timeout=-1):
210 222 """wait for pyzmq send to complete.
211 223
212 224 This is necessary when sending arrays that you intend to edit in-place.
213 225 `timeout` is in seconds, and will raise TimeoutError if it is reached
214 226 before the send completes.
215 227 """
216 228 return self._tracker.wait(timeout)
217 229
218 230 #-------------------------------------
219 231 # dict-access
220 232 #-------------------------------------
221 233
222 234 @check_ready
223 235 def __getitem__(self, key):
224 236 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
225 237 """
226 238 if isinstance(key, int):
227 239 return error.collect_exceptions([self._result[key]], self._fname)[0]
228 240 elif isinstance(key, slice):
229 241 return error.collect_exceptions(self._result[key], self._fname)
230 242 elif isinstance(key, basestring):
231 243 values = [ md[key] for md in self._metadata ]
232 244 if self._single_result:
233 245 return values[0]
234 246 else:
235 247 return values
236 248 else:
237 249 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
238 250
239 251 def __getattr__(self, key):
240 252 """getattr maps to getitem for convenient attr access to metadata."""
241 253 try:
242 254 return self.__getitem__(key)
243 255 except (error.TimeoutError, KeyError):
244 256 raise AttributeError("%r object has no attribute %r"%(
245 257 self.__class__.__name__, key))
246 258
247 259 # asynchronous iterator:
248 260 def __iter__(self):
249 261 if self._single_result:
250 262 raise TypeError("AsyncResults with a single result are not iterable.")
251 263 try:
252 264 rlist = self.get(0)
253 265 except error.TimeoutError:
254 266 # wait for each result individually
255 267 for msg_id in self.msg_ids:
256 268 ar = AsyncResult(self._client, msg_id, self._fname)
257 269 yield ar.get()
258 270 else:
259 271 # already done
260 272 for r in rlist:
261 273 yield r
262 274
263 275 def __len__(self):
264 276 return len(self.msg_ids)
265 277
266 278 #-------------------------------------
267 279 # Sugar methods and attributes
268 280 #-------------------------------------
269 281
270 282 def timedelta(self, start, end, start_key=min, end_key=max):
271 283 """compute the difference between two sets of timestamps
272 284
273 285 The default behavior is to use the earliest of the first
274 286 and the latest of the second list, but this can be changed
275 287 by passing a different
276 288
277 289 Parameters
278 290 ----------
279 291
280 292 start : one or more datetime objects (e.g. ar.submitted)
281 293 end : one or more datetime objects (e.g. ar.received)
282 294 start_key : callable
283 295 Function to call on `start` to extract the relevant
284 296 entry [defalt: min]
285 297 end_key : callable
286 298 Function to call on `end` to extract the relevant
287 299 entry [default: max]
288 300
289 301 Returns
290 302 -------
291 303
292 304 dt : float
293 305 The time elapsed (in seconds) between the two selected timestamps.
294 306 """
295 307 if not isinstance(start, datetime):
296 308 # handle single_result AsyncResults, where ar.stamp is single object,
297 309 # not a list
298 310 start = start_key(start)
299 311 if not isinstance(end, datetime):
300 312 # handle single_result AsyncResults, where ar.stamp is single object,
301 313 # not a list
302 314 end = end_key(end)
303 return (end - start).total_seconds()
315 return _total_seconds(end - start)
304 316
305 317 @property
306 318 def progress(self):
307 319 """the number of tasks which have been completed at this point.
308 320
309 321 Fractional progress would be given by 1.0 * ar.progress / len(ar)
310 322 """
311 323 self.wait(0)
312 324 return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
313 325
314 326 @property
315 327 def elapsed(self):
316 328 """elapsed time since initial submission"""
317 329 if self.ready():
318 330 return self.wall_time
319 331
320 332 now = submitted = datetime.now()
321 333 for msg_id in self.msg_ids:
322 334 if msg_id in self._client.metadata:
323 335 stamp = self._client.metadata[msg_id]['submitted']
324 336 if stamp and stamp < submitted:
325 337 submitted = stamp
326 return (now-submitted).total_seconds()
338 return _total_seconds(now-submitted)
327 339
328 340 @property
329 341 @check_ready
330 342 def serial_time(self):
331 343 """serial computation time of a parallel calculation
332 344
333 345 Computed as the sum of (completed-started) of each task
334 346 """
335 347 t = 0
336 348 for md in self._metadata:
337 t += (md['completed'] - md['started']).total_seconds()
349 t += _total_seconds(md['completed'] - md['started'])
338 350 return t
339 351
340 352 @property
341 353 @check_ready
342 354 def wall_time(self):
343 355 """actual computation time of a parallel calculation
344 356
345 357 Computed as the time between the latest `received` stamp
346 358 and the earliest `submitted`.
347 359
348 360 Only reliable if Client was spinning/waiting when the task finished, because
349 361 the `received` timestamp is created when a result is pulled off of the zmq queue,
350 362 which happens as a result of `client.spin()`.
351 363
352 364 For similar comparison of other timestamp pairs, check out AsyncResult.timedelta.
353 365
354 366 """
355 367 return self.timedelta(self.submitted, self.received)
356 368
357 369 def wait_interactive(self, interval=1., timeout=None):
358 370 """interactive wait, printing progress at regular intervals"""
359 371 N = len(self)
360 372 tic = time.time()
361 373 while not self.ready() and (timeout is None or time.time() - tic <= timeout):
362 374 self.wait(interval)
363 375 clear_output()
364 376 print "%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed),
365 377 sys.stdout.flush()
366 378 print
367 379 print "done"
368 380
369 381
370 382 class AsyncMapResult(AsyncResult):
371 383 """Class for representing results of non-blocking gathers.
372 384
373 385 This will properly reconstruct the gather.
374 386
375 387 This class is iterable at any time, and will wait on results as they come.
376 388
377 389 If ordered=False, then the first results to arrive will come first, otherwise
378 390 results will be yielded in the order they were submitted.
379 391
380 392 """
381 393
382 394 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
383 395 AsyncResult.__init__(self, client, msg_ids, fname=fname)
384 396 self._mapObject = mapObject
385 397 self._single_result = False
386 398 self.ordered = ordered
387 399
388 400 def _reconstruct_result(self, res):
389 401 """Perform the gather on the actual results."""
390 402 return self._mapObject.joinPartitions(res)
391 403
392 404 # asynchronous iterator:
393 405 def __iter__(self):
394 406 it = self._ordered_iter if self.ordered else self._unordered_iter
395 407 for r in it():
396 408 yield r
397 409
398 410 # asynchronous ordered iterator:
399 411 def _ordered_iter(self):
400 412 """iterator for results *as they arrive*, preserving submission order."""
401 413 try:
402 414 rlist = self.get(0)
403 415 except error.TimeoutError:
404 416 # wait for each result individually
405 417 for msg_id in self.msg_ids:
406 418 ar = AsyncResult(self._client, msg_id, self._fname)
407 419 rlist = ar.get()
408 420 try:
409 421 for r in rlist:
410 422 yield r
411 423 except TypeError:
412 424 # flattened, not a list
413 425 # this could get broken by flattened data that returns iterables
414 426 # but most calls to map do not expose the `flatten` argument
415 427 yield rlist
416 428 else:
417 429 # already done
418 430 for r in rlist:
419 431 yield r
420 432
421 433 # asynchronous unordered iterator:
422 434 def _unordered_iter(self):
423 435 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
424 436 try:
425 437 rlist = self.get(0)
426 438 except error.TimeoutError:
427 439 pending = set(self.msg_ids)
428 440 while pending:
429 441 try:
430 442 self._client.wait(pending, 1e-3)
431 443 except error.TimeoutError:
432 444 # ignore timeout error, because that only means
433 445 # *some* jobs are outstanding
434 446 pass
435 447 # update ready set with those no longer outstanding:
436 448 ready = pending.difference(self._client.outstanding)
437 449 # update pending to exclude those that are finished
438 450 pending = pending.difference(ready)
439 451 while ready:
440 452 msg_id = ready.pop()
441 453 ar = AsyncResult(self._client, msg_id, self._fname)
442 454 rlist = ar.get()
443 455 try:
444 456 for r in rlist:
445 457 yield r
446 458 except TypeError:
447 459 # flattened, not a list
448 460 # this could get broken by flattened data that returns iterables
449 461 # but most calls to map do not expose the `flatten` argument
450 462 yield rlist
451 463 else:
452 464 # already done
453 465 for r in rlist:
454 466 yield r
455 467
456 468
457 469
458 470 class AsyncHubResult(AsyncResult):
459 471 """Class to wrap pending results that must be requested from the Hub.
460 472
461 473 Note that waiting/polling on these objects requires polling the Hubover the network,
462 474 so use `AsyncHubResult.wait()` sparingly.
463 475 """
464 476
465 477 def wait(self, timeout=-1):
466 478 """wait for result to complete."""
467 479 start = time.time()
468 480 if self._ready:
469 481 return
470 482 local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
471 483 local_ready = self._client.wait(local_ids, timeout)
472 484 if local_ready:
473 485 remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
474 486 if not remote_ids:
475 487 self._ready = True
476 488 else:
477 489 rdict = self._client.result_status(remote_ids, status_only=False)
478 490 pending = rdict['pending']
479 491 while pending and (timeout < 0 or time.time() < start+timeout):
480 492 rdict = self._client.result_status(remote_ids, status_only=False)
481 493 pending = rdict['pending']
482 494 if pending:
483 495 time.sleep(0.1)
484 496 if not pending:
485 497 self._ready = True
486 498 if self._ready:
487 499 try:
488 500 results = map(self._client.results.get, self.msg_ids)
489 501 self._result = results
490 502 if self._single_result:
491 503 r = results[0]
492 504 if isinstance(r, Exception):
493 505 raise r
494 506 else:
495 507 results = error.collect_exceptions(results, self._fname)
496 508 self._result = self._reconstruct_result(results)
497 509 except Exception, e:
498 510 self._exception = e
499 511 self._success = False
500 512 else:
501 513 self._success = True
502 514 finally:
503 515 self._metadata = map(self._client.metadata.get, self.msg_ids)
504 516
505 517 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult'] No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now