##// END OF EJS Templates
use finite wait_for_outputs
MinRK -
Show More
@@ -1,686 +1,686 b''
1 1 """AsyncResult objects for the client
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 from __future__ import print_function
19 19
20 20 import sys
21 21 import time
22 22 from datetime import datetime
23 23
24 24 from zmq import MessageTracker
25 25
26 26 from IPython.core.display import clear_output, display, display_pretty
27 27 from IPython.external.decorator import decorator
28 28 from IPython.parallel import error
29 29
30 30 #-----------------------------------------------------------------------------
31 31 # Functions
32 32 #-----------------------------------------------------------------------------
33 33
34 34 def _total_seconds(td):
35 35 """timedelta.total_seconds was added in 2.7"""
36 36 try:
37 37 # Python >= 2.7
38 38 return td.total_seconds()
39 39 except AttributeError:
40 40 # Python 2.6
41 41 return 1e-6 * (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6)
42 42
43 43 def _raw_text(s):
44 44 display_pretty(s, raw=True)
45 45
46 46 #-----------------------------------------------------------------------------
47 47 # Classes
48 48 #-----------------------------------------------------------------------------
49 49
50 50 # global empty tracker that's always done:
51 51 finished_tracker = MessageTracker()
52 52
53 53 @decorator
54 54 def check_ready(f, self, *args, **kwargs):
55 55 """Call spin() to sync state prior to calling the method."""
56 56 self.wait(0)
57 57 if not self._ready:
58 58 raise error.TimeoutError("result not ready")
59 59 return f(self, *args, **kwargs)
60 60
61 61 class AsyncResult(object):
62 62 """Class for representing results of non-blocking calls.
63 63
64 64 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
65 65 """
66 66
67 67 msg_ids = None
68 68 _targets = None
69 69 _tracker = None
70 70 _single_result = False
71 71
72 72 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
73 73 if isinstance(msg_ids, basestring):
74 74 # always a list
75 75 msg_ids = [msg_ids]
76 76 if tracker is None:
77 77 # default to always done
78 78 tracker = finished_tracker
79 79 self._client = client
80 80 self.msg_ids = msg_ids
81 81 self._fname=fname
82 82 self._targets = targets
83 83 self._tracker = tracker
84 84 self._ready = False
85 85 self._success = None
86 86 self._metadata = None
87 87 if len(msg_ids) == 1:
88 88 self._single_result = not isinstance(targets, (list, tuple))
89 89 else:
90 90 self._single_result = False
91 91
92 92 def __repr__(self):
93 93 if self._ready:
94 94 return "<%s: finished>"%(self.__class__.__name__)
95 95 else:
96 96 return "<%s: %s>"%(self.__class__.__name__,self._fname)
97 97
98 98
99 99 def _reconstruct_result(self, res):
100 100 """Reconstruct our result from actual result list (always a list)
101 101
102 102 Override me in subclasses for turning a list of results
103 103 into the expected form.
104 104 """
105 105 if self._single_result:
106 106 return res[0]
107 107 else:
108 108 return res
109 109
110 110 def get(self, timeout=-1):
111 111 """Return the result when it arrives.
112 112
113 113 If `timeout` is not ``None`` and the result does not arrive within
114 114 `timeout` seconds then ``TimeoutError`` is raised. If the
115 115 remote call raised an exception then that exception will be reraised
116 116 by get() inside a `RemoteError`.
117 117 """
118 118 if not self.ready():
119 119 self.wait(timeout)
120 120
121 121 if self._ready:
122 122 if self._success:
123 123 return self._result
124 124 else:
125 125 raise self._exception
126 126 else:
127 127 raise error.TimeoutError("Result not ready.")
128 128
129 129 def ready(self):
130 130 """Return whether the call has completed."""
131 131 if not self._ready:
132 132 self.wait(0)
133 133 return self._ready
134 134
135 135 def wait(self, timeout=-1):
136 136 """Wait until the result is available or until `timeout` seconds pass.
137 137
138 138 This method always returns None.
139 139 """
140 140 if self._ready:
141 141 return
142 142 self._ready = self._client.wait(self.msg_ids, timeout)
143 143 if self._ready:
144 144 try:
145 145 results = map(self._client.results.get, self.msg_ids)
146 146 self._result = results
147 147 if self._single_result:
148 148 r = results[0]
149 149 if isinstance(r, Exception):
150 150 raise r
151 151 else:
152 152 results = error.collect_exceptions(results, self._fname)
153 153 self._result = self._reconstruct_result(results)
154 154 except Exception, e:
155 155 self._exception = e
156 156 self._success = False
157 157 else:
158 158 self._success = True
159 159 finally:
160 160 self._metadata = map(self._client.metadata.get, self.msg_ids)
161 self._wait_for_outputs()
161 self._wait_for_outputs(10)
162 162
163 163
164 164
165 165 def successful(self):
166 166 """Return whether the call completed without raising an exception.
167 167
168 168 Will raise ``AssertionError`` if the result is not ready.
169 169 """
170 170 assert self.ready()
171 171 return self._success
172 172
173 173 #----------------------------------------------------------------
174 174 # Extra methods not in mp.pool.AsyncResult
175 175 #----------------------------------------------------------------
176 176
177 177 def get_dict(self, timeout=-1):
178 178 """Get the results as a dict, keyed by engine_id.
179 179
180 180 timeout behavior is described in `get()`.
181 181 """
182 182
183 183 results = self.get(timeout)
184 184 engine_ids = [ md['engine_id'] for md in self._metadata ]
185 185 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
186 186 maxcount = bycount.count(bycount[-1])
187 187 if maxcount > 1:
188 188 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
189 189 maxcount, bycount[-1]))
190 190
191 191 return dict(zip(engine_ids,results))
192 192
193 193 @property
194 194 def result(self):
195 195 """result property wrapper for `get(timeout=0)`."""
196 196 return self.get()
197 197
198 198 # abbreviated alias:
199 199 r = result
200 200
201 201 @property
202 202 @check_ready
203 203 def metadata(self):
204 204 """property for accessing execution metadata."""
205 205 if self._single_result:
206 206 return self._metadata[0]
207 207 else:
208 208 return self._metadata
209 209
210 210 @property
211 211 def result_dict(self):
212 212 """result property as a dict."""
213 213 return self.get_dict()
214 214
215 215 def __dict__(self):
216 216 return self.get_dict(0)
217 217
218 218 def abort(self):
219 219 """abort my tasks."""
220 220 assert not self.ready(), "Can't abort, I am already done!"
221 221 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
222 222
223 223 @property
224 224 def sent(self):
225 225 """check whether my messages have been sent."""
226 226 return self._tracker.done
227 227
228 228 def wait_for_send(self, timeout=-1):
229 229 """wait for pyzmq send to complete.
230 230
231 231 This is necessary when sending arrays that you intend to edit in-place.
232 232 `timeout` is in seconds, and will raise TimeoutError if it is reached
233 233 before the send completes.
234 234 """
235 235 return self._tracker.wait(timeout)
236 236
237 237 #-------------------------------------
238 238 # dict-access
239 239 #-------------------------------------
240 240
241 241 @check_ready
242 242 def __getitem__(self, key):
243 243 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
244 244 """
245 245 if isinstance(key, int):
246 246 return error.collect_exceptions([self._result[key]], self._fname)[0]
247 247 elif isinstance(key, slice):
248 248 return error.collect_exceptions(self._result[key], self._fname)
249 249 elif isinstance(key, basestring):
250 250 values = [ md[key] for md in self._metadata ]
251 251 if self._single_result:
252 252 return values[0]
253 253 else:
254 254 return values
255 255 else:
256 256 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
257 257
258 258 def __getattr__(self, key):
259 259 """getattr maps to getitem for convenient attr access to metadata."""
260 260 try:
261 261 return self.__getitem__(key)
262 262 except (error.TimeoutError, KeyError):
263 263 raise AttributeError("%r object has no attribute %r"%(
264 264 self.__class__.__name__, key))
265 265
266 266 # asynchronous iterator:
267 267 def __iter__(self):
268 268 if self._single_result:
269 269 raise TypeError("AsyncResults with a single result are not iterable.")
270 270 try:
271 271 rlist = self.get(0)
272 272 except error.TimeoutError:
273 273 # wait for each result individually
274 274 for msg_id in self.msg_ids:
275 275 ar = AsyncResult(self._client, msg_id, self._fname)
276 276 yield ar.get()
277 277 else:
278 278 # already done
279 279 for r in rlist:
280 280 yield r
281 281
282 282 def __len__(self):
283 283 return len(self.msg_ids)
284 284
285 285 #-------------------------------------
286 286 # Sugar methods and attributes
287 287 #-------------------------------------
288 288
289 289 def timedelta(self, start, end, start_key=min, end_key=max):
290 290 """compute the difference between two sets of timestamps
291 291
292 292 The default behavior is to use the earliest of the first
293 293 and the latest of the second list, but this can be changed
294 294 by passing a different
295 295
296 296 Parameters
297 297 ----------
298 298
299 299 start : one or more datetime objects (e.g. ar.submitted)
300 300 end : one or more datetime objects (e.g. ar.received)
301 301 start_key : callable
302 302 Function to call on `start` to extract the relevant
303 303 entry [defalt: min]
304 304 end_key : callable
305 305 Function to call on `end` to extract the relevant
306 306 entry [default: max]
307 307
308 308 Returns
309 309 -------
310 310
311 311 dt : float
312 312 The time elapsed (in seconds) between the two selected timestamps.
313 313 """
314 314 if not isinstance(start, datetime):
315 315 # handle single_result AsyncResults, where ar.stamp is single object,
316 316 # not a list
317 317 start = start_key(start)
318 318 if not isinstance(end, datetime):
319 319 # handle single_result AsyncResults, where ar.stamp is single object,
320 320 # not a list
321 321 end = end_key(end)
322 322 return _total_seconds(end - start)
323 323
324 324 @property
325 325 def progress(self):
326 326 """the number of tasks which have been completed at this point.
327 327
328 328 Fractional progress would be given by 1.0 * ar.progress / len(ar)
329 329 """
330 330 self.wait(0)
331 331 return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
332 332
333 333 @property
334 334 def elapsed(self):
335 335 """elapsed time since initial submission"""
336 336 if self.ready():
337 337 return self.wall_time
338 338
339 339 now = submitted = datetime.now()
340 340 for msg_id in self.msg_ids:
341 341 if msg_id in self._client.metadata:
342 342 stamp = self._client.metadata[msg_id]['submitted']
343 343 if stamp and stamp < submitted:
344 344 submitted = stamp
345 345 return _total_seconds(now-submitted)
346 346
347 347 @property
348 348 @check_ready
349 349 def serial_time(self):
350 350 """serial computation time of a parallel calculation
351 351
352 352 Computed as the sum of (completed-started) of each task
353 353 """
354 354 t = 0
355 355 for md in self._metadata:
356 356 t += _total_seconds(md['completed'] - md['started'])
357 357 return t
358 358
359 359 @property
360 360 @check_ready
361 361 def wall_time(self):
362 362 """actual computation time of a parallel calculation
363 363
364 364 Computed as the time between the latest `received` stamp
365 365 and the earliest `submitted`.
366 366
367 367 Only reliable if Client was spinning/waiting when the task finished, because
368 368 the `received` timestamp is created when a result is pulled off of the zmq queue,
369 369 which happens as a result of `client.spin()`.
370 370
371 371 For similar comparison of other timestamp pairs, check out AsyncResult.timedelta.
372 372
373 373 """
374 374 return self.timedelta(self.submitted, self.received)
375 375
376 376 def wait_interactive(self, interval=1., timeout=None):
377 377 """interactive wait, printing progress at regular intervals"""
378 378 N = len(self)
379 379 tic = time.time()
380 380 while not self.ready() and (timeout is None or time.time() - tic <= timeout):
381 381 self.wait(interval)
382 382 clear_output()
383 383 print("%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed), end="")
384 384 sys.stdout.flush()
385 385 print()
386 386 print("done")
387 387
388 388 def _republish_displaypub(self, content, eid):
389 389 """republish individual displaypub content dicts"""
390 390 try:
391 391 ip = get_ipython()
392 392 except NameError:
393 393 # displaypub is meaningless outside IPython
394 394 return
395 395 md = content['metadata'] or {}
396 396 md['engine'] = eid
397 397 ip.display_pub.publish(content['source'], content['data'], md)
398 398
399 399 def _display_stream(self, text, prefix='', file=None):
400 400 if not text:
401 401 # nothing to display
402 402 return
403 403 if file is None:
404 404 file = sys.stdout
405 405 end = '' if text.endswith('\n') else '\n'
406 406
407 407 multiline = text.count('\n') > int(text.endswith('\n'))
408 408 if prefix and multiline and not text.startswith('\n'):
409 409 prefix = prefix + '\n'
410 410 print("%s%s" % (prefix, text), file=file, end=end)
411 411
412 412
413 413 def _display_single_result(self):
414 414 self._display_stream(self.stdout)
415 415 self._display_stream(self.stderr, file=sys.stderr)
416 416
417 417 try:
418 418 get_ipython()
419 419 except NameError:
420 420 # displaypub is meaningless outside IPython
421 421 return
422 422
423 423 for output in self.outputs:
424 424 self._republish_displaypub(output, self.engine_id)
425 425
426 426 if self.pyout is not None:
427 427 display(self.get())
428 428
429 429 def _wait_for_outputs(self, timeout=-1):
430 430 """wait for the 'status=idle' message that indicates we have all outputs
431 431 """
432 432 if not self._success:
433 433 # don't wait on errors
434 434 return
435 435 tic = time.time()
436 436 while not all(md['outputs_ready'] for md in self._metadata):
437 437 time.sleep(0.01)
438 438 self._client._flush_iopub(self._client._iopub_socket)
439 439 if timeout >= 0 and time.time() > tic + timeout:
440 440 break
441 441
442 442 @check_ready
443 443 def display_outputs(self, groupby="type"):
444 444 """republish the outputs of the computation
445 445
446 446 Parameters
447 447 ----------
448 448
449 449 groupby : str [default: type]
450 450 if 'type':
451 451 Group outputs by type (show all stdout, then all stderr, etc.):
452 452
453 453 [stdout:1] foo
454 454 [stdout:2] foo
455 455 [stderr:1] bar
456 456 [stderr:2] bar
457 457 if 'engine':
458 458 Display outputs for each engine before moving on to the next:
459 459
460 460 [stdout:1] foo
461 461 [stderr:1] bar
462 462 [stdout:2] foo
463 463 [stderr:2] bar
464 464
465 465 if 'order':
466 466 Like 'type', but further collate individual displaypub
467 467 outputs. This is meant for cases of each command producing
468 468 several plots, and you would like to see all of the first
469 469 plots together, then all of the second plots, and so on.
470 470 """
471 471 if self._single_result:
472 472 self._display_single_result()
473 473 return
474 474
475 475 stdouts = self.stdout
476 476 stderrs = self.stderr
477 477 pyouts = self.pyout
478 478 output_lists = self.outputs
479 479 results = self.get()
480 480
481 481 targets = self.engine_id
482 482
483 483 if groupby == "engine":
484 484 for eid,stdout,stderr,outputs,r,pyout in zip(
485 485 targets, stdouts, stderrs, output_lists, results, pyouts
486 486 ):
487 487 self._display_stream(stdout, '[stdout:%i] ' % eid)
488 488 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
489 489
490 490 try:
491 491 get_ipython()
492 492 except NameError:
493 493 # displaypub is meaningless outside IPython
494 494 return
495 495
496 496 if outputs or pyout is not None:
497 497 _raw_text('[output:%i]' % eid)
498 498
499 499 for output in outputs:
500 500 self._republish_displaypub(output, eid)
501 501
502 502 if pyout is not None:
503 503 display(r)
504 504
505 505 elif groupby in ('type', 'order'):
506 506 # republish stdout:
507 507 for eid,stdout in zip(targets, stdouts):
508 508 self._display_stream(stdout, '[stdout:%i] ' % eid)
509 509
510 510 # republish stderr:
511 511 for eid,stderr in zip(targets, stderrs):
512 512 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
513 513
514 514 try:
515 515 get_ipython()
516 516 except NameError:
517 517 # displaypub is meaningless outside IPython
518 518 return
519 519
520 520 if groupby == 'order':
521 521 output_dict = dict((eid, outputs) for eid,outputs in zip(targets, output_lists))
522 522 N = max(len(outputs) for outputs in output_lists)
523 523 for i in range(N):
524 524 for eid in targets:
525 525 outputs = output_dict[eid]
526 526 if len(outputs) >= N:
527 527 _raw_text('[output:%i]' % eid)
528 528 self._republish_displaypub(outputs[i], eid)
529 529 else:
530 530 # republish displaypub output
531 531 for eid,outputs in zip(targets, output_lists):
532 532 if outputs:
533 533 _raw_text('[output:%i]' % eid)
534 534 for output in outputs:
535 535 self._republish_displaypub(output, eid)
536 536
537 537 # finally, add pyout:
538 538 for eid,r,pyout in zip(targets, results, pyouts):
539 539 if pyout is not None:
540 540 display(r)
541 541
542 542 else:
543 543 raise ValueError("groupby must be one of 'type', 'engine', 'collate', not %r" % groupby)
544 544
545 545
546 546
547 547
548 548 class AsyncMapResult(AsyncResult):
549 549 """Class for representing results of non-blocking gathers.
550 550
551 551 This will properly reconstruct the gather.
552 552
553 553 This class is iterable at any time, and will wait on results as they come.
554 554
555 555 If ordered=False, then the first results to arrive will come first, otherwise
556 556 results will be yielded in the order they were submitted.
557 557
558 558 """
559 559
560 560 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
561 561 AsyncResult.__init__(self, client, msg_ids, fname=fname)
562 562 self._mapObject = mapObject
563 563 self._single_result = False
564 564 self.ordered = ordered
565 565
566 566 def _reconstruct_result(self, res):
567 567 """Perform the gather on the actual results."""
568 568 return self._mapObject.joinPartitions(res)
569 569
570 570 # asynchronous iterator:
571 571 def __iter__(self):
572 572 it = self._ordered_iter if self.ordered else self._unordered_iter
573 573 for r in it():
574 574 yield r
575 575
576 576 # asynchronous ordered iterator:
577 577 def _ordered_iter(self):
578 578 """iterator for results *as they arrive*, preserving submission order."""
579 579 try:
580 580 rlist = self.get(0)
581 581 except error.TimeoutError:
582 582 # wait for each result individually
583 583 for msg_id in self.msg_ids:
584 584 ar = AsyncResult(self._client, msg_id, self._fname)
585 585 rlist = ar.get()
586 586 try:
587 587 for r in rlist:
588 588 yield r
589 589 except TypeError:
590 590 # flattened, not a list
591 591 # this could get broken by flattened data that returns iterables
592 592 # but most calls to map do not expose the `flatten` argument
593 593 yield rlist
594 594 else:
595 595 # already done
596 596 for r in rlist:
597 597 yield r
598 598
599 599 # asynchronous unordered iterator:
600 600 def _unordered_iter(self):
601 601 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
602 602 try:
603 603 rlist = self.get(0)
604 604 except error.TimeoutError:
605 605 pending = set(self.msg_ids)
606 606 while pending:
607 607 try:
608 608 self._client.wait(pending, 1e-3)
609 609 except error.TimeoutError:
610 610 # ignore timeout error, because that only means
611 611 # *some* jobs are outstanding
612 612 pass
613 613 # update ready set with those no longer outstanding:
614 614 ready = pending.difference(self._client.outstanding)
615 615 # update pending to exclude those that are finished
616 616 pending = pending.difference(ready)
617 617 while ready:
618 618 msg_id = ready.pop()
619 619 ar = AsyncResult(self._client, msg_id, self._fname)
620 620 rlist = ar.get()
621 621 try:
622 622 for r in rlist:
623 623 yield r
624 624 except TypeError:
625 625 # flattened, not a list
626 626 # this could get broken by flattened data that returns iterables
627 627 # but most calls to map do not expose the `flatten` argument
628 628 yield rlist
629 629 else:
630 630 # already done
631 631 for r in rlist:
632 632 yield r
633 633
634 634
635 635 class AsyncHubResult(AsyncResult):
636 636 """Class to wrap pending results that must be requested from the Hub.
637 637
638 638 Note that waiting/polling on these objects requires polling the Hubover the network,
639 639 so use `AsyncHubResult.wait()` sparingly.
640 640 """
641 641
642 642 def _wait_for_outputs(self, timeout=None):
643 643 """no-op, because HubResults are never incomplete"""
644 644 return
645 645
646 646 def wait(self, timeout=-1):
647 647 """wait for result to complete."""
648 648 start = time.time()
649 649 if self._ready:
650 650 return
651 651 local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
652 652 local_ready = self._client.wait(local_ids, timeout)
653 653 if local_ready:
654 654 remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
655 655 if not remote_ids:
656 656 self._ready = True
657 657 else:
658 658 rdict = self._client.result_status(remote_ids, status_only=False)
659 659 pending = rdict['pending']
660 660 while pending and (timeout < 0 or time.time() < start+timeout):
661 661 rdict = self._client.result_status(remote_ids, status_only=False)
662 662 pending = rdict['pending']
663 663 if pending:
664 664 time.sleep(0.1)
665 665 if not pending:
666 666 self._ready = True
667 667 if self._ready:
668 668 try:
669 669 results = map(self._client.results.get, self.msg_ids)
670 670 self._result = results
671 671 if self._single_result:
672 672 r = results[0]
673 673 if isinstance(r, Exception):
674 674 raise r
675 675 else:
676 676 results = error.collect_exceptions(results, self._fname)
677 677 self._result = self._reconstruct_result(results)
678 678 except Exception, e:
679 679 self._exception = e
680 680 self._success = False
681 681 else:
682 682 self._success = True
683 683 finally:
684 684 self._metadata = map(self._client.metadata.get, self.msg_ids)
685 685
686 686 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult'] No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now