##// END OF EJS Templates
fix AsyncResult.get_dict for single result...
MinRK -
Show More
@@ -1,708 +1,713 b''
1 1 """AsyncResult objects for the client
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 from __future__ import print_function
19 19
20 20 import sys
21 21 import time
22 22 from datetime import datetime
23 23
24 24 from zmq import MessageTracker
25 25
26 26 from IPython.core.display import clear_output, display, display_pretty
27 27 from IPython.external.decorator import decorator
28 28 from IPython.parallel import error
29 29
30 30 #-----------------------------------------------------------------------------
31 31 # Functions
32 32 #-----------------------------------------------------------------------------
33 33
34 34 def _total_seconds(td):
35 35 """timedelta.total_seconds was added in 2.7"""
36 36 try:
37 37 # Python >= 2.7
38 38 return td.total_seconds()
39 39 except AttributeError:
40 40 # Python 2.6
41 41 return 1e-6 * (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6)
42 42
43 43 def _raw_text(s):
44 44 display_pretty(s, raw=True)
45 45
46 46 #-----------------------------------------------------------------------------
47 47 # Classes
48 48 #-----------------------------------------------------------------------------
49 49
50 50 # global empty tracker that's always done:
51 51 finished_tracker = MessageTracker()
52 52
53 53 @decorator
54 54 def check_ready(f, self, *args, **kwargs):
55 55 """Call spin() to sync state prior to calling the method."""
56 56 self.wait(0)
57 57 if not self._ready:
58 58 raise error.TimeoutError("result not ready")
59 59 return f(self, *args, **kwargs)
60 60
61 61 class AsyncResult(object):
62 62 """Class for representing results of non-blocking calls.
63 63
64 64 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
65 65 """
66 66
67 67 msg_ids = None
68 68 _targets = None
69 69 _tracker = None
70 70 _single_result = False
71 71
72 72 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
73 73 if isinstance(msg_ids, basestring):
74 74 # always a list
75 75 msg_ids = [msg_ids]
76 76 if tracker is None:
77 77 # default to always done
78 78 tracker = finished_tracker
79 79 self._client = client
80 80 self.msg_ids = msg_ids
81 81 self._fname=fname
82 82 self._targets = targets
83 83 self._tracker = tracker
84 84 self._ready = False
85 85 self._outputs_ready = False
86 86 self._success = None
87 87 self._metadata = [ self._client.metadata.get(id) for id in self.msg_ids ]
88 88 if len(msg_ids) == 1:
89 89 self._single_result = not isinstance(targets, (list, tuple))
90 90 else:
91 91 self._single_result = False
92 92
93 93 def __repr__(self):
94 94 if self._ready:
95 95 return "<%s: finished>"%(self.__class__.__name__)
96 96 else:
97 97 return "<%s: %s>"%(self.__class__.__name__,self._fname)
98 98
99 99
100 100 def _reconstruct_result(self, res):
101 101 """Reconstruct our result from actual result list (always a list)
102 102
103 103 Override me in subclasses for turning a list of results
104 104 into the expected form.
105 105 """
106 106 if self._single_result:
107 107 return res[0]
108 108 else:
109 109 return res
110 110
111 111 def get(self, timeout=-1):
112 112 """Return the result when it arrives.
113 113
114 114 If `timeout` is not ``None`` and the result does not arrive within
115 115 `timeout` seconds then ``TimeoutError`` is raised. If the
116 116 remote call raised an exception then that exception will be reraised
117 117 by get() inside a `RemoteError`.
118 118 """
119 119 if not self.ready():
120 120 self.wait(timeout)
121 121
122 122 if self._ready:
123 123 if self._success:
124 124 return self._result
125 125 else:
126 126 raise self._exception
127 127 else:
128 128 raise error.TimeoutError("Result not ready.")
129 129
130 130 def _check_ready(self):
131 131 if not self.ready():
132 132 raise error.TimeoutError("Result not ready.")
133 133
134 134 def ready(self):
135 135 """Return whether the call has completed."""
136 136 if not self._ready:
137 137 self.wait(0)
138 138 elif not self._outputs_ready:
139 139 self._wait_for_outputs(0)
140 140
141 141 return self._ready
142 142
143 143 def wait(self, timeout=-1):
144 144 """Wait until the result is available or until `timeout` seconds pass.
145 145
146 146 This method always returns None.
147 147 """
148 148 if self._ready:
149 149 self._wait_for_outputs(timeout)
150 150 return
151 151 self._ready = self._client.wait(self.msg_ids, timeout)
152 152 if self._ready:
153 153 try:
154 154 results = map(self._client.results.get, self.msg_ids)
155 155 self._result = results
156 156 if self._single_result:
157 157 r = results[0]
158 158 if isinstance(r, Exception):
159 159 raise r
160 160 else:
161 161 results = error.collect_exceptions(results, self._fname)
162 162 self._result = self._reconstruct_result(results)
163 163 except Exception as e:
164 164 self._exception = e
165 165 self._success = False
166 166 else:
167 167 self._success = True
168 168 finally:
169 169 if timeout is None or timeout < 0:
170 170 # cutoff infinite wait at 10s
171 171 timeout = 10
172 172 self._wait_for_outputs(timeout)
173 173
174 174
175 175 def successful(self):
176 176 """Return whether the call completed without raising an exception.
177 177
178 178 Will raise ``AssertionError`` if the result is not ready.
179 179 """
180 180 assert self.ready()
181 181 return self._success
182 182
183 183 #----------------------------------------------------------------
184 184 # Extra methods not in mp.pool.AsyncResult
185 185 #----------------------------------------------------------------
186 186
187 187 def get_dict(self, timeout=-1):
188 188 """Get the results as a dict, keyed by engine_id.
189 189
190 190 timeout behavior is described in `get()`.
191 191 """
192 192
193 193 results = self.get(timeout)
194 if self._single_result:
195 results = [results]
194 196 engine_ids = [ md['engine_id'] for md in self._metadata ]
195 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
196 maxcount = bycount.count(bycount[-1])
197 if maxcount > 1:
197
198 seen = set()
199 for engine_id in engine_ids:
200 if engine_id in seen:
198 201 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
199 maxcount, bycount[-1]))
202 engine_ids.count(engine_id), engine_id))
203 else:
204 seen.add(engine_id)
200 205
201 206 return dict(zip(engine_ids,results))
202 207
203 208 @property
204 209 def result(self):
205 210 """result property wrapper for `get(timeout=-1)`."""
206 211 return self.get()
207 212
208 213 # abbreviated alias:
209 214 r = result
210 215
211 216 @property
212 217 def metadata(self):
213 218 """property for accessing execution metadata."""
214 219 if self._single_result:
215 220 return self._metadata[0]
216 221 else:
217 222 return self._metadata
218 223
219 224 @property
220 225 def result_dict(self):
221 226 """result property as a dict."""
222 227 return self.get_dict()
223 228
224 229 def __dict__(self):
225 230 return self.get_dict(0)
226 231
227 232 def abort(self):
228 233 """abort my tasks."""
229 234 assert not self.ready(), "Can't abort, I am already done!"
230 235 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
231 236
232 237 @property
233 238 def sent(self):
234 239 """check whether my messages have been sent."""
235 240 return self._tracker.done
236 241
237 242 def wait_for_send(self, timeout=-1):
238 243 """wait for pyzmq send to complete.
239 244
240 245 This is necessary when sending arrays that you intend to edit in-place.
241 246 `timeout` is in seconds, and will raise TimeoutError if it is reached
242 247 before the send completes.
243 248 """
244 249 return self._tracker.wait(timeout)
245 250
246 251 #-------------------------------------
247 252 # dict-access
248 253 #-------------------------------------
249 254
250 255 def __getitem__(self, key):
251 256 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
252 257 """
253 258 if isinstance(key, int):
254 259 self._check_ready()
255 260 return error.collect_exceptions([self._result[key]], self._fname)[0]
256 261 elif isinstance(key, slice):
257 262 self._check_ready()
258 263 return error.collect_exceptions(self._result[key], self._fname)
259 264 elif isinstance(key, basestring):
260 265 # metadata proxy *does not* require that results are done
261 266 self.wait(0)
262 267 values = [ md[key] for md in self._metadata ]
263 268 if self._single_result:
264 269 return values[0]
265 270 else:
266 271 return values
267 272 else:
268 273 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
269 274
270 275 def __getattr__(self, key):
271 276 """getattr maps to getitem for convenient attr access to metadata."""
272 277 try:
273 278 return self.__getitem__(key)
274 279 except (error.TimeoutError, KeyError):
275 280 raise AttributeError("%r object has no attribute %r"%(
276 281 self.__class__.__name__, key))
277 282
278 283 # asynchronous iterator:
279 284 def __iter__(self):
280 285 if self._single_result:
281 286 raise TypeError("AsyncResults with a single result are not iterable.")
282 287 try:
283 288 rlist = self.get(0)
284 289 except error.TimeoutError:
285 290 # wait for each result individually
286 291 for msg_id in self.msg_ids:
287 292 ar = AsyncResult(self._client, msg_id, self._fname)
288 293 yield ar.get()
289 294 else:
290 295 # already done
291 296 for r in rlist:
292 297 yield r
293 298
294 299 def __len__(self):
295 300 return len(self.msg_ids)
296 301
297 302 #-------------------------------------
298 303 # Sugar methods and attributes
299 304 #-------------------------------------
300 305
301 306 def timedelta(self, start, end, start_key=min, end_key=max):
302 307 """compute the difference between two sets of timestamps
303 308
304 309 The default behavior is to use the earliest of the first
305 310 and the latest of the second list, but this can be changed
306 311 by passing a different
307 312
308 313 Parameters
309 314 ----------
310 315
311 316 start : one or more datetime objects (e.g. ar.submitted)
312 317 end : one or more datetime objects (e.g. ar.received)
313 318 start_key : callable
314 319 Function to call on `start` to extract the relevant
315 320 entry [defalt: min]
316 321 end_key : callable
317 322 Function to call on `end` to extract the relevant
318 323 entry [default: max]
319 324
320 325 Returns
321 326 -------
322 327
323 328 dt : float
324 329 The time elapsed (in seconds) between the two selected timestamps.
325 330 """
326 331 if not isinstance(start, datetime):
327 332 # handle single_result AsyncResults, where ar.stamp is single object,
328 333 # not a list
329 334 start = start_key(start)
330 335 if not isinstance(end, datetime):
331 336 # handle single_result AsyncResults, where ar.stamp is single object,
332 337 # not a list
333 338 end = end_key(end)
334 339 return _total_seconds(end - start)
335 340
336 341 @property
337 342 def progress(self):
338 343 """the number of tasks which have been completed at this point.
339 344
340 345 Fractional progress would be given by 1.0 * ar.progress / len(ar)
341 346 """
342 347 self.wait(0)
343 348 return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
344 349
345 350 @property
346 351 def elapsed(self):
347 352 """elapsed time since initial submission"""
348 353 if self.ready():
349 354 return self.wall_time
350 355
351 356 now = submitted = datetime.now()
352 357 for msg_id in self.msg_ids:
353 358 if msg_id in self._client.metadata:
354 359 stamp = self._client.metadata[msg_id]['submitted']
355 360 if stamp and stamp < submitted:
356 361 submitted = stamp
357 362 return _total_seconds(now-submitted)
358 363
359 364 @property
360 365 @check_ready
361 366 def serial_time(self):
362 367 """serial computation time of a parallel calculation
363 368
364 369 Computed as the sum of (completed-started) of each task
365 370 """
366 371 t = 0
367 372 for md in self._metadata:
368 373 t += _total_seconds(md['completed'] - md['started'])
369 374 return t
370 375
371 376 @property
372 377 @check_ready
373 378 def wall_time(self):
374 379 """actual computation time of a parallel calculation
375 380
376 381 Computed as the time between the latest `received` stamp
377 382 and the earliest `submitted`.
378 383
379 384 Only reliable if Client was spinning/waiting when the task finished, because
380 385 the `received` timestamp is created when a result is pulled off of the zmq queue,
381 386 which happens as a result of `client.spin()`.
382 387
383 388 For similar comparison of other timestamp pairs, check out AsyncResult.timedelta.
384 389
385 390 """
386 391 return self.timedelta(self.submitted, self.received)
387 392
388 393 def wait_interactive(self, interval=1., timeout=-1):
389 394 """interactive wait, printing progress at regular intervals"""
390 395 if timeout is None:
391 396 timeout = -1
392 397 N = len(self)
393 398 tic = time.time()
394 399 while not self.ready() and (timeout < 0 or time.time() - tic <= timeout):
395 400 self.wait(interval)
396 401 clear_output()
397 402 print("%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed), end="")
398 403 sys.stdout.flush()
399 404 print()
400 405 print("done")
401 406
402 407 def _republish_displaypub(self, content, eid):
403 408 """republish individual displaypub content dicts"""
404 409 try:
405 410 ip = get_ipython()
406 411 except NameError:
407 412 # displaypub is meaningless outside IPython
408 413 return
409 414 md = content['metadata'] or {}
410 415 md['engine'] = eid
411 416 ip.display_pub.publish(content['source'], content['data'], md)
412 417
413 418 def _display_stream(self, text, prefix='', file=None):
414 419 if not text:
415 420 # nothing to display
416 421 return
417 422 if file is None:
418 423 file = sys.stdout
419 424 end = '' if text.endswith('\n') else '\n'
420 425
421 426 multiline = text.count('\n') > int(text.endswith('\n'))
422 427 if prefix and multiline and not text.startswith('\n'):
423 428 prefix = prefix + '\n'
424 429 print("%s%s" % (prefix, text), file=file, end=end)
425 430
426 431
427 432 def _display_single_result(self):
428 433 self._display_stream(self.stdout)
429 434 self._display_stream(self.stderr, file=sys.stderr)
430 435
431 436 try:
432 437 get_ipython()
433 438 except NameError:
434 439 # displaypub is meaningless outside IPython
435 440 return
436 441
437 442 for output in self.outputs:
438 443 self._republish_displaypub(output, self.engine_id)
439 444
440 445 if self.pyout is not None:
441 446 display(self.get())
442 447
443 448 def _wait_for_outputs(self, timeout=-1):
444 449 """wait for the 'status=idle' message that indicates we have all outputs
445 450 """
446 451 if self._outputs_ready or not self._success:
447 452 # don't wait on errors
448 453 return
449 454
450 455 # cast None to -1 for infinite timeout
451 456 if timeout is None:
452 457 timeout = -1
453 458
454 459 tic = time.time()
455 460 while True:
456 461 self._client._flush_iopub(self._client._iopub_socket)
457 462 self._outputs_ready = all(md['outputs_ready']
458 463 for md in self._metadata)
459 464 if self._outputs_ready or \
460 465 (timeout >= 0 and time.time() > tic + timeout):
461 466 break
462 467 time.sleep(0.01)
463 468
464 469 @check_ready
465 470 def display_outputs(self, groupby="type"):
466 471 """republish the outputs of the computation
467 472
468 473 Parameters
469 474 ----------
470 475
471 476 groupby : str [default: type]
472 477 if 'type':
473 478 Group outputs by type (show all stdout, then all stderr, etc.):
474 479
475 480 [stdout:1] foo
476 481 [stdout:2] foo
477 482 [stderr:1] bar
478 483 [stderr:2] bar
479 484 if 'engine':
480 485 Display outputs for each engine before moving on to the next:
481 486
482 487 [stdout:1] foo
483 488 [stderr:1] bar
484 489 [stdout:2] foo
485 490 [stderr:2] bar
486 491
487 492 if 'order':
488 493 Like 'type', but further collate individual displaypub
489 494 outputs. This is meant for cases of each command producing
490 495 several plots, and you would like to see all of the first
491 496 plots together, then all of the second plots, and so on.
492 497 """
493 498 if self._single_result:
494 499 self._display_single_result()
495 500 return
496 501
497 502 stdouts = self.stdout
498 503 stderrs = self.stderr
499 504 pyouts = self.pyout
500 505 output_lists = self.outputs
501 506 results = self.get()
502 507
503 508 targets = self.engine_id
504 509
505 510 if groupby == "engine":
506 511 for eid,stdout,stderr,outputs,r,pyout in zip(
507 512 targets, stdouts, stderrs, output_lists, results, pyouts
508 513 ):
509 514 self._display_stream(stdout, '[stdout:%i] ' % eid)
510 515 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
511 516
512 517 try:
513 518 get_ipython()
514 519 except NameError:
515 520 # displaypub is meaningless outside IPython
516 521 return
517 522
518 523 if outputs or pyout is not None:
519 524 _raw_text('[output:%i]' % eid)
520 525
521 526 for output in outputs:
522 527 self._republish_displaypub(output, eid)
523 528
524 529 if pyout is not None:
525 530 display(r)
526 531
527 532 elif groupby in ('type', 'order'):
528 533 # republish stdout:
529 534 for eid,stdout in zip(targets, stdouts):
530 535 self._display_stream(stdout, '[stdout:%i] ' % eid)
531 536
532 537 # republish stderr:
533 538 for eid,stderr in zip(targets, stderrs):
534 539 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
535 540
536 541 try:
537 542 get_ipython()
538 543 except NameError:
539 544 # displaypub is meaningless outside IPython
540 545 return
541 546
542 547 if groupby == 'order':
543 548 output_dict = dict((eid, outputs) for eid,outputs in zip(targets, output_lists))
544 549 N = max(len(outputs) for outputs in output_lists)
545 550 for i in range(N):
546 551 for eid in targets:
547 552 outputs = output_dict[eid]
548 553 if len(outputs) >= N:
549 554 _raw_text('[output:%i]' % eid)
550 555 self._republish_displaypub(outputs[i], eid)
551 556 else:
552 557 # republish displaypub output
553 558 for eid,outputs in zip(targets, output_lists):
554 559 if outputs:
555 560 _raw_text('[output:%i]' % eid)
556 561 for output in outputs:
557 562 self._republish_displaypub(output, eid)
558 563
559 564 # finally, add pyout:
560 565 for eid,r,pyout in zip(targets, results, pyouts):
561 566 if pyout is not None:
562 567 display(r)
563 568
564 569 else:
565 570 raise ValueError("groupby must be one of 'type', 'engine', 'collate', not %r" % groupby)
566 571
567 572
568 573
569 574
570 575 class AsyncMapResult(AsyncResult):
571 576 """Class for representing results of non-blocking gathers.
572 577
573 578 This will properly reconstruct the gather.
574 579
575 580 This class is iterable at any time, and will wait on results as they come.
576 581
577 582 If ordered=False, then the first results to arrive will come first, otherwise
578 583 results will be yielded in the order they were submitted.
579 584
580 585 """
581 586
582 587 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
583 588 AsyncResult.__init__(self, client, msg_ids, fname=fname)
584 589 self._mapObject = mapObject
585 590 self._single_result = False
586 591 self.ordered = ordered
587 592
588 593 def _reconstruct_result(self, res):
589 594 """Perform the gather on the actual results."""
590 595 return self._mapObject.joinPartitions(res)
591 596
592 597 # asynchronous iterator:
593 598 def __iter__(self):
594 599 it = self._ordered_iter if self.ordered else self._unordered_iter
595 600 for r in it():
596 601 yield r
597 602
598 603 # asynchronous ordered iterator:
599 604 def _ordered_iter(self):
600 605 """iterator for results *as they arrive*, preserving submission order."""
601 606 try:
602 607 rlist = self.get(0)
603 608 except error.TimeoutError:
604 609 # wait for each result individually
605 610 for msg_id in self.msg_ids:
606 611 ar = AsyncResult(self._client, msg_id, self._fname)
607 612 rlist = ar.get()
608 613 try:
609 614 for r in rlist:
610 615 yield r
611 616 except TypeError:
612 617 # flattened, not a list
613 618 # this could get broken by flattened data that returns iterables
614 619 # but most calls to map do not expose the `flatten` argument
615 620 yield rlist
616 621 else:
617 622 # already done
618 623 for r in rlist:
619 624 yield r
620 625
621 626 # asynchronous unordered iterator:
622 627 def _unordered_iter(self):
623 628 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
624 629 try:
625 630 rlist = self.get(0)
626 631 except error.TimeoutError:
627 632 pending = set(self.msg_ids)
628 633 while pending:
629 634 try:
630 635 self._client.wait(pending, 1e-3)
631 636 except error.TimeoutError:
632 637 # ignore timeout error, because that only means
633 638 # *some* jobs are outstanding
634 639 pass
635 640 # update ready set with those no longer outstanding:
636 641 ready = pending.difference(self._client.outstanding)
637 642 # update pending to exclude those that are finished
638 643 pending = pending.difference(ready)
639 644 while ready:
640 645 msg_id = ready.pop()
641 646 ar = AsyncResult(self._client, msg_id, self._fname)
642 647 rlist = ar.get()
643 648 try:
644 649 for r in rlist:
645 650 yield r
646 651 except TypeError:
647 652 # flattened, not a list
648 653 # this could get broken by flattened data that returns iterables
649 654 # but most calls to map do not expose the `flatten` argument
650 655 yield rlist
651 656 else:
652 657 # already done
653 658 for r in rlist:
654 659 yield r
655 660
656 661
657 662 class AsyncHubResult(AsyncResult):
658 663 """Class to wrap pending results that must be requested from the Hub.
659 664
660 665 Note that waiting/polling on these objects requires polling the Hubover the network,
661 666 so use `AsyncHubResult.wait()` sparingly.
662 667 """
663 668
664 669 def _wait_for_outputs(self, timeout=-1):
665 670 """no-op, because HubResults are never incomplete"""
666 671 self._outputs_ready = True
667 672
668 673 def wait(self, timeout=-1):
669 674 """wait for result to complete."""
670 675 start = time.time()
671 676 if self._ready:
672 677 return
673 678 local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
674 679 local_ready = self._client.wait(local_ids, timeout)
675 680 if local_ready:
676 681 remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
677 682 if not remote_ids:
678 683 self._ready = True
679 684 else:
680 685 rdict = self._client.result_status(remote_ids, status_only=False)
681 686 pending = rdict['pending']
682 687 while pending and (timeout < 0 or time.time() < start+timeout):
683 688 rdict = self._client.result_status(remote_ids, status_only=False)
684 689 pending = rdict['pending']
685 690 if pending:
686 691 time.sleep(0.1)
687 692 if not pending:
688 693 self._ready = True
689 694 if self._ready:
690 695 try:
691 696 results = map(self._client.results.get, self.msg_ids)
692 697 self._result = results
693 698 if self._single_result:
694 699 r = results[0]
695 700 if isinstance(r, Exception):
696 701 raise r
697 702 else:
698 703 results = error.collect_exceptions(results, self._fname)
699 704 self._result = self._reconstruct_result(results)
700 705 except Exception as e:
701 706 self._exception = e
702 707 self._success = False
703 708 else:
704 709 self._success = True
705 710 finally:
706 711 self._metadata = map(self._client.metadata.get, self.msg_ids)
707 712
708 713 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
General Comments 0
You need to be logged in to leave comments. Login now