##// END OF EJS Templates
add ownership to AsyncResult objects...
MinRK -
Show More
@@ -1,695 +1,703 b''
1 1 """AsyncResult objects for the client"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import print_function
7 7
8 8 import sys
9 9 import time
10 10 from datetime import datetime
11 11
12 12 from zmq import MessageTracker
13 13
14 14 from IPython.core.display import clear_output, display, display_pretty
15 15 from IPython.external.decorator import decorator
16 16 from IPython.parallel import error
17 17 from IPython.utils.py3compat import string_types
18 18
19 #-----------------------------------------------------------------------------
20 # Functions
21 #-----------------------------------------------------------------------------
22 19
23 20 def _raw_text(s):
24 21 display_pretty(s, raw=True)
25 22
26 #-----------------------------------------------------------------------------
27 # Classes
28 #-----------------------------------------------------------------------------
29 23
30 24 # global empty tracker that's always done:
31 25 finished_tracker = MessageTracker()
32 26
33 27 @decorator
34 28 def check_ready(f, self, *args, **kwargs):
35 29 """Call spin() to sync state prior to calling the method."""
36 30 self.wait(0)
37 31 if not self._ready:
38 32 raise error.TimeoutError("result not ready")
39 33 return f(self, *args, **kwargs)
40 34
41 35 class AsyncResult(object):
42 36 """Class for representing results of non-blocking calls.
43 37
44 38 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
45 39 """
46 40
47 41 msg_ids = None
48 42 _targets = None
49 43 _tracker = None
50 44 _single_result = False
45 owner = False,
51 46
52 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
47 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None,
48 owner=False,
49 ):
53 50 if isinstance(msg_ids, string_types):
54 51 # always a list
55 52 msg_ids = [msg_ids]
56 53 self._single_result = True
57 54 else:
58 55 self._single_result = False
59 56 if tracker is None:
60 57 # default to always done
61 58 tracker = finished_tracker
62 59 self._client = client
63 60 self.msg_ids = msg_ids
64 61 self._fname=fname
65 62 self._targets = targets
66 63 self._tracker = tracker
64 self.owner = owner
67 65
68 66 self._ready = False
69 67 self._outputs_ready = False
70 68 self._success = None
71 69 self._metadata = [self._client.metadata[id] for id in self.msg_ids]
72 70
73 71 def __repr__(self):
74 72 if self._ready:
75 73 return "<%s: finished>"%(self.__class__.__name__)
76 74 else:
77 75 return "<%s: %s>"%(self.__class__.__name__,self._fname)
78 76
79 77
80 78 def _reconstruct_result(self, res):
81 79 """Reconstruct our result from actual result list (always a list)
82 80
83 81 Override me in subclasses for turning a list of results
84 82 into the expected form.
85 83 """
86 84 if self._single_result:
87 85 return res[0]
88 86 else:
89 87 return res
90 88
91 89 def get(self, timeout=-1):
92 90 """Return the result when it arrives.
93 91
94 92 If `timeout` is not ``None`` and the result does not arrive within
95 93 `timeout` seconds then ``TimeoutError`` is raised. If the
96 94 remote call raised an exception then that exception will be reraised
97 95 by get() inside a `RemoteError`.
98 96 """
99 97 if not self.ready():
100 98 self.wait(timeout)
101 99
102 100 if self._ready:
103 101 if self._success:
104 102 return self._result
105 103 else:
106 104 raise self._exception
107 105 else:
108 106 raise error.TimeoutError("Result not ready.")
109 107
110 108 def _check_ready(self):
111 109 if not self.ready():
112 110 raise error.TimeoutError("Result not ready.")
113 111
114 112 def ready(self):
115 113 """Return whether the call has completed."""
116 114 if not self._ready:
117 115 self.wait(0)
118 116 elif not self._outputs_ready:
119 117 self._wait_for_outputs(0)
120 118
121 119 return self._ready
122 120
123 121 def wait(self, timeout=-1):
124 122 """Wait until the result is available or until `timeout` seconds pass.
125 123
126 124 This method always returns None.
127 125 """
128 126 if self._ready:
129 127 self._wait_for_outputs(timeout)
130 128 return
131 129 self._ready = self._client.wait(self.msg_ids, timeout)
132 130 if self._ready:
133 131 try:
134 132 results = list(map(self._client.results.get, self.msg_ids))
135 133 self._result = results
136 134 if self._single_result:
137 135 r = results[0]
138 136 if isinstance(r, Exception):
139 137 raise r
140 138 else:
141 139 results = error.collect_exceptions(results, self._fname)
142 140 self._result = self._reconstruct_result(results)
143 141 except Exception as e:
144 142 self._exception = e
145 143 self._success = False
146 144 else:
147 145 self._success = True
148 146 finally:
149 147 if timeout is None or timeout < 0:
150 148 # cutoff infinite wait at 10s
151 149 timeout = 10
152 150 self._wait_for_outputs(timeout)
151
152 if self.owner:
153
154 self._metadata = [self._client.metadata.pop(mid) for mid in self.msg_ids]
155 [self._client.results.pop(mid) for mid in self.msg_ids]
156
153 157
154 158
155 159 def successful(self):
156 160 """Return whether the call completed without raising an exception.
157 161
158 162 Will raise ``AssertionError`` if the result is not ready.
159 163 """
160 164 assert self.ready()
161 165 return self._success
162 166
163 167 #----------------------------------------------------------------
164 168 # Extra methods not in mp.pool.AsyncResult
165 169 #----------------------------------------------------------------
166 170
167 171 def get_dict(self, timeout=-1):
168 172 """Get the results as a dict, keyed by engine_id.
169 173
170 174 timeout behavior is described in `get()`.
171 175 """
172 176
173 177 results = self.get(timeout)
174 178 if self._single_result:
175 179 results = [results]
176 180 engine_ids = [ md['engine_id'] for md in self._metadata ]
177 181
178 182
179 183 rdict = {}
180 184 for engine_id, result in zip(engine_ids, results):
181 185 if engine_id in rdict:
182 186 raise ValueError("Cannot build dict, %i jobs ran on engine #%i" % (
183 187 engine_ids.count(engine_id), engine_id)
184 188 )
185 189 else:
186 190 rdict[engine_id] = result
187 191
188 192 return rdict
189 193
190 194 @property
191 195 def result(self):
192 196 """result property wrapper for `get(timeout=-1)`."""
193 197 return self.get()
194 198
195 199 # abbreviated alias:
196 200 r = result
197 201
198 202 @property
199 203 def metadata(self):
200 204 """property for accessing execution metadata."""
201 205 if self._single_result:
202 206 return self._metadata[0]
203 207 else:
204 208 return self._metadata
205 209
206 210 @property
207 211 def result_dict(self):
208 212 """result property as a dict."""
209 213 return self.get_dict()
210 214
211 215 def __dict__(self):
212 216 return self.get_dict(0)
213 217
214 218 def abort(self):
215 219 """abort my tasks."""
216 220 assert not self.ready(), "Can't abort, I am already done!"
217 221 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
218 222
219 223 @property
220 224 def sent(self):
221 225 """check whether my messages have been sent."""
222 226 return self._tracker.done
223 227
224 228 def wait_for_send(self, timeout=-1):
225 229 """wait for pyzmq send to complete.
226 230
227 231 This is necessary when sending arrays that you intend to edit in-place.
228 232 `timeout` is in seconds, and will raise TimeoutError if it is reached
229 233 before the send completes.
230 234 """
231 235 return self._tracker.wait(timeout)
232 236
233 237 #-------------------------------------
234 238 # dict-access
235 239 #-------------------------------------
236 240
237 241 def __getitem__(self, key):
238 242 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
239 243 """
240 244 if isinstance(key, int):
241 245 self._check_ready()
242 246 return error.collect_exceptions([self._result[key]], self._fname)[0]
243 247 elif isinstance(key, slice):
244 248 self._check_ready()
245 249 return error.collect_exceptions(self._result[key], self._fname)
246 250 elif isinstance(key, string_types):
247 251 # metadata proxy *does not* require that results are done
248 252 self.wait(0)
249 253 values = [ md[key] for md in self._metadata ]
250 254 if self._single_result:
251 255 return values[0]
252 256 else:
253 257 return values
254 258 else:
255 259 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
256 260
257 261 def __getattr__(self, key):
258 262 """getattr maps to getitem for convenient attr access to metadata."""
259 263 try:
260 264 return self.__getitem__(key)
261 265 except (error.TimeoutError, KeyError):
262 266 raise AttributeError("%r object has no attribute %r"%(
263 267 self.__class__.__name__, key))
264 268
265 269 # asynchronous iterator:
266 270 def __iter__(self):
267 271 if self._single_result:
268 272 raise TypeError("AsyncResults with a single result are not iterable.")
269 273 try:
270 274 rlist = self.get(0)
271 275 except error.TimeoutError:
272 276 # wait for each result individually
273 277 for msg_id in self.msg_ids:
274 278 ar = AsyncResult(self._client, msg_id, self._fname)
275 279 yield ar.get()
276 280 else:
277 281 # already done
278 282 for r in rlist:
279 283 yield r
280 284
281 285 def __len__(self):
282 286 return len(self.msg_ids)
283 287
284 288 #-------------------------------------
285 289 # Sugar methods and attributes
286 290 #-------------------------------------
287 291
288 292 def timedelta(self, start, end, start_key=min, end_key=max):
289 293 """compute the difference between two sets of timestamps
290 294
291 295 The default behavior is to use the earliest of the first
292 296 and the latest of the second list, but this can be changed
293 297 by passing a different
294 298
295 299 Parameters
296 300 ----------
297 301
298 302 start : one or more datetime objects (e.g. ar.submitted)
299 303 end : one or more datetime objects (e.g. ar.received)
300 304 start_key : callable
301 305 Function to call on `start` to extract the relevant
302 306 entry [defalt: min]
303 307 end_key : callable
304 308 Function to call on `end` to extract the relevant
305 309 entry [default: max]
306 310
307 311 Returns
308 312 -------
309 313
310 314 dt : float
311 315 The time elapsed (in seconds) between the two selected timestamps.
312 316 """
313 317 if not isinstance(start, datetime):
314 318 # handle single_result AsyncResults, where ar.stamp is single object,
315 319 # not a list
316 320 start = start_key(start)
317 321 if not isinstance(end, datetime):
318 322 # handle single_result AsyncResults, where ar.stamp is single object,
319 323 # not a list
320 324 end = end_key(end)
321 325 return (end - start).total_seconds()
322 326
323 327 @property
324 328 def progress(self):
325 329 """the number of tasks which have been completed at this point.
326 330
327 331 Fractional progress would be given by 1.0 * ar.progress / len(ar)
328 332 """
329 333 self.wait(0)
330 334 return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
331 335
332 336 @property
333 337 def elapsed(self):
334 338 """elapsed time since initial submission"""
335 339 if self.ready():
336 340 return self.wall_time
337 341
338 342 now = submitted = datetime.now()
339 343 for msg_id in self.msg_ids:
340 344 if msg_id in self._client.metadata:
341 345 stamp = self._client.metadata[msg_id]['submitted']
342 346 if stamp and stamp < submitted:
343 347 submitted = stamp
344 348 return (now-submitted).total_seconds()
345 349
346 350 @property
347 351 @check_ready
348 352 def serial_time(self):
349 353 """serial computation time of a parallel calculation
350 354
351 355 Computed as the sum of (completed-started) of each task
352 356 """
353 357 t = 0
354 358 for md in self._metadata:
355 359 t += (md['completed'] - md['started']).total_seconds()
356 360 return t
357 361
358 362 @property
359 363 @check_ready
360 364 def wall_time(self):
361 365 """actual computation time of a parallel calculation
362 366
363 367 Computed as the time between the latest `received` stamp
364 368 and the earliest `submitted`.
365 369
366 370 Only reliable if Client was spinning/waiting when the task finished, because
367 371 the `received` timestamp is created when a result is pulled off of the zmq queue,
368 372 which happens as a result of `client.spin()`.
369 373
370 374 For similar comparison of other timestamp pairs, check out AsyncResult.timedelta.
371 375
372 376 """
373 377 return self.timedelta(self.submitted, self.received)
374 378
375 379 def wait_interactive(self, interval=1., timeout=-1):
376 380 """interactive wait, printing progress at regular intervals"""
377 381 if timeout is None:
378 382 timeout = -1
379 383 N = len(self)
380 384 tic = time.time()
381 385 while not self.ready() and (timeout < 0 or time.time() - tic <= timeout):
382 386 self.wait(interval)
383 387 clear_output(wait=True)
384 388 print("%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed), end="")
385 389 sys.stdout.flush()
386 390 print()
387 391 print("done")
388 392
389 393 def _republish_displaypub(self, content, eid):
390 394 """republish individual displaypub content dicts"""
391 395 try:
392 396 ip = get_ipython()
393 397 except NameError:
394 398 # displaypub is meaningless outside IPython
395 399 return
396 400 md = content['metadata'] or {}
397 401 md['engine'] = eid
398 402 ip.display_pub.publish(content['source'], content['data'], md)
399 403
400 404 def _display_stream(self, text, prefix='', file=None):
401 405 if not text:
402 406 # nothing to display
403 407 return
404 408 if file is None:
405 409 file = sys.stdout
406 410 end = '' if text.endswith('\n') else '\n'
407 411
408 412 multiline = text.count('\n') > int(text.endswith('\n'))
409 413 if prefix and multiline and not text.startswith('\n'):
410 414 prefix = prefix + '\n'
411 415 print("%s%s" % (prefix, text), file=file, end=end)
412 416
413 417
414 418 def _display_single_result(self):
415 419 self._display_stream(self.stdout)
416 420 self._display_stream(self.stderr, file=sys.stderr)
417 421
418 422 try:
419 423 get_ipython()
420 424 except NameError:
421 425 # displaypub is meaningless outside IPython
422 426 return
423 427
424 428 for output in self.outputs:
425 429 self._republish_displaypub(output, self.engine_id)
426 430
427 431 if self.execute_result is not None:
428 432 display(self.get())
429 433
430 434 def _wait_for_outputs(self, timeout=-1):
431 435 """wait for the 'status=idle' message that indicates we have all outputs
432 436 """
433 437 if self._outputs_ready or not self._success:
434 438 # don't wait on errors
435 439 return
436 440
437 441 # cast None to -1 for infinite timeout
438 442 if timeout is None:
439 443 timeout = -1
440 444
441 445 tic = time.time()
442 446 while True:
443 447 self._client._flush_iopub(self._client._iopub_socket)
444 448 self._outputs_ready = all(md['outputs_ready']
445 449 for md in self._metadata)
446 450 if self._outputs_ready or \
447 451 (timeout >= 0 and time.time() > tic + timeout):
448 452 break
449 453 time.sleep(0.01)
450 454
451 455 @check_ready
452 456 def display_outputs(self, groupby="type"):
453 457 """republish the outputs of the computation
454 458
455 459 Parameters
456 460 ----------
457 461
458 462 groupby : str [default: type]
459 463 if 'type':
460 464 Group outputs by type (show all stdout, then all stderr, etc.):
461 465
462 466 [stdout:1] foo
463 467 [stdout:2] foo
464 468 [stderr:1] bar
465 469 [stderr:2] bar
466 470 if 'engine':
467 471 Display outputs for each engine before moving on to the next:
468 472
469 473 [stdout:1] foo
470 474 [stderr:1] bar
471 475 [stdout:2] foo
472 476 [stderr:2] bar
473 477
474 478 if 'order':
475 479 Like 'type', but further collate individual displaypub
476 480 outputs. This is meant for cases of each command producing
477 481 several plots, and you would like to see all of the first
478 482 plots together, then all of the second plots, and so on.
479 483 """
480 484 if self._single_result:
481 485 self._display_single_result()
482 486 return
483 487
484 488 stdouts = self.stdout
485 489 stderrs = self.stderr
486 490 execute_results = self.execute_result
487 491 output_lists = self.outputs
488 492 results = self.get()
489 493
490 494 targets = self.engine_id
491 495
492 496 if groupby == "engine":
493 497 for eid,stdout,stderr,outputs,r,execute_result in zip(
494 498 targets, stdouts, stderrs, output_lists, results, execute_results
495 499 ):
496 500 self._display_stream(stdout, '[stdout:%i] ' % eid)
497 501 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
498 502
499 503 try:
500 504 get_ipython()
501 505 except NameError:
502 506 # displaypub is meaningless outside IPython
503 507 return
504 508
505 509 if outputs or execute_result is not None:
506 510 _raw_text('[output:%i]' % eid)
507 511
508 512 for output in outputs:
509 513 self._republish_displaypub(output, eid)
510 514
511 515 if execute_result is not None:
512 516 display(r)
513 517
514 518 elif groupby in ('type', 'order'):
515 519 # republish stdout:
516 520 for eid,stdout in zip(targets, stdouts):
517 521 self._display_stream(stdout, '[stdout:%i] ' % eid)
518 522
519 523 # republish stderr:
520 524 for eid,stderr in zip(targets, stderrs):
521 525 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
522 526
523 527 try:
524 528 get_ipython()
525 529 except NameError:
526 530 # displaypub is meaningless outside IPython
527 531 return
528 532
529 533 if groupby == 'order':
530 534 output_dict = dict((eid, outputs) for eid,outputs in zip(targets, output_lists))
531 535 N = max(len(outputs) for outputs in output_lists)
532 536 for i in range(N):
533 537 for eid in targets:
534 538 outputs = output_dict[eid]
535 539 if len(outputs) >= N:
536 540 _raw_text('[output:%i]' % eid)
537 541 self._republish_displaypub(outputs[i], eid)
538 542 else:
539 543 # republish displaypub output
540 544 for eid,outputs in zip(targets, output_lists):
541 545 if outputs:
542 546 _raw_text('[output:%i]' % eid)
543 547 for output in outputs:
544 548 self._republish_displaypub(output, eid)
545 549
546 550 # finally, add execute_result:
547 551 for eid,r,execute_result in zip(targets, results, execute_results):
548 552 if execute_result is not None:
549 553 display(r)
550 554
551 555 else:
552 556 raise ValueError("groupby must be one of 'type', 'engine', 'collate', not %r" % groupby)
553 557
554 558
555 559
556 560
557 561 class AsyncMapResult(AsyncResult):
558 562 """Class for representing results of non-blocking gathers.
559 563
560 564 This will properly reconstruct the gather.
561 565
562 566 This class is iterable at any time, and will wait on results as they come.
563 567
564 568 If ordered=False, then the first results to arrive will come first, otherwise
565 569 results will be yielded in the order they were submitted.
566 570
567 571 """
568 572
569 573 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
570 574 AsyncResult.__init__(self, client, msg_ids, fname=fname)
571 575 self._mapObject = mapObject
572 576 self._single_result = False
573 577 self.ordered = ordered
574 578
575 579 def _reconstruct_result(self, res):
576 580 """Perform the gather on the actual results."""
577 581 return self._mapObject.joinPartitions(res)
578 582
579 583 # asynchronous iterator:
580 584 def __iter__(self):
581 585 it = self._ordered_iter if self.ordered else self._unordered_iter
582 586 for r in it():
583 587 yield r
584 588
585 589 # asynchronous ordered iterator:
586 590 def _ordered_iter(self):
587 591 """iterator for results *as they arrive*, preserving submission order."""
588 592 try:
589 593 rlist = self.get(0)
590 594 except error.TimeoutError:
591 595 # wait for each result individually
592 596 for msg_id in self.msg_ids:
593 597 ar = AsyncResult(self._client, msg_id, self._fname)
594 598 rlist = ar.get()
595 599 try:
596 600 for r in rlist:
597 601 yield r
598 602 except TypeError:
599 603 # flattened, not a list
600 604 # this could get broken by flattened data that returns iterables
601 605 # but most calls to map do not expose the `flatten` argument
602 606 yield rlist
603 607 else:
604 608 # already done
605 609 for r in rlist:
606 610 yield r
607 611
608 612 # asynchronous unordered iterator:
609 613 def _unordered_iter(self):
610 614 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
611 615 try:
612 616 rlist = self.get(0)
613 617 except error.TimeoutError:
614 618 pending = set(self.msg_ids)
615 619 while pending:
616 620 try:
617 621 self._client.wait(pending, 1e-3)
618 622 except error.TimeoutError:
619 623 # ignore timeout error, because that only means
620 624 # *some* jobs are outstanding
621 625 pass
622 626 # update ready set with those no longer outstanding:
623 627 ready = pending.difference(self._client.outstanding)
624 628 # update pending to exclude those that are finished
625 629 pending = pending.difference(ready)
626 630 while ready:
627 631 msg_id = ready.pop()
628 632 ar = AsyncResult(self._client, msg_id, self._fname)
629 633 rlist = ar.get()
630 634 try:
631 635 for r in rlist:
632 636 yield r
633 637 except TypeError:
634 638 # flattened, not a list
635 639 # this could get broken by flattened data that returns iterables
636 640 # but most calls to map do not expose the `flatten` argument
637 641 yield rlist
638 642 else:
639 643 # already done
640 644 for r in rlist:
641 645 yield r
642 646
643 647
644 648 class AsyncHubResult(AsyncResult):
645 649 """Class to wrap pending results that must be requested from the Hub.
646 650
647 651 Note that waiting/polling on these objects requires polling the Hubover the network,
648 652 so use `AsyncHubResult.wait()` sparingly.
649 653 """
650 654
651 655 def _wait_for_outputs(self, timeout=-1):
652 656 """no-op, because HubResults are never incomplete"""
653 657 self._outputs_ready = True
654 658
655 659 def wait(self, timeout=-1):
656 660 """wait for result to complete."""
657 661 start = time.time()
658 662 if self._ready:
659 663 return
660 664 local_ids = [m for m in self.msg_ids if m in self._client.outstanding]
661 665 local_ready = self._client.wait(local_ids, timeout)
662 666 if local_ready:
663 667 remote_ids = [m for m in self.msg_ids if m not in self._client.results]
664 668 if not remote_ids:
665 669 self._ready = True
666 670 else:
667 671 rdict = self._client.result_status(remote_ids, status_only=False)
668 672 pending = rdict['pending']
669 673 while pending and (timeout < 0 or time.time() < start+timeout):
670 674 rdict = self._client.result_status(remote_ids, status_only=False)
671 675 pending = rdict['pending']
672 676 if pending:
673 677 time.sleep(0.1)
674 678 if not pending:
675 679 self._ready = True
676 680 if self._ready:
677 681 try:
678 682 results = list(map(self._client.results.get, self.msg_ids))
679 683 self._result = results
680 684 if self._single_result:
681 685 r = results[0]
682 686 if isinstance(r, Exception):
683 687 raise r
684 688 else:
685 689 results = error.collect_exceptions(results, self._fname)
686 690 self._result = self._reconstruct_result(results)
687 691 except Exception as e:
688 692 self._exception = e
689 693 self._success = False
690 694 else:
691 695 self._success = True
692 696 finally:
693 697 self._metadata = [self._client.metadata[mid] for mid in self.msg_ids]
698 if self.owner:
699 [self._client.metadata.pop(mid) for mid in self.msg_ids]
700 [self._client.results.pop(mid) for mid in self.msg_ids]
701
694 702
695 703 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
@@ -1,1863 +1,1868 b''
1 1 """A semi-synchronous Client for IPython parallel"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import print_function
7 7
8 8 import os
9 9 import json
10 10 import sys
11 11 from threading import Thread, Event
12 12 import time
13 13 import warnings
14 14 from datetime import datetime
15 15 from getpass import getpass
16 16 from pprint import pprint
17 17
18 18 pjoin = os.path.join
19 19
20 20 import zmq
21 21
22 22 from IPython.config.configurable import MultipleInstanceError
23 23 from IPython.core.application import BaseIPythonApplication
24 24 from IPython.core.profiledir import ProfileDir, ProfileDirError
25 25
26 26 from IPython.utils.capture import RichOutput
27 27 from IPython.utils.coloransi import TermColors
28 28 from IPython.utils.jsonutil import rekey, extract_dates, parse_date
29 29 from IPython.utils.localinterfaces import localhost, is_local_ip
30 30 from IPython.utils.path import get_ipython_dir
31 31 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
32 32 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
33 33 Dict, List, Bool, Set, Any)
34 34 from IPython.external.decorator import decorator
35 35 from IPython.external.ssh import tunnel
36 36
37 37 from IPython.parallel import Reference
38 38 from IPython.parallel import error
39 39 from IPython.parallel import util
40 40
41 41 from IPython.kernel.zmq.session import Session, Message
42 42 from IPython.kernel.zmq import serialize
43 43
44 44 from .asyncresult import AsyncResult, AsyncHubResult
45 45 from .view import DirectView, LoadBalancedView
46 46
47 47 #--------------------------------------------------------------------------
48 48 # Decorators for Client methods
49 49 #--------------------------------------------------------------------------
50 50
51 51 @decorator
52 52 def spin_first(f, self, *args, **kwargs):
53 53 """Call spin() to sync state prior to calling the method."""
54 54 self.spin()
55 55 return f(self, *args, **kwargs)
56 56
57 57
58 58 #--------------------------------------------------------------------------
59 59 # Classes
60 60 #--------------------------------------------------------------------------
61 61
62 62
63 63 class ExecuteReply(RichOutput):
64 64 """wrapper for finished Execute results"""
65 65 def __init__(self, msg_id, content, metadata):
66 66 self.msg_id = msg_id
67 67 self._content = content
68 68 self.execution_count = content['execution_count']
69 69 self.metadata = metadata
70 70
71 71 # RichOutput overrides
72 72
73 73 @property
74 74 def source(self):
75 75 execute_result = self.metadata['execute_result']
76 76 if execute_result:
77 77 return execute_result.get('source', '')
78 78
79 79 @property
80 80 def data(self):
81 81 execute_result = self.metadata['execute_result']
82 82 if execute_result:
83 83 return execute_result.get('data', {})
84 84
85 85 @property
86 86 def _metadata(self):
87 87 execute_result = self.metadata['execute_result']
88 88 if execute_result:
89 89 return execute_result.get('metadata', {})
90 90
91 91 def display(self):
92 92 from IPython.display import publish_display_data
93 93 publish_display_data(self.data, self.metadata)
94 94
95 95 def _repr_mime_(self, mime):
96 96 if mime not in self.data:
97 97 return
98 98 data = self.data[mime]
99 99 if mime in self._metadata:
100 100 return data, self._metadata[mime]
101 101 else:
102 102 return data
103 103
104 104 def __getitem__(self, key):
105 105 return self.metadata[key]
106 106
107 107 def __getattr__(self, key):
108 108 if key not in self.metadata:
109 109 raise AttributeError(key)
110 110 return self.metadata[key]
111 111
112 112 def __repr__(self):
113 113 execute_result = self.metadata['execute_result'] or {'data':{}}
114 114 text_out = execute_result['data'].get('text/plain', '')
115 115 if len(text_out) > 32:
116 116 text_out = text_out[:29] + '...'
117 117
118 118 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
119 119
120 120 def _repr_pretty_(self, p, cycle):
121 121 execute_result = self.metadata['execute_result'] or {'data':{}}
122 122 text_out = execute_result['data'].get('text/plain', '')
123 123
124 124 if not text_out:
125 125 return
126 126
127 127 try:
128 128 ip = get_ipython()
129 129 except NameError:
130 130 colors = "NoColor"
131 131 else:
132 132 colors = ip.colors
133 133
134 134 if colors == "NoColor":
135 135 out = normal = ""
136 136 else:
137 137 out = TermColors.Red
138 138 normal = TermColors.Normal
139 139
140 140 if '\n' in text_out and not text_out.startswith('\n'):
141 141 # add newline for multiline reprs
142 142 text_out = '\n' + text_out
143 143
144 144 p.text(
145 145 out + u'Out[%i:%i]: ' % (
146 146 self.metadata['engine_id'], self.execution_count
147 147 ) + normal + text_out
148 148 )
149 149
150 150
151 151 class Metadata(dict):
152 152 """Subclass of dict for initializing metadata values.
153 153
154 154 Attribute access works on keys.
155 155
156 156 These objects have a strict set of keys - errors will raise if you try
157 157 to add new keys.
158 158 """
159 159 def __init__(self, *args, **kwargs):
160 160 dict.__init__(self)
161 161 md = {'msg_id' : None,
162 162 'submitted' : None,
163 163 'started' : None,
164 164 'completed' : None,
165 165 'received' : None,
166 166 'engine_uuid' : None,
167 167 'engine_id' : None,
168 168 'follow' : None,
169 169 'after' : None,
170 170 'status' : None,
171 171
172 172 'execute_input' : None,
173 173 'execute_result' : None,
174 174 'error' : None,
175 175 'stdout' : '',
176 176 'stderr' : '',
177 177 'outputs' : [],
178 178 'data': {},
179 179 'outputs_ready' : False,
180 180 }
181 181 self.update(md)
182 182 self.update(dict(*args, **kwargs))
183 183
184 184 def __getattr__(self, key):
185 185 """getattr aliased to getitem"""
186 186 if key in self:
187 187 return self[key]
188 188 else:
189 189 raise AttributeError(key)
190 190
191 191 def __setattr__(self, key, value):
192 192 """setattr aliased to setitem, with strict"""
193 193 if key in self:
194 194 self[key] = value
195 195 else:
196 196 raise AttributeError(key)
197 197
198 198 def __setitem__(self, key, value):
199 199 """strict static key enforcement"""
200 200 if key in self:
201 201 dict.__setitem__(self, key, value)
202 202 else:
203 203 raise KeyError(key)
204 204
205 205
206 206 class Client(HasTraits):
207 207 """A semi-synchronous client to the IPython ZMQ cluster
208 208
209 209 Parameters
210 210 ----------
211 211
212 212 url_file : str/unicode; path to ipcontroller-client.json
213 213 This JSON file should contain all the information needed to connect to a cluster,
214 214 and is likely the only argument needed.
215 215 Connection information for the Hub's registration. If a json connector
216 216 file is given, then likely no further configuration is necessary.
217 217 [Default: use profile]
218 218 profile : bytes
219 219 The name of the Cluster profile to be used to find connector information.
220 220 If run from an IPython application, the default profile will be the same
221 221 as the running application, otherwise it will be 'default'.
222 222 cluster_id : str
223 223 String id to added to runtime files, to prevent name collisions when using
224 224 multiple clusters with a single profile simultaneously.
225 225 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
226 226 Since this is text inserted into filenames, typical recommendations apply:
227 227 Simple character strings are ideal, and spaces are not recommended (but
228 228 should generally work)
229 229 context : zmq.Context
230 230 Pass an existing zmq.Context instance, otherwise the client will create its own.
231 231 debug : bool
232 232 flag for lots of message printing for debug purposes
233 233 timeout : int/float
234 234 time (in seconds) to wait for connection replies from the Hub
235 235 [Default: 10]
236 236
237 237 #-------------- session related args ----------------
238 238
239 239 config : Config object
240 240 If specified, this will be relayed to the Session for configuration
241 241 username : str
242 242 set username for the session object
243 243
244 244 #-------------- ssh related args ----------------
245 245 # These are args for configuring the ssh tunnel to be used
246 246 # credentials are used to forward connections over ssh to the Controller
247 247 # Note that the ip given in `addr` needs to be relative to sshserver
248 248 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
249 249 # and set sshserver as the same machine the Controller is on. However,
250 250 # the only requirement is that sshserver is able to see the Controller
251 251 # (i.e. is within the same trusted network).
252 252
253 253 sshserver : str
254 254 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
255 255 If keyfile or password is specified, and this is not, it will default to
256 256 the ip given in addr.
257 257 sshkey : str; path to ssh private key file
258 258 This specifies a key to be used in ssh login, default None.
259 259 Regular default ssh keys will be used without specifying this argument.
260 260 password : str
261 261 Your ssh password to sshserver. Note that if this is left None,
262 262 you will be prompted for it if passwordless key based login is unavailable.
263 263 paramiko : bool
264 264 flag for whether to use paramiko instead of shell ssh for tunneling.
265 265 [default: True on win32, False else]
266 266
267 267
268 268 Attributes
269 269 ----------
270 270
271 271 ids : list of int engine IDs
272 272 requesting the ids attribute always synchronizes
273 273 the registration state. To request ids without synchronization,
274 274 use semi-private _ids attributes.
275 275
276 276 history : list of msg_ids
277 277 a list of msg_ids, keeping track of all the execution
278 278 messages you have submitted in order.
279 279
280 280 outstanding : set of msg_ids
281 281 a set of msg_ids that have been submitted, but whose
282 282 results have not yet been received.
283 283
284 284 results : dict
285 285 a dict of all our results, keyed by msg_id
286 286
287 287 block : bool
288 288 determines default behavior when block not specified
289 289 in execution methods
290 290
291 291 Methods
292 292 -------
293 293
294 294 spin
295 295 flushes incoming results and registration state changes
296 296 control methods spin, and requesting `ids` also ensures up to date
297 297
298 298 wait
299 299 wait on one or more msg_ids
300 300
301 301 execution methods
302 302 apply
303 303 legacy: execute, run
304 304
305 305 data movement
306 306 push, pull, scatter, gather
307 307
308 308 query methods
309 309 queue_status, get_result, purge, result_status
310 310
311 311 control methods
312 312 abort, shutdown
313 313
314 314 """
315 315
316 316
317 317 block = Bool(False)
318 318 outstanding = Set()
319 319 results = Instance('collections.defaultdict', (dict,))
320 320 metadata = Instance('collections.defaultdict', (Metadata,))
321 321 history = List()
322 322 debug = Bool(False)
323 323 _spin_thread = Any()
324 324 _stop_spinning = Any()
325 325
326 326 profile=Unicode()
327 327 def _profile_default(self):
328 328 if BaseIPythonApplication.initialized():
329 329 # an IPython app *might* be running, try to get its profile
330 330 try:
331 331 return BaseIPythonApplication.instance().profile
332 332 except (AttributeError, MultipleInstanceError):
333 333 # could be a *different* subclass of config.Application,
334 334 # which would raise one of these two errors.
335 335 return u'default'
336 336 else:
337 337 return u'default'
338 338
339 339
340 340 _outstanding_dict = Instance('collections.defaultdict', (set,))
341 341 _ids = List()
342 342 _connected=Bool(False)
343 343 _ssh=Bool(False)
344 344 _context = Instance('zmq.Context')
345 345 _config = Dict()
346 346 _engines=Instance(util.ReverseDict, (), {})
347 347 # _hub_socket=Instance('zmq.Socket')
348 348 _query_socket=Instance('zmq.Socket')
349 349 _control_socket=Instance('zmq.Socket')
350 350 _iopub_socket=Instance('zmq.Socket')
351 351 _notification_socket=Instance('zmq.Socket')
352 352 _mux_socket=Instance('zmq.Socket')
353 353 _task_socket=Instance('zmq.Socket')
354 354 _task_scheme=Unicode()
355 355 _closed = False
356 356 _ignored_control_replies=Integer(0)
357 357 _ignored_hub_replies=Integer(0)
358 358
359 359 def __new__(self, *args, **kw):
360 360 # don't raise on positional args
361 361 return HasTraits.__new__(self, **kw)
362 362
363 363 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
364 364 context=None, debug=False,
365 365 sshserver=None, sshkey=None, password=None, paramiko=None,
366 366 timeout=10, cluster_id=None, **extra_args
367 367 ):
368 368 if profile:
369 369 super(Client, self).__init__(debug=debug, profile=profile)
370 370 else:
371 371 super(Client, self).__init__(debug=debug)
372 372 if context is None:
373 373 context = zmq.Context.instance()
374 374 self._context = context
375 375 self._stop_spinning = Event()
376 376
377 377 if 'url_or_file' in extra_args:
378 378 url_file = extra_args['url_or_file']
379 379 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
380 380
381 381 if url_file and util.is_url(url_file):
382 382 raise ValueError("single urls cannot be specified, url-files must be used.")
383 383
384 384 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
385 385
386 386 if self._cd is not None:
387 387 if url_file is None:
388 388 if not cluster_id:
389 389 client_json = 'ipcontroller-client.json'
390 390 else:
391 391 client_json = 'ipcontroller-%s-client.json' % cluster_id
392 392 url_file = pjoin(self._cd.security_dir, client_json)
393 393 if url_file is None:
394 394 raise ValueError(
395 395 "I can't find enough information to connect to a hub!"
396 396 " Please specify at least one of url_file or profile."
397 397 )
398 398
399 399 with open(url_file) as f:
400 400 cfg = json.load(f)
401 401
402 402 self._task_scheme = cfg['task_scheme']
403 403
404 404 # sync defaults from args, json:
405 405 if sshserver:
406 406 cfg['ssh'] = sshserver
407 407
408 408 location = cfg.setdefault('location', None)
409 409
410 410 proto,addr = cfg['interface'].split('://')
411 411 addr = util.disambiguate_ip_address(addr, location)
412 412 cfg['interface'] = "%s://%s" % (proto, addr)
413 413
414 414 # turn interface,port into full urls:
415 415 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
416 416 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
417 417
418 418 url = cfg['registration']
419 419
420 420 if location is not None and addr == localhost():
421 421 # location specified, and connection is expected to be local
422 422 if not is_local_ip(location) and not sshserver:
423 423 # load ssh from JSON *only* if the controller is not on
424 424 # this machine
425 425 sshserver=cfg['ssh']
426 426 if not is_local_ip(location) and not sshserver:
427 427 # warn if no ssh specified, but SSH is probably needed
428 428 # This is only a warning, because the most likely cause
429 429 # is a local Controller on a laptop whose IP is dynamic
430 430 warnings.warn("""
431 431 Controller appears to be listening on localhost, but not on this machine.
432 432 If this is true, you should specify Client(...,sshserver='you@%s')
433 433 or instruct your controller to listen on an external IP."""%location,
434 434 RuntimeWarning)
435 435 elif not sshserver:
436 436 # otherwise sync with cfg
437 437 sshserver = cfg['ssh']
438 438
439 439 self._config = cfg
440 440
441 441 self._ssh = bool(sshserver or sshkey or password)
442 442 if self._ssh and sshserver is None:
443 443 # default to ssh via localhost
444 444 sshserver = addr
445 445 if self._ssh and password is None:
446 446 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
447 447 password=False
448 448 else:
449 449 password = getpass("SSH Password for %s: "%sshserver)
450 450 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
451 451
452 452 # configure and construct the session
453 453 try:
454 454 extra_args['packer'] = cfg['pack']
455 455 extra_args['unpacker'] = cfg['unpack']
456 456 extra_args['key'] = cast_bytes(cfg['key'])
457 457 extra_args['signature_scheme'] = cfg['signature_scheme']
458 458 except KeyError as exc:
459 459 msg = '\n'.join([
460 460 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
461 461 "If you are reusing connection files, remove them and start ipcontroller again."
462 462 ])
463 463 raise ValueError(msg.format(exc.message))
464 464
465 465 self.session = Session(**extra_args)
466 466
467 467 self._query_socket = self._context.socket(zmq.DEALER)
468 468
469 469 if self._ssh:
470 470 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
471 471 else:
472 472 self._query_socket.connect(cfg['registration'])
473 473
474 474 self.session.debug = self.debug
475 475
476 476 self._notification_handlers = {'registration_notification' : self._register_engine,
477 477 'unregistration_notification' : self._unregister_engine,
478 478 'shutdown_notification' : lambda msg: self.close(),
479 479 }
480 480 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
481 481 'apply_reply' : self._handle_apply_reply}
482 482
483 483 try:
484 484 self._connect(sshserver, ssh_kwargs, timeout)
485 485 except:
486 486 self.close(linger=0)
487 487 raise
488 488
489 489 # last step: setup magics, if we are in IPython:
490 490
491 491 try:
492 492 ip = get_ipython()
493 493 except NameError:
494 494 return
495 495 else:
496 496 if 'px' not in ip.magics_manager.magics:
497 497 # in IPython but we are the first Client.
498 498 # activate a default view for parallel magics.
499 499 self.activate()
500 500
501 501 def __del__(self):
502 502 """cleanup sockets, but _not_ context."""
503 503 self.close()
504 504
505 505 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
506 506 if ipython_dir is None:
507 507 ipython_dir = get_ipython_dir()
508 508 if profile_dir is not None:
509 509 try:
510 510 self._cd = ProfileDir.find_profile_dir(profile_dir)
511 511 return
512 512 except ProfileDirError:
513 513 pass
514 514 elif profile is not None:
515 515 try:
516 516 self._cd = ProfileDir.find_profile_dir_by_name(
517 517 ipython_dir, profile)
518 518 return
519 519 except ProfileDirError:
520 520 pass
521 521 self._cd = None
522 522
523 523 def _update_engines(self, engines):
524 524 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
525 525 for k,v in iteritems(engines):
526 526 eid = int(k)
527 527 if eid not in self._engines:
528 528 self._ids.append(eid)
529 529 self._engines[eid] = v
530 530 self._ids = sorted(self._ids)
531 531 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
532 532 self._task_scheme == 'pure' and self._task_socket:
533 533 self._stop_scheduling_tasks()
534 534
535 535 def _stop_scheduling_tasks(self):
536 536 """Stop scheduling tasks because an engine has been unregistered
537 537 from a pure ZMQ scheduler.
538 538 """
539 539 self._task_socket.close()
540 540 self._task_socket = None
541 541 msg = "An engine has been unregistered, and we are using pure " +\
542 542 "ZMQ task scheduling. Task farming will be disabled."
543 543 if self.outstanding:
544 544 msg += " If you were running tasks when this happened, " +\
545 545 "some `outstanding` msg_ids may never resolve."
546 546 warnings.warn(msg, RuntimeWarning)
547 547
548 548 def _build_targets(self, targets):
549 549 """Turn valid target IDs or 'all' into two lists:
550 550 (int_ids, uuids).
551 551 """
552 552 if not self._ids:
553 553 # flush notification socket if no engines yet, just in case
554 554 if not self.ids:
555 555 raise error.NoEnginesRegistered("Can't build targets without any engines")
556 556
557 557 if targets is None:
558 558 targets = self._ids
559 559 elif isinstance(targets, string_types):
560 560 if targets.lower() == 'all':
561 561 targets = self._ids
562 562 else:
563 563 raise TypeError("%r not valid str target, must be 'all'"%(targets))
564 564 elif isinstance(targets, int):
565 565 if targets < 0:
566 566 targets = self.ids[targets]
567 567 if targets not in self._ids:
568 568 raise IndexError("No such engine: %i"%targets)
569 569 targets = [targets]
570 570
571 571 if isinstance(targets, slice):
572 572 indices = list(range(len(self._ids))[targets])
573 573 ids = self.ids
574 574 targets = [ ids[i] for i in indices ]
575 575
576 576 if not isinstance(targets, (tuple, list, xrange)):
577 577 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
578 578
579 579 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
580 580
581 581 def _connect(self, sshserver, ssh_kwargs, timeout):
582 582 """setup all our socket connections to the cluster. This is called from
583 583 __init__."""
584 584
585 585 # Maybe allow reconnecting?
586 586 if self._connected:
587 587 return
588 588 self._connected=True
589 589
590 590 def connect_socket(s, url):
591 591 if self._ssh:
592 592 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
593 593 else:
594 594 return s.connect(url)
595 595
596 596 self.session.send(self._query_socket, 'connection_request')
597 597 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
598 598 poller = zmq.Poller()
599 599 poller.register(self._query_socket, zmq.POLLIN)
600 600 # poll expects milliseconds, timeout is seconds
601 601 evts = poller.poll(timeout*1000)
602 602 if not evts:
603 603 raise error.TimeoutError("Hub connection request timed out")
604 604 idents,msg = self.session.recv(self._query_socket,mode=0)
605 605 if self.debug:
606 606 pprint(msg)
607 607 content = msg['content']
608 608 # self._config['registration'] = dict(content)
609 609 cfg = self._config
610 610 if content['status'] == 'ok':
611 611 self._mux_socket = self._context.socket(zmq.DEALER)
612 612 connect_socket(self._mux_socket, cfg['mux'])
613 613
614 614 self._task_socket = self._context.socket(zmq.DEALER)
615 615 connect_socket(self._task_socket, cfg['task'])
616 616
617 617 self._notification_socket = self._context.socket(zmq.SUB)
618 618 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
619 619 connect_socket(self._notification_socket, cfg['notification'])
620 620
621 621 self._control_socket = self._context.socket(zmq.DEALER)
622 622 connect_socket(self._control_socket, cfg['control'])
623 623
624 624 self._iopub_socket = self._context.socket(zmq.SUB)
625 625 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
626 626 connect_socket(self._iopub_socket, cfg['iopub'])
627 627
628 628 self._update_engines(dict(content['engines']))
629 629 else:
630 630 self._connected = False
631 631 raise Exception("Failed to connect!")
632 632
633 633 #--------------------------------------------------------------------------
634 634 # handlers and callbacks for incoming messages
635 635 #--------------------------------------------------------------------------
636 636
637 637 def _unwrap_exception(self, content):
638 638 """unwrap exception, and remap engine_id to int."""
639 639 e = error.unwrap_exception(content)
640 640 # print e.traceback
641 641 if e.engine_info:
642 642 e_uuid = e.engine_info['engine_uuid']
643 643 eid = self._engines[e_uuid]
644 644 e.engine_info['engine_id'] = eid
645 645 return e
646 646
647 647 def _extract_metadata(self, msg):
648 648 header = msg['header']
649 649 parent = msg['parent_header']
650 650 msg_meta = msg['metadata']
651 651 content = msg['content']
652 652 md = {'msg_id' : parent['msg_id'],
653 653 'received' : datetime.now(),
654 654 'engine_uuid' : msg_meta.get('engine', None),
655 655 'follow' : msg_meta.get('follow', []),
656 656 'after' : msg_meta.get('after', []),
657 657 'status' : content['status'],
658 658 }
659 659
660 660 if md['engine_uuid'] is not None:
661 661 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
662 662
663 663 if 'date' in parent:
664 664 md['submitted'] = parent['date']
665 665 if 'started' in msg_meta:
666 666 md['started'] = parse_date(msg_meta['started'])
667 667 if 'date' in header:
668 668 md['completed'] = header['date']
669 669 return md
670 670
671 671 def _register_engine(self, msg):
672 672 """Register a new engine, and update our connection info."""
673 673 content = msg['content']
674 674 eid = content['id']
675 675 d = {eid : content['uuid']}
676 676 self._update_engines(d)
677 677
678 678 def _unregister_engine(self, msg):
679 679 """Unregister an engine that has died."""
680 680 content = msg['content']
681 681 eid = int(content['id'])
682 682 if eid in self._ids:
683 683 self._ids.remove(eid)
684 684 uuid = self._engines.pop(eid)
685 685
686 686 self._handle_stranded_msgs(eid, uuid)
687 687
688 688 if self._task_socket and self._task_scheme == 'pure':
689 689 self._stop_scheduling_tasks()
690 690
691 691 def _handle_stranded_msgs(self, eid, uuid):
692 692 """Handle messages known to be on an engine when the engine unregisters.
693 693
694 694 It is possible that this will fire prematurely - that is, an engine will
695 695 go down after completing a result, and the client will be notified
696 696 of the unregistration and later receive the successful result.
697 697 """
698 698
699 699 outstanding = self._outstanding_dict[uuid]
700 700
701 701 for msg_id in list(outstanding):
702 702 if msg_id in self.results:
703 703 # we already
704 704 continue
705 705 try:
706 706 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
707 707 except:
708 708 content = error.wrap_exception()
709 709 # build a fake message:
710 710 msg = self.session.msg('apply_reply', content=content)
711 711 msg['parent_header']['msg_id'] = msg_id
712 712 msg['metadata']['engine'] = uuid
713 713 self._handle_apply_reply(msg)
714 714
715 715 def _handle_execute_reply(self, msg):
716 716 """Save the reply to an execute_request into our results.
717 717
718 718 execute messages are never actually used. apply is used instead.
719 719 """
720 720
721 721 parent = msg['parent_header']
722 722 msg_id = parent['msg_id']
723 723 if msg_id not in self.outstanding:
724 724 if msg_id in self.history:
725 725 print("got stale result: %s"%msg_id)
726 726 else:
727 727 print("got unknown result: %s"%msg_id)
728 728 else:
729 729 self.outstanding.remove(msg_id)
730 730
731 731 content = msg['content']
732 732 header = msg['header']
733 733
734 734 # construct metadata:
735 735 md = self.metadata[msg_id]
736 736 md.update(self._extract_metadata(msg))
737 737 # is this redundant?
738 738 self.metadata[msg_id] = md
739 739
740 740 e_outstanding = self._outstanding_dict[md['engine_uuid']]
741 741 if msg_id in e_outstanding:
742 742 e_outstanding.remove(msg_id)
743 743
744 744 # construct result:
745 745 if content['status'] == 'ok':
746 746 self.results[msg_id] = ExecuteReply(msg_id, content, md)
747 747 elif content['status'] == 'aborted':
748 748 self.results[msg_id] = error.TaskAborted(msg_id)
749 749 elif content['status'] == 'resubmitted':
750 750 # TODO: handle resubmission
751 751 pass
752 752 else:
753 753 self.results[msg_id] = self._unwrap_exception(content)
754 754
755 755 def _handle_apply_reply(self, msg):
756 756 """Save the reply to an apply_request into our results."""
757 757 parent = msg['parent_header']
758 758 msg_id = parent['msg_id']
759 759 if msg_id not in self.outstanding:
760 760 if msg_id in self.history:
761 761 print("got stale result: %s"%msg_id)
762 762 print(self.results[msg_id])
763 763 print(msg)
764 764 else:
765 765 print("got unknown result: %s"%msg_id)
766 766 else:
767 767 self.outstanding.remove(msg_id)
768 768 content = msg['content']
769 769 header = msg['header']
770 770
771 771 # construct metadata:
772 772 md = self.metadata[msg_id]
773 773 md.update(self._extract_metadata(msg))
774 774 # is this redundant?
775 775 self.metadata[msg_id] = md
776 776
777 777 e_outstanding = self._outstanding_dict[md['engine_uuid']]
778 778 if msg_id in e_outstanding:
779 779 e_outstanding.remove(msg_id)
780 780
781 781 # construct result:
782 782 if content['status'] == 'ok':
783 783 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
784 784 elif content['status'] == 'aborted':
785 785 self.results[msg_id] = error.TaskAborted(msg_id)
786 786 elif content['status'] == 'resubmitted':
787 787 # TODO: handle resubmission
788 788 pass
789 789 else:
790 790 self.results[msg_id] = self._unwrap_exception(content)
791 791
792 792 def _flush_notifications(self):
793 793 """Flush notifications of engine registrations waiting
794 794 in ZMQ queue."""
795 795 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
796 796 while msg is not None:
797 797 if self.debug:
798 798 pprint(msg)
799 799 msg_type = msg['header']['msg_type']
800 800 handler = self._notification_handlers.get(msg_type, None)
801 801 if handler is None:
802 802 raise Exception("Unhandled message type: %s" % msg_type)
803 803 else:
804 804 handler(msg)
805 805 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
806 806
807 807 def _flush_results(self, sock):
808 808 """Flush task or queue results waiting in ZMQ queue."""
809 809 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
810 810 while msg is not None:
811 811 if self.debug:
812 812 pprint(msg)
813 813 msg_type = msg['header']['msg_type']
814 814 handler = self._queue_handlers.get(msg_type, None)
815 815 if handler is None:
816 816 raise Exception("Unhandled message type: %s" % msg_type)
817 817 else:
818 818 handler(msg)
819 819 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
820 820
821 821 def _flush_control(self, sock):
822 822 """Flush replies from the control channel waiting
823 823 in the ZMQ queue.
824 824
825 825 Currently: ignore them."""
826 826 if self._ignored_control_replies <= 0:
827 827 return
828 828 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
829 829 while msg is not None:
830 830 self._ignored_control_replies -= 1
831 831 if self.debug:
832 832 pprint(msg)
833 833 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
834 834
835 835 def _flush_ignored_control(self):
836 836 """flush ignored control replies"""
837 837 while self._ignored_control_replies > 0:
838 838 self.session.recv(self._control_socket)
839 839 self._ignored_control_replies -= 1
840 840
841 841 def _flush_ignored_hub_replies(self):
842 842 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
843 843 while msg is not None:
844 844 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
845 845
846 846 def _flush_iopub(self, sock):
847 847 """Flush replies from the iopub channel waiting
848 848 in the ZMQ queue.
849 849 """
850 850 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
851 851 while msg is not None:
852 852 if self.debug:
853 853 pprint(msg)
854 854 parent = msg['parent_header']
855 855 # ignore IOPub messages with no parent.
856 856 # Caused by print statements or warnings from before the first execution.
857 857 if not parent:
858 858 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
859 859 continue
860 860 msg_id = parent['msg_id']
861 861 content = msg['content']
862 862 header = msg['header']
863 863 msg_type = msg['header']['msg_type']
864 864
865 865 # init metadata:
866 866 md = self.metadata[msg_id]
867 867
868 868 if msg_type == 'stream':
869 869 name = content['name']
870 870 s = md[name] or ''
871 871 md[name] = s + content['data']
872 872 elif msg_type == 'error':
873 873 md.update({'error' : self._unwrap_exception(content)})
874 874 elif msg_type == 'execute_input':
875 875 md.update({'execute_input' : content['code']})
876 876 elif msg_type == 'display_data':
877 877 md['outputs'].append(content)
878 878 elif msg_type == 'execute_result':
879 879 md['execute_result'] = content
880 880 elif msg_type == 'data_message':
881 881 data, remainder = serialize.unserialize_object(msg['buffers'])
882 882 md['data'].update(data)
883 883 elif msg_type == 'status':
884 884 # idle message comes after all outputs
885 885 if content['execution_state'] == 'idle':
886 886 md['outputs_ready'] = True
887 887 else:
888 888 # unhandled msg_type (status, etc.)
889 889 pass
890 890
891 891 # reduntant?
892 892 self.metadata[msg_id] = md
893 893
894 894 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
895 895
896 896 #--------------------------------------------------------------------------
897 897 # len, getitem
898 898 #--------------------------------------------------------------------------
899 899
900 900 def __len__(self):
901 901 """len(client) returns # of engines."""
902 902 return len(self.ids)
903 903
904 904 def __getitem__(self, key):
905 905 """index access returns DirectView multiplexer objects
906 906
907 907 Must be int, slice, or list/tuple/xrange of ints"""
908 908 if not isinstance(key, (int, slice, tuple, list, xrange)):
909 909 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
910 910 else:
911 911 return self.direct_view(key)
912 912
913 913 def __iter__(self):
914 914 """Since we define getitem, Client is iterable
915 915
916 916 but unless we also define __iter__, it won't work correctly unless engine IDs
917 917 start at zero and are continuous.
918 918 """
919 919 for eid in self.ids:
920 920 yield self.direct_view(eid)
921 921
922 922 #--------------------------------------------------------------------------
923 923 # Begin public methods
924 924 #--------------------------------------------------------------------------
925 925
926 926 @property
927 927 def ids(self):
928 928 """Always up-to-date ids property."""
929 929 self._flush_notifications()
930 930 # always copy:
931 931 return list(self._ids)
932 932
933 933 def activate(self, targets='all', suffix=''):
934 934 """Create a DirectView and register it with IPython magics
935 935
936 936 Defines the magics `%px, %autopx, %pxresult, %%px`
937 937
938 938 Parameters
939 939 ----------
940 940
941 941 targets: int, list of ints, or 'all'
942 942 The engines on which the view's magics will run
943 943 suffix: str [default: '']
944 944 The suffix, if any, for the magics. This allows you to have
945 945 multiple views associated with parallel magics at the same time.
946 946
947 947 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
948 948 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
949 949 on engine 0.
950 950 """
951 951 view = self.direct_view(targets)
952 952 view.block = True
953 953 view.activate(suffix)
954 954 return view
955 955
956 956 def close(self, linger=None):
957 957 """Close my zmq Sockets
958 958
959 959 If `linger`, set the zmq LINGER socket option,
960 960 which allows discarding of messages.
961 961 """
962 962 if self._closed:
963 963 return
964 964 self.stop_spin_thread()
965 965 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
966 966 for name in snames:
967 967 socket = getattr(self, name)
968 968 if socket is not None and not socket.closed:
969 969 if linger is not None:
970 970 socket.close(linger=linger)
971 971 else:
972 972 socket.close()
973 973 self._closed = True
974 974
975 975 def _spin_every(self, interval=1):
976 976 """target func for use in spin_thread"""
977 977 while True:
978 978 if self._stop_spinning.is_set():
979 979 return
980 980 time.sleep(interval)
981 981 self.spin()
982 982
983 983 def spin_thread(self, interval=1):
984 984 """call Client.spin() in a background thread on some regular interval
985 985
986 986 This helps ensure that messages don't pile up too much in the zmq queue
987 987 while you are working on other things, or just leaving an idle terminal.
988 988
989 989 It also helps limit potential padding of the `received` timestamp
990 990 on AsyncResult objects, used for timings.
991 991
992 992 Parameters
993 993 ----------
994 994
995 995 interval : float, optional
996 996 The interval on which to spin the client in the background thread
997 997 (simply passed to time.sleep).
998 998
999 999 Notes
1000 1000 -----
1001 1001
1002 1002 For precision timing, you may want to use this method to put a bound
1003 1003 on the jitter (in seconds) in `received` timestamps used
1004 1004 in AsyncResult.wall_time.
1005 1005
1006 1006 """
1007 1007 if self._spin_thread is not None:
1008 1008 self.stop_spin_thread()
1009 1009 self._stop_spinning.clear()
1010 1010 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1011 1011 self._spin_thread.daemon = True
1012 1012 self._spin_thread.start()
1013 1013
1014 1014 def stop_spin_thread(self):
1015 1015 """stop background spin_thread, if any"""
1016 1016 if self._spin_thread is not None:
1017 1017 self._stop_spinning.set()
1018 1018 self._spin_thread.join()
1019 1019 self._spin_thread = None
1020 1020
1021 1021 def spin(self):
1022 1022 """Flush any registration notifications and execution results
1023 1023 waiting in the ZMQ queue.
1024 1024 """
1025 1025 if self._notification_socket:
1026 1026 self._flush_notifications()
1027 1027 if self._iopub_socket:
1028 1028 self._flush_iopub(self._iopub_socket)
1029 1029 if self._mux_socket:
1030 1030 self._flush_results(self._mux_socket)
1031 1031 if self._task_socket:
1032 1032 self._flush_results(self._task_socket)
1033 1033 if self._control_socket:
1034 1034 self._flush_control(self._control_socket)
1035 1035 if self._query_socket:
1036 1036 self._flush_ignored_hub_replies()
1037 1037
1038 1038 def wait(self, jobs=None, timeout=-1):
1039 1039 """waits on one or more `jobs`, for up to `timeout` seconds.
1040 1040
1041 1041 Parameters
1042 1042 ----------
1043 1043
1044 1044 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1045 1045 ints are indices to self.history
1046 1046 strs are msg_ids
1047 1047 default: wait on all outstanding messages
1048 1048 timeout : float
1049 1049 a time in seconds, after which to give up.
1050 1050 default is -1, which means no timeout
1051 1051
1052 1052 Returns
1053 1053 -------
1054 1054
1055 1055 True : when all msg_ids are done
1056 1056 False : timeout reached, some msg_ids still outstanding
1057 1057 """
1058 1058 tic = time.time()
1059 1059 if jobs is None:
1060 1060 theids = self.outstanding
1061 1061 else:
1062 1062 if isinstance(jobs, string_types + (int, AsyncResult)):
1063 1063 jobs = [jobs]
1064 1064 theids = set()
1065 1065 for job in jobs:
1066 1066 if isinstance(job, int):
1067 1067 # index access
1068 1068 job = self.history[job]
1069 1069 elif isinstance(job, AsyncResult):
1070 1070 theids.update(job.msg_ids)
1071 1071 continue
1072 1072 theids.add(job)
1073 1073 if not theids.intersection(self.outstanding):
1074 1074 return True
1075 1075 self.spin()
1076 1076 while theids.intersection(self.outstanding):
1077 1077 if timeout >= 0 and ( time.time()-tic ) > timeout:
1078 1078 break
1079 1079 time.sleep(1e-3)
1080 1080 self.spin()
1081 1081 return len(theids.intersection(self.outstanding)) == 0
1082 1082
1083 1083 #--------------------------------------------------------------------------
1084 1084 # Control methods
1085 1085 #--------------------------------------------------------------------------
1086 1086
1087 1087 @spin_first
1088 1088 def clear(self, targets=None, block=None):
1089 1089 """Clear the namespace in target(s)."""
1090 1090 block = self.block if block is None else block
1091 1091 targets = self._build_targets(targets)[0]
1092 1092 for t in targets:
1093 1093 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1094 1094 error = False
1095 1095 if block:
1096 1096 self._flush_ignored_control()
1097 1097 for i in range(len(targets)):
1098 1098 idents,msg = self.session.recv(self._control_socket,0)
1099 1099 if self.debug:
1100 1100 pprint(msg)
1101 1101 if msg['content']['status'] != 'ok':
1102 1102 error = self._unwrap_exception(msg['content'])
1103 1103 else:
1104 1104 self._ignored_control_replies += len(targets)
1105 1105 if error:
1106 1106 raise error
1107 1107
1108 1108
1109 1109 @spin_first
1110 1110 def abort(self, jobs=None, targets=None, block=None):
1111 1111 """Abort specific jobs from the execution queues of target(s).
1112 1112
1113 1113 This is a mechanism to prevent jobs that have already been submitted
1114 1114 from executing.
1115 1115
1116 1116 Parameters
1117 1117 ----------
1118 1118
1119 1119 jobs : msg_id, list of msg_ids, or AsyncResult
1120 1120 The jobs to be aborted
1121 1121
1122 1122 If unspecified/None: abort all outstanding jobs.
1123 1123
1124 1124 """
1125 1125 block = self.block if block is None else block
1126 1126 jobs = jobs if jobs is not None else list(self.outstanding)
1127 1127 targets = self._build_targets(targets)[0]
1128 1128
1129 1129 msg_ids = []
1130 1130 if isinstance(jobs, string_types + (AsyncResult,)):
1131 1131 jobs = [jobs]
1132 1132 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1133 1133 if bad_ids:
1134 1134 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1135 1135 for j in jobs:
1136 1136 if isinstance(j, AsyncResult):
1137 1137 msg_ids.extend(j.msg_ids)
1138 1138 else:
1139 1139 msg_ids.append(j)
1140 1140 content = dict(msg_ids=msg_ids)
1141 1141 for t in targets:
1142 1142 self.session.send(self._control_socket, 'abort_request',
1143 1143 content=content, ident=t)
1144 1144 error = False
1145 1145 if block:
1146 1146 self._flush_ignored_control()
1147 1147 for i in range(len(targets)):
1148 1148 idents,msg = self.session.recv(self._control_socket,0)
1149 1149 if self.debug:
1150 1150 pprint(msg)
1151 1151 if msg['content']['status'] != 'ok':
1152 1152 error = self._unwrap_exception(msg['content'])
1153 1153 else:
1154 1154 self._ignored_control_replies += len(targets)
1155 1155 if error:
1156 1156 raise error
1157 1157
1158 1158 @spin_first
1159 1159 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1160 1160 """Terminates one or more engine processes, optionally including the hub.
1161 1161
1162 1162 Parameters
1163 1163 ----------
1164 1164
1165 1165 targets: list of ints or 'all' [default: all]
1166 1166 Which engines to shutdown.
1167 1167 hub: bool [default: False]
1168 1168 Whether to include the Hub. hub=True implies targets='all'.
1169 1169 block: bool [default: self.block]
1170 1170 Whether to wait for clean shutdown replies or not.
1171 1171 restart: bool [default: False]
1172 1172 NOT IMPLEMENTED
1173 1173 whether to restart engines after shutting them down.
1174 1174 """
1175 1175 from IPython.parallel.error import NoEnginesRegistered
1176 1176 if restart:
1177 1177 raise NotImplementedError("Engine restart is not yet implemented")
1178 1178
1179 1179 block = self.block if block is None else block
1180 1180 if hub:
1181 1181 targets = 'all'
1182 1182 try:
1183 1183 targets = self._build_targets(targets)[0]
1184 1184 except NoEnginesRegistered:
1185 1185 targets = []
1186 1186 for t in targets:
1187 1187 self.session.send(self._control_socket, 'shutdown_request',
1188 1188 content={'restart':restart},ident=t)
1189 1189 error = False
1190 1190 if block or hub:
1191 1191 self._flush_ignored_control()
1192 1192 for i in range(len(targets)):
1193 1193 idents,msg = self.session.recv(self._control_socket, 0)
1194 1194 if self.debug:
1195 1195 pprint(msg)
1196 1196 if msg['content']['status'] != 'ok':
1197 1197 error = self._unwrap_exception(msg['content'])
1198 1198 else:
1199 1199 self._ignored_control_replies += len(targets)
1200 1200
1201 1201 if hub:
1202 1202 time.sleep(0.25)
1203 1203 self.session.send(self._query_socket, 'shutdown_request')
1204 1204 idents,msg = self.session.recv(self._query_socket, 0)
1205 1205 if self.debug:
1206 1206 pprint(msg)
1207 1207 if msg['content']['status'] != 'ok':
1208 1208 error = self._unwrap_exception(msg['content'])
1209 1209
1210 1210 if error:
1211 1211 raise error
1212 1212
1213 1213 #--------------------------------------------------------------------------
1214 1214 # Execution related methods
1215 1215 #--------------------------------------------------------------------------
1216 1216
1217 1217 def _maybe_raise(self, result):
1218 1218 """wrapper for maybe raising an exception if apply failed."""
1219 1219 if isinstance(result, error.RemoteError):
1220 1220 raise result
1221 1221
1222 1222 return result
1223 1223
1224 1224 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1225 1225 ident=None):
1226 1226 """construct and send an apply message via a socket.
1227 1227
1228 1228 This is the principal method with which all engine execution is performed by views.
1229 1229 """
1230 1230
1231 1231 if self._closed:
1232 1232 raise RuntimeError("Client cannot be used after its sockets have been closed")
1233 1233
1234 1234 # defaults:
1235 1235 args = args if args is not None else []
1236 1236 kwargs = kwargs if kwargs is not None else {}
1237 1237 metadata = metadata if metadata is not None else {}
1238 1238
1239 1239 # validate arguments
1240 1240 if not callable(f) and not isinstance(f, Reference):
1241 1241 raise TypeError("f must be callable, not %s"%type(f))
1242 1242 if not isinstance(args, (tuple, list)):
1243 1243 raise TypeError("args must be tuple or list, not %s"%type(args))
1244 1244 if not isinstance(kwargs, dict):
1245 1245 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1246 1246 if not isinstance(metadata, dict):
1247 1247 raise TypeError("metadata must be dict, not %s"%type(metadata))
1248 1248
1249 1249 bufs = serialize.pack_apply_message(f, args, kwargs,
1250 1250 buffer_threshold=self.session.buffer_threshold,
1251 1251 item_threshold=self.session.item_threshold,
1252 1252 )
1253 1253
1254 1254 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1255 1255 metadata=metadata, track=track)
1256 1256
1257 1257 msg_id = msg['header']['msg_id']
1258 1258 self.outstanding.add(msg_id)
1259 1259 if ident:
1260 1260 # possibly routed to a specific engine
1261 1261 if isinstance(ident, list):
1262 1262 ident = ident[-1]
1263 1263 if ident in self._engines.values():
1264 1264 # save for later, in case of engine death
1265 1265 self._outstanding_dict[ident].add(msg_id)
1266 1266 self.history.append(msg_id)
1267 1267 self.metadata[msg_id]['submitted'] = datetime.now()
1268 1268
1269 1269 return msg
1270 1270
1271 1271 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1272 1272 """construct and send an execute request via a socket.
1273 1273
1274 1274 """
1275 1275
1276 1276 if self._closed:
1277 1277 raise RuntimeError("Client cannot be used after its sockets have been closed")
1278 1278
1279 1279 # defaults:
1280 1280 metadata = metadata if metadata is not None else {}
1281 1281
1282 1282 # validate arguments
1283 1283 if not isinstance(code, string_types):
1284 1284 raise TypeError("code must be text, not %s" % type(code))
1285 1285 if not isinstance(metadata, dict):
1286 1286 raise TypeError("metadata must be dict, not %s" % type(metadata))
1287 1287
1288 1288 content = dict(code=code, silent=bool(silent), user_expressions={})
1289 1289
1290 1290
1291 1291 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1292 1292 metadata=metadata)
1293 1293
1294 1294 msg_id = msg['header']['msg_id']
1295 1295 self.outstanding.add(msg_id)
1296 1296 if ident:
1297 1297 # possibly routed to a specific engine
1298 1298 if isinstance(ident, list):
1299 1299 ident = ident[-1]
1300 1300 if ident in self._engines.values():
1301 1301 # save for later, in case of engine death
1302 1302 self._outstanding_dict[ident].add(msg_id)
1303 1303 self.history.append(msg_id)
1304 1304 self.metadata[msg_id]['submitted'] = datetime.now()
1305 1305
1306 1306 return msg
1307 1307
1308 1308 #--------------------------------------------------------------------------
1309 1309 # construct a View object
1310 1310 #--------------------------------------------------------------------------
1311 1311
1312 1312 def load_balanced_view(self, targets=None):
1313 1313 """construct a DirectView object.
1314 1314
1315 1315 If no arguments are specified, create a LoadBalancedView
1316 1316 using all engines.
1317 1317
1318 1318 Parameters
1319 1319 ----------
1320 1320
1321 1321 targets: list,slice,int,etc. [default: use all engines]
1322 1322 The subset of engines across which to load-balance
1323 1323 """
1324 1324 if targets == 'all':
1325 1325 targets = None
1326 1326 if targets is not None:
1327 1327 targets = self._build_targets(targets)[1]
1328 1328 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1329 1329
1330 1330 def direct_view(self, targets='all'):
1331 1331 """construct a DirectView object.
1332 1332
1333 1333 If no targets are specified, create a DirectView using all engines.
1334 1334
1335 1335 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1336 1336 evaluate the target engines at each execution, whereas rc[:] will connect to
1337 1337 all *current* engines, and that list will not change.
1338 1338
1339 1339 That is, 'all' will always use all engines, whereas rc[:] will not use
1340 1340 engines added after the DirectView is constructed.
1341 1341
1342 1342 Parameters
1343 1343 ----------
1344 1344
1345 1345 targets: list,slice,int,etc. [default: use all engines]
1346 1346 The engines to use for the View
1347 1347 """
1348 1348 single = isinstance(targets, int)
1349 1349 # allow 'all' to be lazily evaluated at each execution
1350 1350 if targets != 'all':
1351 1351 targets = self._build_targets(targets)[1]
1352 1352 if single:
1353 1353 targets = targets[0]
1354 1354 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1355 1355
1356 1356 #--------------------------------------------------------------------------
1357 1357 # Query methods
1358 1358 #--------------------------------------------------------------------------
1359 1359
1360 1360 @spin_first
1361 def get_result(self, indices_or_msg_ids=None, block=None):
1361 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
1362 1362 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1363 1363
1364 1364 If the client already has the results, no request to the Hub will be made.
1365 1365
1366 1366 This is a convenient way to construct AsyncResult objects, which are wrappers
1367 1367 that include metadata about execution, and allow for awaiting results that
1368 1368 were not submitted by this Client.
1369 1369
1370 1370 It can also be a convenient way to retrieve the metadata associated with
1371 1371 blocking execution, since it always retrieves
1372 1372
1373 1373 Examples
1374 1374 --------
1375 1375 ::
1376 1376
1377 1377 In [10]: r = client.apply()
1378 1378
1379 1379 Parameters
1380 1380 ----------
1381 1381
1382 1382 indices_or_msg_ids : integer history index, str msg_id, or list of either
1383 1383 The indices or msg_ids of indices to be retrieved
1384 1384
1385 1385 block : bool
1386 1386 Whether to wait for the result to be done
1387 owner : bool [default: True]
1388 Whether this AsyncResult should own the result.
1389 If so, calling `ar.get()` will remove data from the
1390 client's result and metadata cache.
1391 There should only be one owner of any given msg_id.
1387 1392
1388 1393 Returns
1389 1394 -------
1390 1395
1391 1396 AsyncResult
1392 1397 A single AsyncResult object will always be returned.
1393 1398
1394 1399 AsyncHubResult
1395 1400 A subclass of AsyncResult that retrieves results from the Hub
1396 1401
1397 1402 """
1398 1403 block = self.block if block is None else block
1399 1404 if indices_or_msg_ids is None:
1400 1405 indices_or_msg_ids = -1
1401 1406
1402 1407 single_result = False
1403 1408 if not isinstance(indices_or_msg_ids, (list,tuple)):
1404 1409 indices_or_msg_ids = [indices_or_msg_ids]
1405 1410 single_result = True
1406 1411
1407 1412 theids = []
1408 1413 for id in indices_or_msg_ids:
1409 1414 if isinstance(id, int):
1410 1415 id = self.history[id]
1411 1416 if not isinstance(id, string_types):
1412 1417 raise TypeError("indices must be str or int, not %r"%id)
1413 1418 theids.append(id)
1414 1419
1415 1420 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1416 1421 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1417 1422
1418 1423 # given single msg_id initially, get_result shot get the result itself,
1419 1424 # not a length-one list
1420 1425 if single_result:
1421 1426 theids = theids[0]
1422 1427
1423 1428 if remote_ids:
1424 ar = AsyncHubResult(self, msg_ids=theids)
1429 ar = AsyncHubResult(self, msg_ids=theids, owner=owner)
1425 1430 else:
1426 ar = AsyncResult(self, msg_ids=theids)
1431 ar = AsyncResult(self, msg_ids=theids, owner=owner)
1427 1432
1428 1433 if block:
1429 1434 ar.wait()
1430 1435
1431 1436 return ar
1432 1437
1433 1438 @spin_first
1434 1439 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1435 1440 """Resubmit one or more tasks.
1436 1441
1437 1442 in-flight tasks may not be resubmitted.
1438 1443
1439 1444 Parameters
1440 1445 ----------
1441 1446
1442 1447 indices_or_msg_ids : integer history index, str msg_id, or list of either
1443 1448 The indices or msg_ids of indices to be retrieved
1444 1449
1445 1450 block : bool
1446 1451 Whether to wait for the result to be done
1447 1452
1448 1453 Returns
1449 1454 -------
1450 1455
1451 1456 AsyncHubResult
1452 1457 A subclass of AsyncResult that retrieves results from the Hub
1453 1458
1454 1459 """
1455 1460 block = self.block if block is None else block
1456 1461 if indices_or_msg_ids is None:
1457 1462 indices_or_msg_ids = -1
1458 1463
1459 1464 if not isinstance(indices_or_msg_ids, (list,tuple)):
1460 1465 indices_or_msg_ids = [indices_or_msg_ids]
1461 1466
1462 1467 theids = []
1463 1468 for id in indices_or_msg_ids:
1464 1469 if isinstance(id, int):
1465 1470 id = self.history[id]
1466 1471 if not isinstance(id, string_types):
1467 1472 raise TypeError("indices must be str or int, not %r"%id)
1468 1473 theids.append(id)
1469 1474
1470 1475 content = dict(msg_ids = theids)
1471 1476
1472 1477 self.session.send(self._query_socket, 'resubmit_request', content)
1473 1478
1474 1479 zmq.select([self._query_socket], [], [])
1475 1480 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1476 1481 if self.debug:
1477 1482 pprint(msg)
1478 1483 content = msg['content']
1479 1484 if content['status'] != 'ok':
1480 1485 raise self._unwrap_exception(content)
1481 1486 mapping = content['resubmitted']
1482 1487 new_ids = [ mapping[msg_id] for msg_id in theids ]
1483 1488
1484 1489 ar = AsyncHubResult(self, msg_ids=new_ids)
1485 1490
1486 1491 if block:
1487 1492 ar.wait()
1488 1493
1489 1494 return ar
1490 1495
1491 1496 @spin_first
1492 1497 def result_status(self, msg_ids, status_only=True):
1493 1498 """Check on the status of the result(s) of the apply request with `msg_ids`.
1494 1499
1495 1500 If status_only is False, then the actual results will be retrieved, else
1496 1501 only the status of the results will be checked.
1497 1502
1498 1503 Parameters
1499 1504 ----------
1500 1505
1501 1506 msg_ids : list of msg_ids
1502 1507 if int:
1503 1508 Passed as index to self.history for convenience.
1504 1509 status_only : bool (default: True)
1505 1510 if False:
1506 1511 Retrieve the actual results of completed tasks.
1507 1512
1508 1513 Returns
1509 1514 -------
1510 1515
1511 1516 results : dict
1512 1517 There will always be the keys 'pending' and 'completed', which will
1513 1518 be lists of msg_ids that are incomplete or complete. If `status_only`
1514 1519 is False, then completed results will be keyed by their `msg_id`.
1515 1520 """
1516 1521 if not isinstance(msg_ids, (list,tuple)):
1517 1522 msg_ids = [msg_ids]
1518 1523
1519 1524 theids = []
1520 1525 for msg_id in msg_ids:
1521 1526 if isinstance(msg_id, int):
1522 1527 msg_id = self.history[msg_id]
1523 1528 if not isinstance(msg_id, string_types):
1524 1529 raise TypeError("msg_ids must be str, not %r"%msg_id)
1525 1530 theids.append(msg_id)
1526 1531
1527 1532 completed = []
1528 1533 local_results = {}
1529 1534
1530 1535 # comment this block out to temporarily disable local shortcut:
1531 1536 for msg_id in theids:
1532 1537 if msg_id in self.results:
1533 1538 completed.append(msg_id)
1534 1539 local_results[msg_id] = self.results[msg_id]
1535 1540 theids.remove(msg_id)
1536 1541
1537 1542 if theids: # some not locally cached
1538 1543 content = dict(msg_ids=theids, status_only=status_only)
1539 1544 msg = self.session.send(self._query_socket, "result_request", content=content)
1540 1545 zmq.select([self._query_socket], [], [])
1541 1546 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1542 1547 if self.debug:
1543 1548 pprint(msg)
1544 1549 content = msg['content']
1545 1550 if content['status'] != 'ok':
1546 1551 raise self._unwrap_exception(content)
1547 1552 buffers = msg['buffers']
1548 1553 else:
1549 1554 content = dict(completed=[],pending=[])
1550 1555
1551 1556 content['completed'].extend(completed)
1552 1557
1553 1558 if status_only:
1554 1559 return content
1555 1560
1556 1561 failures = []
1557 1562 # load cached results into result:
1558 1563 content.update(local_results)
1559 1564
1560 1565 # update cache with results:
1561 1566 for msg_id in sorted(theids):
1562 1567 if msg_id in content['completed']:
1563 1568 rec = content[msg_id]
1564 1569 parent = extract_dates(rec['header'])
1565 1570 header = extract_dates(rec['result_header'])
1566 1571 rcontent = rec['result_content']
1567 1572 iodict = rec['io']
1568 1573 if isinstance(rcontent, str):
1569 1574 rcontent = self.session.unpack(rcontent)
1570 1575
1571 1576 md = self.metadata[msg_id]
1572 1577 md_msg = dict(
1573 1578 content=rcontent,
1574 1579 parent_header=parent,
1575 1580 header=header,
1576 1581 metadata=rec['result_metadata'],
1577 1582 )
1578 1583 md.update(self._extract_metadata(md_msg))
1579 1584 if rec.get('received'):
1580 1585 md['received'] = parse_date(rec['received'])
1581 1586 md.update(iodict)
1582 1587
1583 1588 if rcontent['status'] == 'ok':
1584 1589 if header['msg_type'] == 'apply_reply':
1585 1590 res,buffers = serialize.unserialize_object(buffers)
1586 1591 elif header['msg_type'] == 'execute_reply':
1587 1592 res = ExecuteReply(msg_id, rcontent, md)
1588 1593 else:
1589 1594 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1590 1595 else:
1591 1596 res = self._unwrap_exception(rcontent)
1592 1597 failures.append(res)
1593 1598
1594 1599 self.results[msg_id] = res
1595 1600 content[msg_id] = res
1596 1601
1597 1602 if len(theids) == 1 and failures:
1598 1603 raise failures[0]
1599 1604
1600 1605 error.collect_exceptions(failures, "result_status")
1601 1606 return content
1602 1607
1603 1608 @spin_first
1604 1609 def queue_status(self, targets='all', verbose=False):
1605 1610 """Fetch the status of engine queues.
1606 1611
1607 1612 Parameters
1608 1613 ----------
1609 1614
1610 1615 targets : int/str/list of ints/strs
1611 1616 the engines whose states are to be queried.
1612 1617 default : all
1613 1618 verbose : bool
1614 1619 Whether to return lengths only, or lists of ids for each element
1615 1620 """
1616 1621 if targets == 'all':
1617 1622 # allow 'all' to be evaluated on the engine
1618 1623 engine_ids = None
1619 1624 else:
1620 1625 engine_ids = self._build_targets(targets)[1]
1621 1626 content = dict(targets=engine_ids, verbose=verbose)
1622 1627 self.session.send(self._query_socket, "queue_request", content=content)
1623 1628 idents,msg = self.session.recv(self._query_socket, 0)
1624 1629 if self.debug:
1625 1630 pprint(msg)
1626 1631 content = msg['content']
1627 1632 status = content.pop('status')
1628 1633 if status != 'ok':
1629 1634 raise self._unwrap_exception(content)
1630 1635 content = rekey(content)
1631 1636 if isinstance(targets, int):
1632 1637 return content[targets]
1633 1638 else:
1634 1639 return content
1635 1640
1636 1641 def _build_msgids_from_target(self, targets=None):
1637 1642 """Build a list of msg_ids from the list of engine targets"""
1638 1643 if not targets: # needed as _build_targets otherwise uses all engines
1639 1644 return []
1640 1645 target_ids = self._build_targets(targets)[0]
1641 1646 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1642 1647
1643 1648 def _build_msgids_from_jobs(self, jobs=None):
1644 1649 """Build a list of msg_ids from "jobs" """
1645 1650 if not jobs:
1646 1651 return []
1647 1652 msg_ids = []
1648 1653 if isinstance(jobs, string_types + (AsyncResult,)):
1649 1654 jobs = [jobs]
1650 1655 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1651 1656 if bad_ids:
1652 1657 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1653 1658 for j in jobs:
1654 1659 if isinstance(j, AsyncResult):
1655 1660 msg_ids.extend(j.msg_ids)
1656 1661 else:
1657 1662 msg_ids.append(j)
1658 1663 return msg_ids
1659 1664
1660 1665 def purge_local_results(self, jobs=[], targets=[]):
1661 1666 """Clears the client caches of results and their metadata.
1662 1667
1663 1668 Individual results can be purged by msg_id, or the entire
1664 1669 history of specific targets can be purged.
1665 1670
1666 1671 Use `purge_local_results('all')` to scrub everything from the Clients's
1667 1672 results and metadata caches.
1668 1673
1669 1674 After this call all `AsyncResults` are invalid and should be discarded.
1670 1675
1671 1676 If you must "reget" the results, you can still do so by using
1672 1677 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1673 1678 redownload the results from the hub if they are still available
1674 1679 (i.e `client.purge_hub_results(...)` has not been called.
1675 1680
1676 1681 Parameters
1677 1682 ----------
1678 1683
1679 1684 jobs : str or list of str or AsyncResult objects
1680 1685 the msg_ids whose results should be purged.
1681 1686 targets : int/list of ints
1682 1687 The engines, by integer ID, whose entire result histories are to be purged.
1683 1688
1684 1689 Raises
1685 1690 ------
1686 1691
1687 1692 RuntimeError : if any of the tasks to be purged are still outstanding.
1688 1693
1689 1694 """
1690 1695 if not targets and not jobs:
1691 1696 raise ValueError("Must specify at least one of `targets` and `jobs`")
1692 1697
1693 1698 if jobs == 'all':
1694 1699 if self.outstanding:
1695 1700 raise RuntimeError("Can't purge outstanding tasks: %s" % self.outstanding)
1696 1701 self.results.clear()
1697 1702 self.metadata.clear()
1698 1703 else:
1699 1704 msg_ids = set()
1700 1705 msg_ids.update(self._build_msgids_from_target(targets))
1701 1706 msg_ids.update(self._build_msgids_from_jobs(jobs))
1702 1707 still_outstanding = self.outstanding.intersection(msg_ids)
1703 1708 if still_outstanding:
1704 1709 raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding)
1705 1710 for mid in msg_ids:
1706 self.results.pop(mid)
1707 self.metadata.pop(mid)
1711 self.results.pop(mid, None)
1712 self.metadata.pop(mid, None)
1708 1713
1709 1714
1710 1715 @spin_first
1711 1716 def purge_hub_results(self, jobs=[], targets=[]):
1712 1717 """Tell the Hub to forget results.
1713 1718
1714 1719 Individual results can be purged by msg_id, or the entire
1715 1720 history of specific targets can be purged.
1716 1721
1717 1722 Use `purge_results('all')` to scrub everything from the Hub's db.
1718 1723
1719 1724 Parameters
1720 1725 ----------
1721 1726
1722 1727 jobs : str or list of str or AsyncResult objects
1723 1728 the msg_ids whose results should be forgotten.
1724 1729 targets : int/str/list of ints/strs
1725 1730 The targets, by int_id, whose entire history is to be purged.
1726 1731
1727 1732 default : None
1728 1733 """
1729 1734 if not targets and not jobs:
1730 1735 raise ValueError("Must specify at least one of `targets` and `jobs`")
1731 1736 if targets:
1732 1737 targets = self._build_targets(targets)[1]
1733 1738
1734 1739 # construct msg_ids from jobs
1735 1740 if jobs == 'all':
1736 1741 msg_ids = jobs
1737 1742 else:
1738 1743 msg_ids = self._build_msgids_from_jobs(jobs)
1739 1744
1740 1745 content = dict(engine_ids=targets, msg_ids=msg_ids)
1741 1746 self.session.send(self._query_socket, "purge_request", content=content)
1742 1747 idents, msg = self.session.recv(self._query_socket, 0)
1743 1748 if self.debug:
1744 1749 pprint(msg)
1745 1750 content = msg['content']
1746 1751 if content['status'] != 'ok':
1747 1752 raise self._unwrap_exception(content)
1748 1753
1749 1754 def purge_results(self, jobs=[], targets=[]):
1750 1755 """Clears the cached results from both the hub and the local client
1751 1756
1752 1757 Individual results can be purged by msg_id, or the entire
1753 1758 history of specific targets can be purged.
1754 1759
1755 1760 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1756 1761 the Client's db.
1757 1762
1758 1763 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1759 1764 the same arguments.
1760 1765
1761 1766 Parameters
1762 1767 ----------
1763 1768
1764 1769 jobs : str or list of str or AsyncResult objects
1765 1770 the msg_ids whose results should be forgotten.
1766 1771 targets : int/str/list of ints/strs
1767 1772 The targets, by int_id, whose entire history is to be purged.
1768 1773
1769 1774 default : None
1770 1775 """
1771 1776 self.purge_local_results(jobs=jobs, targets=targets)
1772 1777 self.purge_hub_results(jobs=jobs, targets=targets)
1773 1778
1774 1779 def purge_everything(self):
1775 1780 """Clears all content from previous Tasks from both the hub and the local client
1776 1781
1777 1782 In addition to calling `purge_results("all")` it also deletes the history and
1778 1783 other bookkeeping lists.
1779 1784 """
1780 1785 self.purge_results("all")
1781 1786 self.history = []
1782 1787 self.session.digest_history.clear()
1783 1788
1784 1789 @spin_first
1785 1790 def hub_history(self):
1786 1791 """Get the Hub's history
1787 1792
1788 1793 Just like the Client, the Hub has a history, which is a list of msg_ids.
1789 1794 This will contain the history of all clients, and, depending on configuration,
1790 1795 may contain history across multiple cluster sessions.
1791 1796
1792 1797 Any msg_id returned here is a valid argument to `get_result`.
1793 1798
1794 1799 Returns
1795 1800 -------
1796 1801
1797 1802 msg_ids : list of strs
1798 1803 list of all msg_ids, ordered by task submission time.
1799 1804 """
1800 1805
1801 1806 self.session.send(self._query_socket, "history_request", content={})
1802 1807 idents, msg = self.session.recv(self._query_socket, 0)
1803 1808
1804 1809 if self.debug:
1805 1810 pprint(msg)
1806 1811 content = msg['content']
1807 1812 if content['status'] != 'ok':
1808 1813 raise self._unwrap_exception(content)
1809 1814 else:
1810 1815 return content['history']
1811 1816
1812 1817 @spin_first
1813 1818 def db_query(self, query, keys=None):
1814 1819 """Query the Hub's TaskRecord database
1815 1820
1816 1821 This will return a list of task record dicts that match `query`
1817 1822
1818 1823 Parameters
1819 1824 ----------
1820 1825
1821 1826 query : mongodb query dict
1822 1827 The search dict. See mongodb query docs for details.
1823 1828 keys : list of strs [optional]
1824 1829 The subset of keys to be returned. The default is to fetch everything but buffers.
1825 1830 'msg_id' will *always* be included.
1826 1831 """
1827 1832 if isinstance(keys, string_types):
1828 1833 keys = [keys]
1829 1834 content = dict(query=query, keys=keys)
1830 1835 self.session.send(self._query_socket, "db_request", content=content)
1831 1836 idents, msg = self.session.recv(self._query_socket, 0)
1832 1837 if self.debug:
1833 1838 pprint(msg)
1834 1839 content = msg['content']
1835 1840 if content['status'] != 'ok':
1836 1841 raise self._unwrap_exception(content)
1837 1842
1838 1843 records = content['records']
1839 1844
1840 1845 buffer_lens = content['buffer_lens']
1841 1846 result_buffer_lens = content['result_buffer_lens']
1842 1847 buffers = msg['buffers']
1843 1848 has_bufs = buffer_lens is not None
1844 1849 has_rbufs = result_buffer_lens is not None
1845 1850 for i,rec in enumerate(records):
1846 1851 # unpack datetime objects
1847 1852 for hkey in ('header', 'result_header'):
1848 1853 if hkey in rec:
1849 1854 rec[hkey] = extract_dates(rec[hkey])
1850 1855 for dtkey in ('submitted', 'started', 'completed', 'received'):
1851 1856 if dtkey in rec:
1852 1857 rec[dtkey] = parse_date(rec[dtkey])
1853 1858 # relink buffers
1854 1859 if has_bufs:
1855 1860 blen = buffer_lens[i]
1856 1861 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1857 1862 if has_rbufs:
1858 1863 blen = result_buffer_lens[i]
1859 1864 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1860 1865
1861 1866 return records
1862 1867
1863 1868 __all__ = [ 'Client' ]
@@ -1,1131 +1,1125 b''
1 """Views of remote engines.
1 """Views of remote engines."""
2 2
3 Authors:
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 5
5 * Min RK
6 """
7 6 from __future__ import print_function
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
14
15 #-----------------------------------------------------------------------------
16 # Imports
17 #-----------------------------------------------------------------------------
18 7
19 8 import imp
20 9 import sys
21 10 import warnings
22 11 from contextlib import contextmanager
23 12 from types import ModuleType
24 13
25 14 import zmq
26 15
27 16 from IPython.testing.skipdoctest import skip_doctest
28 17 from IPython.utils import pickleutil
29 18 from IPython.utils.traitlets import (
30 19 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
31 20 )
32 21 from IPython.external.decorator import decorator
33 22
34 23 from IPython.parallel import util
35 24 from IPython.parallel.controller.dependency import Dependency, dependent
36 25 from IPython.utils.py3compat import string_types, iteritems, PY3
37 26
38 27 from . import map as Map
39 28 from .asyncresult import AsyncResult, AsyncMapResult
40 29 from .remotefunction import ParallelFunction, parallel, remote, getname
41 30
42 31 #-----------------------------------------------------------------------------
43 32 # Decorators
44 33 #-----------------------------------------------------------------------------
45 34
46 35 @decorator
47 36 def save_ids(f, self, *args, **kwargs):
48 37 """Keep our history and outstanding attributes up to date after a method call."""
49 38 n_previous = len(self.client.history)
50 39 try:
51 40 ret = f(self, *args, **kwargs)
52 41 finally:
53 42 nmsgs = len(self.client.history) - n_previous
54 43 msg_ids = self.client.history[-nmsgs:]
55 44 self.history.extend(msg_ids)
56 45 self.outstanding.update(msg_ids)
57 46 return ret
58 47
59 48 @decorator
60 49 def sync_results(f, self, *args, **kwargs):
61 50 """sync relevant results from self.client to our results attribute."""
62 51 if self._in_sync_results:
63 52 return f(self, *args, **kwargs)
64 53 self._in_sync_results = True
65 54 try:
66 55 ret = f(self, *args, **kwargs)
67 56 finally:
68 57 self._in_sync_results = False
69 58 self._sync_results()
70 59 return ret
71 60
72 61 @decorator
73 62 def spin_after(f, self, *args, **kwargs):
74 63 """call spin after the method."""
75 64 ret = f(self, *args, **kwargs)
76 65 self.spin()
77 66 return ret
78 67
79 68 #-----------------------------------------------------------------------------
80 69 # Classes
81 70 #-----------------------------------------------------------------------------
82 71
83 72 @skip_doctest
84 73 class View(HasTraits):
85 74 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
86 75
87 76 Don't use this class, use subclasses.
88 77
89 78 Methods
90 79 -------
91 80
92 81 spin
93 82 flushes incoming results and registration state changes
94 83 control methods spin, and requesting `ids` also ensures up to date
95 84
96 85 wait
97 86 wait on one or more msg_ids
98 87
99 88 execution methods
100 89 apply
101 90 legacy: execute, run
102 91
103 92 data movement
104 93 push, pull, scatter, gather
105 94
106 95 query methods
107 96 get_result, queue_status, purge_results, result_status
108 97
109 98 control methods
110 99 abort, shutdown
111 100
112 101 """
113 102 # flags
114 103 block=Bool(False)
115 104 track=Bool(True)
116 105 targets = Any()
117 106
118 107 history=List()
119 108 outstanding = Set()
120 109 results = Dict()
121 110 client = Instance('IPython.parallel.Client')
122 111
123 112 _socket = Instance('zmq.Socket')
124 113 _flag_names = List(['targets', 'block', 'track'])
125 114 _in_sync_results = Bool(False)
126 115 _targets = Any()
127 116 _idents = Any()
128 117
129 118 def __init__(self, client=None, socket=None, **flags):
130 119 super(View, self).__init__(client=client, _socket=socket)
131 120 self.results = client.results
132 121 self.block = client.block
133 122
134 123 self.set_flags(**flags)
135 124
136 125 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
137 126
138 127 def __repr__(self):
139 128 strtargets = str(self.targets)
140 129 if len(strtargets) > 16:
141 130 strtargets = strtargets[:12]+'...]'
142 131 return "<%s %s>"%(self.__class__.__name__, strtargets)
143 132
144 133 def __len__(self):
145 134 if isinstance(self.targets, list):
146 135 return len(self.targets)
147 136 elif isinstance(self.targets, int):
148 137 return 1
149 138 else:
150 139 return len(self.client)
151 140
152 141 def set_flags(self, **kwargs):
153 142 """set my attribute flags by keyword.
154 143
155 144 Views determine behavior with a few attributes (`block`, `track`, etc.).
156 145 These attributes can be set all at once by name with this method.
157 146
158 147 Parameters
159 148 ----------
160 149
161 150 block : bool
162 151 whether to wait for results
163 152 track : bool
164 153 whether to create a MessageTracker to allow the user to
165 154 safely edit after arrays and buffers during non-copying
166 155 sends.
167 156 """
168 157 for name, value in iteritems(kwargs):
169 158 if name not in self._flag_names:
170 159 raise KeyError("Invalid name: %r"%name)
171 160 else:
172 161 setattr(self, name, value)
173 162
174 163 @contextmanager
175 164 def temp_flags(self, **kwargs):
176 165 """temporarily set flags, for use in `with` statements.
177 166
178 167 See set_flags for permanent setting of flags
179 168
180 169 Examples
181 170 --------
182 171
183 172 >>> view.track=False
184 173 ...
185 174 >>> with view.temp_flags(track=True):
186 175 ... ar = view.apply(dostuff, my_big_array)
187 176 ... ar.tracker.wait() # wait for send to finish
188 177 >>> view.track
189 178 False
190 179
191 180 """
192 181 # preflight: save flags, and set temporaries
193 182 saved_flags = {}
194 183 for f in self._flag_names:
195 184 saved_flags[f] = getattr(self, f)
196 185 self.set_flags(**kwargs)
197 186 # yield to the with-statement block
198 187 try:
199 188 yield
200 189 finally:
201 190 # postflight: restore saved flags
202 191 self.set_flags(**saved_flags)
203 192
204 193
205 194 #----------------------------------------------------------------
206 195 # apply
207 196 #----------------------------------------------------------------
208 197
209 198 def _sync_results(self):
210 199 """to be called by @sync_results decorator
211 200
212 201 after submitting any tasks.
213 202 """
214 203 delta = self.outstanding.difference(self.client.outstanding)
215 204 completed = self.outstanding.intersection(delta)
216 205 self.outstanding = self.outstanding.difference(completed)
217 206
218 207 @sync_results
219 208 @save_ids
220 209 def _really_apply(self, f, args, kwargs, block=None, **options):
221 210 """wrapper for client.send_apply_request"""
222 211 raise NotImplementedError("Implement in subclasses")
223 212
224 213 def apply(self, f, *args, **kwargs):
225 214 """calls ``f(*args, **kwargs)`` on remote engines, returning the result.
226 215
227 216 This method sets all apply flags via this View's attributes.
228 217
229 218 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult`
230 219 instance if ``self.block`` is False, otherwise the return value of
231 220 ``f(*args, **kwargs)``.
232 221 """
233 222 return self._really_apply(f, args, kwargs)
234 223
235 224 def apply_async(self, f, *args, **kwargs):
236 225 """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.
237 226
238 227 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance.
239 228 """
240 229 return self._really_apply(f, args, kwargs, block=False)
241 230
242 231 @spin_after
243 232 def apply_sync(self, f, *args, **kwargs):
244 233 """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
245 234 returning the result.
246 235 """
247 236 return self._really_apply(f, args, kwargs, block=True)
248 237
249 238 #----------------------------------------------------------------
250 239 # wrappers for client and control methods
251 240 #----------------------------------------------------------------
252 241 @sync_results
253 242 def spin(self):
254 243 """spin the client, and sync"""
255 244 self.client.spin()
256 245
257 246 @sync_results
258 247 def wait(self, jobs=None, timeout=-1):
259 248 """waits on one or more `jobs`, for up to `timeout` seconds.
260 249
261 250 Parameters
262 251 ----------
263 252
264 253 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
265 254 ints are indices to self.history
266 255 strs are msg_ids
267 256 default: wait on all outstanding messages
268 257 timeout : float
269 258 a time in seconds, after which to give up.
270 259 default is -1, which means no timeout
271 260
272 261 Returns
273 262 -------
274 263
275 264 True : when all msg_ids are done
276 265 False : timeout reached, some msg_ids still outstanding
277 266 """
278 267 if jobs is None:
279 268 jobs = self.history
280 269 return self.client.wait(jobs, timeout)
281 270
282 271 def abort(self, jobs=None, targets=None, block=None):
283 272 """Abort jobs on my engines.
284 273
285 274 Parameters
286 275 ----------
287 276
288 277 jobs : None, str, list of strs, optional
289 278 if None: abort all jobs.
290 279 else: abort specific msg_id(s).
291 280 """
292 281 block = block if block is not None else self.block
293 282 targets = targets if targets is not None else self.targets
294 283 jobs = jobs if jobs is not None else list(self.outstanding)
295 284
296 285 return self.client.abort(jobs=jobs, targets=targets, block=block)
297 286
298 287 def queue_status(self, targets=None, verbose=False):
299 288 """Fetch the Queue status of my engines"""
300 289 targets = targets if targets is not None else self.targets
301 290 return self.client.queue_status(targets=targets, verbose=verbose)
302 291
303 292 def purge_results(self, jobs=[], targets=[]):
304 293 """Instruct the controller to forget specific results."""
305 294 if targets is None or targets == 'all':
306 295 targets = self.targets
307 296 return self.client.purge_results(jobs=jobs, targets=targets)
308 297
309 298 def shutdown(self, targets=None, restart=False, hub=False, block=None):
310 299 """Terminates one or more engine processes, optionally including the hub.
311 300 """
312 301 block = self.block if block is None else block
313 302 if targets is None or targets == 'all':
314 303 targets = self.targets
315 304 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
316 305
317 306 @spin_after
318 def get_result(self, indices_or_msg_ids=None):
307 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
319 308 """return one or more results, specified by history index or msg_id.
320 309
321 310 See :meth:`IPython.parallel.client.client.Client.get_result` for details.
322 311 """
323 312
324 313 if indices_or_msg_ids is None:
325 314 indices_or_msg_ids = -1
326 315 if isinstance(indices_or_msg_ids, int):
327 316 indices_or_msg_ids = self.history[indices_or_msg_ids]
328 317 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
329 318 indices_or_msg_ids = list(indices_or_msg_ids)
330 319 for i,index in enumerate(indices_or_msg_ids):
331 320 if isinstance(index, int):
332 321 indices_or_msg_ids[i] = self.history[index]
333 return self.client.get_result(indices_or_msg_ids)
322 return self.client.get_result(indices_or_msg_ids, block=block, owner=owner)
334 323
335 324 #-------------------------------------------------------------------
336 325 # Map
337 326 #-------------------------------------------------------------------
338 327
339 328 @sync_results
340 329 def map(self, f, *sequences, **kwargs):
341 330 """override in subclasses"""
342 331 raise NotImplementedError
343 332
344 333 def map_async(self, f, *sequences, **kwargs):
345 334 """Parallel version of builtin :func:`python:map`, using this view's engines.
346 335
347 336 This is equivalent to ``map(...block=False)``.
348 337
349 338 See `self.map` for details.
350 339 """
351 340 if 'block' in kwargs:
352 341 raise TypeError("map_async doesn't take a `block` keyword argument.")
353 342 kwargs['block'] = False
354 343 return self.map(f,*sequences,**kwargs)
355 344
356 345 def map_sync(self, f, *sequences, **kwargs):
357 346 """Parallel version of builtin :func:`python:map`, using this view's engines.
358 347
359 348 This is equivalent to ``map(...block=True)``.
360 349
361 350 See `self.map` for details.
362 351 """
363 352 if 'block' in kwargs:
364 353 raise TypeError("map_sync doesn't take a `block` keyword argument.")
365 354 kwargs['block'] = True
366 355 return self.map(f,*sequences,**kwargs)
367 356
368 357 def imap(self, f, *sequences, **kwargs):
369 358 """Parallel version of :func:`itertools.imap`.
370 359
371 360 See `self.map` for details.
372 361
373 362 """
374 363
375 364 return iter(self.map_async(f,*sequences, **kwargs))
376 365
377 366 #-------------------------------------------------------------------
378 367 # Decorators
379 368 #-------------------------------------------------------------------
380 369
381 370 def remote(self, block=None, **flags):
382 371 """Decorator for making a RemoteFunction"""
383 372 block = self.block if block is None else block
384 373 return remote(self, block=block, **flags)
385 374
386 375 def parallel(self, dist='b', block=None, **flags):
387 376 """Decorator for making a ParallelFunction"""
388 377 block = self.block if block is None else block
389 378 return parallel(self, dist=dist, block=block, **flags)
390 379
391 380 @skip_doctest
392 381 class DirectView(View):
393 382 """Direct Multiplexer View of one or more engines.
394 383
395 384 These are created via indexed access to a client:
396 385
397 386 >>> dv_1 = client[1]
398 387 >>> dv_all = client[:]
399 388 >>> dv_even = client[::2]
400 389 >>> dv_some = client[1:3]
401 390
402 391 This object provides dictionary access to engine namespaces:
403 392
404 393 # push a=5:
405 394 >>> dv['a'] = 5
406 395 # pull 'foo':
407 396 >>> dv['foo']
408 397
409 398 """
410 399
411 400 def __init__(self, client=None, socket=None, targets=None):
412 401 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
413 402
414 403 @property
415 404 def importer(self):
416 405 """sync_imports(local=True) as a property.
417 406
418 407 See sync_imports for details.
419 408
420 409 """
421 410 return self.sync_imports(True)
422 411
423 412 @contextmanager
424 413 def sync_imports(self, local=True, quiet=False):
425 414 """Context Manager for performing simultaneous local and remote imports.
426 415
427 416 'import x as y' will *not* work. The 'as y' part will simply be ignored.
428 417
429 418 If `local=True`, then the package will also be imported locally.
430 419
431 420 If `quiet=True`, no output will be produced when attempting remote
432 421 imports.
433 422
434 423 Note that remote-only (`local=False`) imports have not been implemented.
435 424
436 425 >>> with view.sync_imports():
437 426 ... from numpy import recarray
438 427 importing recarray from numpy on engine(s)
439 428
440 429 """
441 430 from IPython.utils.py3compat import builtin_mod
442 431 local_import = builtin_mod.__import__
443 432 modules = set()
444 433 results = []
445 434 @util.interactive
446 435 def remote_import(name, fromlist, level):
447 436 """the function to be passed to apply, that actually performs the import
448 437 on the engine, and loads up the user namespace.
449 438 """
450 439 import sys
451 440 user_ns = globals()
452 441 mod = __import__(name, fromlist=fromlist, level=level)
453 442 if fromlist:
454 443 for key in fromlist:
455 444 user_ns[key] = getattr(mod, key)
456 445 else:
457 446 user_ns[name] = sys.modules[name]
458 447
459 448 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
460 449 """the drop-in replacement for __import__, that optionally imports
461 450 locally as well.
462 451 """
463 452 # don't override nested imports
464 453 save_import = builtin_mod.__import__
465 454 builtin_mod.__import__ = local_import
466 455
467 456 if imp.lock_held():
468 457 # this is a side-effect import, don't do it remotely, or even
469 458 # ignore the local effects
470 459 return local_import(name, globals, locals, fromlist, level)
471 460
472 461 imp.acquire_lock()
473 462 if local:
474 463 mod = local_import(name, globals, locals, fromlist, level)
475 464 else:
476 465 raise NotImplementedError("remote-only imports not yet implemented")
477 466 imp.release_lock()
478 467
479 468 key = name+':'+','.join(fromlist or [])
480 469 if level <= 0 and key not in modules:
481 470 modules.add(key)
482 471 if not quiet:
483 472 if fromlist:
484 473 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
485 474 else:
486 475 print("importing %s on engine(s)"%name)
487 476 results.append(self.apply_async(remote_import, name, fromlist, level))
488 477 # restore override
489 478 builtin_mod.__import__ = save_import
490 479
491 480 return mod
492 481
493 482 # override __import__
494 483 builtin_mod.__import__ = view_import
495 484 try:
496 485 # enter the block
497 486 yield
498 487 except ImportError:
499 488 if local:
500 489 raise
501 490 else:
502 491 # ignore import errors if not doing local imports
503 492 pass
504 493 finally:
505 494 # always restore __import__
506 495 builtin_mod.__import__ = local_import
507 496
508 497 for r in results:
509 498 # raise possible remote ImportErrors here
510 499 r.get()
511 500
512 501 def use_dill(self):
513 502 """Expand serialization support with dill
514 503
515 504 adds support for closures, etc.
516 505
517 506 This calls IPython.utils.pickleutil.use_dill() here and on each engine.
518 507 """
519 508 pickleutil.use_dill()
520 509 return self.apply(pickleutil.use_dill)
521 510
522 511 def use_cloudpickle(self):
523 512 """Expand serialization support with cloudpickle.
524 513 """
525 514 pickleutil.use_cloudpickle()
526 515 return self.apply(pickleutil.use_cloudpickle)
527 516
528 517
529 518 @sync_results
530 519 @save_ids
531 520 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
532 521 """calls f(*args, **kwargs) on remote engines, returning the result.
533 522
534 523 This method sets all of `apply`'s flags via this View's attributes.
535 524
536 525 Parameters
537 526 ----------
538 527
539 528 f : callable
540 529
541 530 args : list [default: empty]
542 531
543 532 kwargs : dict [default: empty]
544 533
545 534 targets : target list [default: self.targets]
546 535 where to run
547 536 block : bool [default: self.block]
548 537 whether to block
549 538 track : bool [default: self.track]
550 539 whether to ask zmq to track the message, for safe non-copying sends
551 540
552 541 Returns
553 542 -------
554 543
555 544 if self.block is False:
556 545 returns AsyncResult
557 546 else:
558 547 returns actual result of f(*args, **kwargs) on the engine(s)
559 548 This will be a list of self.targets is also a list (even length 1), or
560 549 the single result if self.targets is an integer engine id
561 550 """
562 551 args = [] if args is None else args
563 552 kwargs = {} if kwargs is None else kwargs
564 553 block = self.block if block is None else block
565 554 track = self.track if track is None else track
566 555 targets = self.targets if targets is None else targets
567 556
568 557 _idents, _targets = self.client._build_targets(targets)
569 558 msg_ids = []
570 559 trackers = []
571 560 for ident in _idents:
572 561 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
573 562 ident=ident)
574 563 if track:
575 564 trackers.append(msg['tracker'])
576 565 msg_ids.append(msg['header']['msg_id'])
577 566 if isinstance(targets, int):
578 567 msg_ids = msg_ids[0]
579 568 tracker = None if track is False else zmq.MessageTracker(*trackers)
580 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets, tracker=tracker)
569 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets,
570 tracker=tracker, owner=True,
571 )
581 572 if block:
582 573 try:
583 574 return ar.get()
584 575 except KeyboardInterrupt:
585 576 pass
586 577 return ar
587 578
588 579
589 580 @sync_results
590 581 def map(self, f, *sequences, **kwargs):
591 582 """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult
592 583
593 584 Parallel version of builtin `map`, using this View's `targets`.
594 585
595 586 There will be one task per target, so work will be chunked
596 587 if the sequences are longer than `targets`.
597 588
598 589 Results can be iterated as they are ready, but will become available in chunks.
599 590
600 591 Parameters
601 592 ----------
602 593
603 594 f : callable
604 595 function to be mapped
605 596 *sequences: one or more sequences of matching length
606 597 the sequences to be distributed and passed to `f`
607 598 block : bool
608 599 whether to wait for the result or not [default self.block]
609 600
610 601 Returns
611 602 -------
612 603
613 604
614 605 If block=False
615 606 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
616 607 An object like AsyncResult, but which reassembles the sequence of results
617 608 into a single list. AsyncMapResults can be iterated through before all
618 609 results are complete.
619 610 else
620 611 A list, the result of ``map(f,*sequences)``
621 612 """
622 613
623 614 block = kwargs.pop('block', self.block)
624 615 for k in kwargs.keys():
625 616 if k not in ['block', 'track']:
626 617 raise TypeError("invalid keyword arg, %r"%k)
627 618
628 619 assert len(sequences) > 0, "must have some sequences to map onto!"
629 620 pf = ParallelFunction(self, f, block=block, **kwargs)
630 621 return pf.map(*sequences)
631 622
632 623 @sync_results
633 624 @save_ids
634 625 def execute(self, code, silent=True, targets=None, block=None):
635 626 """Executes `code` on `targets` in blocking or nonblocking manner.
636 627
637 628 ``execute`` is always `bound` (affects engine namespace)
638 629
639 630 Parameters
640 631 ----------
641 632
642 633 code : str
643 634 the code string to be executed
644 635 block : bool
645 636 whether or not to wait until done to return
646 637 default: self.block
647 638 """
648 639 block = self.block if block is None else block
649 640 targets = self.targets if targets is None else targets
650 641
651 642 _idents, _targets = self.client._build_targets(targets)
652 643 msg_ids = []
653 644 trackers = []
654 645 for ident in _idents:
655 646 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
656 647 msg_ids.append(msg['header']['msg_id'])
657 648 if isinstance(targets, int):
658 649 msg_ids = msg_ids[0]
659 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets)
650 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True)
660 651 if block:
661 652 try:
662 653 ar.get()
663 654 except KeyboardInterrupt:
664 655 pass
665 656 return ar
666 657
667 658 def run(self, filename, targets=None, block=None):
668 659 """Execute contents of `filename` on my engine(s).
669 660
670 661 This simply reads the contents of the file and calls `execute`.
671 662
672 663 Parameters
673 664 ----------
674 665
675 666 filename : str
676 667 The path to the file
677 668 targets : int/str/list of ints/strs
678 669 the engines on which to execute
679 670 default : all
680 671 block : bool
681 672 whether or not to wait until done
682 673 default: self.block
683 674
684 675 """
685 676 with open(filename, 'r') as f:
686 677 # add newline in case of trailing indented whitespace
687 678 # which will cause SyntaxError
688 679 code = f.read()+'\n'
689 680 return self.execute(code, block=block, targets=targets)
690 681
691 682 def update(self, ns):
692 683 """update remote namespace with dict `ns`
693 684
694 685 See `push` for details.
695 686 """
696 687 return self.push(ns, block=self.block, track=self.track)
697 688
698 689 def push(self, ns, targets=None, block=None, track=None):
699 690 """update remote namespace with dict `ns`
700 691
701 692 Parameters
702 693 ----------
703 694
704 695 ns : dict
705 696 dict of keys with which to update engine namespace(s)
706 697 block : bool [default : self.block]
707 698 whether to wait to be notified of engine receipt
708 699
709 700 """
710 701
711 702 block = block if block is not None else self.block
712 703 track = track if track is not None else self.track
713 704 targets = targets if targets is not None else self.targets
714 705 # applier = self.apply_sync if block else self.apply_async
715 706 if not isinstance(ns, dict):
716 707 raise TypeError("Must be a dict, not %s"%type(ns))
717 708 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
718 709
719 710 def get(self, key_s):
720 711 """get object(s) by `key_s` from remote namespace
721 712
722 713 see `pull` for details.
723 714 """
724 715 # block = block if block is not None else self.block
725 716 return self.pull(key_s, block=True)
726 717
727 718 def pull(self, names, targets=None, block=None):
728 719 """get object(s) by `name` from remote namespace
729 720
730 721 will return one object if it is a key.
731 722 can also take a list of keys, in which case it will return a list of objects.
732 723 """
733 724 block = block if block is not None else self.block
734 725 targets = targets if targets is not None else self.targets
735 726 applier = self.apply_sync if block else self.apply_async
736 727 if isinstance(names, string_types):
737 728 pass
738 729 elif isinstance(names, (list,tuple,set)):
739 730 for key in names:
740 731 if not isinstance(key, string_types):
741 732 raise TypeError("keys must be str, not type %r"%type(key))
742 733 else:
743 734 raise TypeError("names must be strs, not %r"%names)
744 735 return self._really_apply(util._pull, (names,), block=block, targets=targets)
745 736
746 737 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
747 738 """
748 739 Partition a Python sequence and send the partitions to a set of engines.
749 740 """
750 741 block = block if block is not None else self.block
751 742 track = track if track is not None else self.track
752 743 targets = targets if targets is not None else self.targets
753 744
754 745 # construct integer ID list:
755 746 targets = self.client._build_targets(targets)[1]
756 747
757 748 mapObject = Map.dists[dist]()
758 749 nparts = len(targets)
759 750 msg_ids = []
760 751 trackers = []
761 752 for index, engineid in enumerate(targets):
762 753 partition = mapObject.getPartition(seq, index, nparts)
763 754 if flatten and len(partition) == 1:
764 755 ns = {key: partition[0]}
765 756 else:
766 757 ns = {key: partition}
767 758 r = self.push(ns, block=False, track=track, targets=engineid)
768 759 msg_ids.extend(r.msg_ids)
769 760 if track:
770 761 trackers.append(r._tracker)
771 762
772 763 if track:
773 764 tracker = zmq.MessageTracker(*trackers)
774 765 else:
775 766 tracker = None
776 767
777 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
768 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets,
769 tracker=tracker, owner=True,
770 )
778 771 if block:
779 772 r.wait()
780 773 else:
781 774 return r
782 775
783 776 @sync_results
784 777 @save_ids
785 778 def gather(self, key, dist='b', targets=None, block=None):
786 779 """
787 780 Gather a partitioned sequence on a set of engines as a single local seq.
788 781 """
789 782 block = block if block is not None else self.block
790 783 targets = targets if targets is not None else self.targets
791 784 mapObject = Map.dists[dist]()
792 785 msg_ids = []
793 786
794 787 # construct integer ID list:
795 788 targets = self.client._build_targets(targets)[1]
796 789
797 790 for index, engineid in enumerate(targets):
798 791 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
799 792
800 793 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
801 794
802 795 if block:
803 796 try:
804 797 return r.get()
805 798 except KeyboardInterrupt:
806 799 pass
807 800 return r
808 801
809 802 def __getitem__(self, key):
810 803 return self.get(key)
811 804
812 805 def __setitem__(self,key, value):
813 806 self.update({key:value})
814 807
815 808 def clear(self, targets=None, block=None):
816 809 """Clear the remote namespaces on my engines."""
817 810 block = block if block is not None else self.block
818 811 targets = targets if targets is not None else self.targets
819 812 return self.client.clear(targets=targets, block=block)
820 813
821 814 #----------------------------------------
822 815 # activate for %px, %autopx, etc. magics
823 816 #----------------------------------------
824 817
825 818 def activate(self, suffix=''):
826 819 """Activate IPython magics associated with this View
827 820
828 821 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
829 822
830 823 Parameters
831 824 ----------
832 825
833 826 suffix: str [default: '']
834 827 The suffix, if any, for the magics. This allows you to have
835 828 multiple views associated with parallel magics at the same time.
836 829
837 830 e.g. ``rc[::2].activate(suffix='_even')`` will give you
838 831 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
839 832 on the even engines.
840 833 """
841 834
842 835 from IPython.parallel.client.magics import ParallelMagics
843 836
844 837 try:
845 838 # This is injected into __builtins__.
846 839 ip = get_ipython()
847 840 except NameError:
848 841 print("The IPython parallel magics (%px, etc.) only work within IPython.")
849 842 return
850 843
851 844 M = ParallelMagics(ip, self, suffix)
852 845 ip.magics_manager.register(M)
853 846
854 847
855 848 @skip_doctest
856 849 class LoadBalancedView(View):
857 850 """An load-balancing View that only executes via the Task scheduler.
858 851
859 852 Load-balanced views can be created with the client's `view` method:
860 853
861 854 >>> v = client.load_balanced_view()
862 855
863 856 or targets can be specified, to restrict the potential destinations:
864 857
865 858 >>> v = client.client.load_balanced_view([1,3])
866 859
867 860 which would restrict loadbalancing to between engines 1 and 3.
868 861
869 862 """
870 863
871 864 follow=Any()
872 865 after=Any()
873 866 timeout=CFloat()
874 867 retries = Integer(0)
875 868
876 869 _task_scheme = Any()
877 870 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
878 871
879 872 def __init__(self, client=None, socket=None, **flags):
880 873 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
881 874 self._task_scheme=client._task_scheme
882 875
883 876 def _validate_dependency(self, dep):
884 877 """validate a dependency.
885 878
886 879 For use in `set_flags`.
887 880 """
888 881 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
889 882 return True
890 883 elif isinstance(dep, (list,set, tuple)):
891 884 for d in dep:
892 885 if not isinstance(d, string_types + (AsyncResult,)):
893 886 return False
894 887 elif isinstance(dep, dict):
895 888 if set(dep.keys()) != set(Dependency().as_dict().keys()):
896 889 return False
897 890 if not isinstance(dep['msg_ids'], list):
898 891 return False
899 892 for d in dep['msg_ids']:
900 893 if not isinstance(d, string_types):
901 894 return False
902 895 else:
903 896 return False
904 897
905 898 return True
906 899
907 900 def _render_dependency(self, dep):
908 901 """helper for building jsonable dependencies from various input forms."""
909 902 if isinstance(dep, Dependency):
910 903 return dep.as_dict()
911 904 elif isinstance(dep, AsyncResult):
912 905 return dep.msg_ids
913 906 elif dep is None:
914 907 return []
915 908 else:
916 909 # pass to Dependency constructor
917 910 return list(Dependency(dep))
918 911
919 912 def set_flags(self, **kwargs):
920 913 """set my attribute flags by keyword.
921 914
922 915 A View is a wrapper for the Client's apply method, but with attributes
923 916 that specify keyword arguments, those attributes can be set by keyword
924 917 argument with this method.
925 918
926 919 Parameters
927 920 ----------
928 921
929 922 block : bool
930 923 whether to wait for results
931 924 track : bool
932 925 whether to create a MessageTracker to allow the user to
933 926 safely edit after arrays and buffers during non-copying
934 927 sends.
935 928
936 929 after : Dependency or collection of msg_ids
937 930 Only for load-balanced execution (targets=None)
938 931 Specify a list of msg_ids as a time-based dependency.
939 932 This job will only be run *after* the dependencies
940 933 have been met.
941 934
942 935 follow : Dependency or collection of msg_ids
943 936 Only for load-balanced execution (targets=None)
944 937 Specify a list of msg_ids as a location-based dependency.
945 938 This job will only be run on an engine where this dependency
946 939 is met.
947 940
948 941 timeout : float/int or None
949 942 Only for load-balanced execution (targets=None)
950 943 Specify an amount of time (in seconds) for the scheduler to
951 944 wait for dependencies to be met before failing with a
952 945 DependencyTimeout.
953 946
954 947 retries : int
955 948 Number of times a task will be retried on failure.
956 949 """
957 950
958 951 super(LoadBalancedView, self).set_flags(**kwargs)
959 952 for name in ('follow', 'after'):
960 953 if name in kwargs:
961 954 value = kwargs[name]
962 955 if self._validate_dependency(value):
963 956 setattr(self, name, value)
964 957 else:
965 958 raise ValueError("Invalid dependency: %r"%value)
966 959 if 'timeout' in kwargs:
967 960 t = kwargs['timeout']
968 961 if not isinstance(t, (int, float, type(None))):
969 962 if (not PY3) and (not isinstance(t, long)):
970 963 raise TypeError("Invalid type for timeout: %r"%type(t))
971 964 if t is not None:
972 965 if t < 0:
973 966 raise ValueError("Invalid timeout: %s"%t)
974 967 self.timeout = t
975 968
976 969 @sync_results
977 970 @save_ids
978 971 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
979 972 after=None, follow=None, timeout=None,
980 973 targets=None, retries=None):
981 974 """calls f(*args, **kwargs) on a remote engine, returning the result.
982 975
983 976 This method temporarily sets all of `apply`'s flags for a single call.
984 977
985 978 Parameters
986 979 ----------
987 980
988 981 f : callable
989 982
990 983 args : list [default: empty]
991 984
992 985 kwargs : dict [default: empty]
993 986
994 987 block : bool [default: self.block]
995 988 whether to block
996 989 track : bool [default: self.track]
997 990 whether to ask zmq to track the message, for safe non-copying sends
998 991
999 992 !!!!!! TODO: THE REST HERE !!!!
1000 993
1001 994 Returns
1002 995 -------
1003 996
1004 997 if self.block is False:
1005 998 returns AsyncResult
1006 999 else:
1007 1000 returns actual result of f(*args, **kwargs) on the engine(s)
1008 1001 This will be a list of self.targets is also a list (even length 1), or
1009 1002 the single result if self.targets is an integer engine id
1010 1003 """
1011 1004
1012 1005 # validate whether we can run
1013 1006 if self._socket.closed:
1014 1007 msg = "Task farming is disabled"
1015 1008 if self._task_scheme == 'pure':
1016 1009 msg += " because the pure ZMQ scheduler cannot handle"
1017 1010 msg += " disappearing engines."
1018 1011 raise RuntimeError(msg)
1019 1012
1020 1013 if self._task_scheme == 'pure':
1021 1014 # pure zmq scheme doesn't support extra features
1022 1015 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1023 1016 "follow, after, retries, targets, timeout"
1024 1017 if (follow or after or retries or targets or timeout):
1025 1018 # hard fail on Scheduler flags
1026 1019 raise RuntimeError(msg)
1027 1020 if isinstance(f, dependent):
1028 1021 # soft warn on functional dependencies
1029 1022 warnings.warn(msg, RuntimeWarning)
1030 1023
1031 1024 # build args
1032 1025 args = [] if args is None else args
1033 1026 kwargs = {} if kwargs is None else kwargs
1034 1027 block = self.block if block is None else block
1035 1028 track = self.track if track is None else track
1036 1029 after = self.after if after is None else after
1037 1030 retries = self.retries if retries is None else retries
1038 1031 follow = self.follow if follow is None else follow
1039 1032 timeout = self.timeout if timeout is None else timeout
1040 1033 targets = self.targets if targets is None else targets
1041 1034
1042 1035 if not isinstance(retries, int):
1043 1036 raise TypeError('retries must be int, not %r'%type(retries))
1044 1037
1045 1038 if targets is None:
1046 1039 idents = []
1047 1040 else:
1048 1041 idents = self.client._build_targets(targets)[0]
1049 1042 # ensure *not* bytes
1050 1043 idents = [ ident.decode() for ident in idents ]
1051 1044
1052 1045 after = self._render_dependency(after)
1053 1046 follow = self._render_dependency(follow)
1054 1047 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1055 1048
1056 1049 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1057 1050 metadata=metadata)
1058 1051 tracker = None if track is False else msg['tracker']
1059 1052
1060 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
1061
1053 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f),
1054 targets=None, tracker=tracker, owner=True,
1055 )
1062 1056 if block:
1063 1057 try:
1064 1058 return ar.get()
1065 1059 except KeyboardInterrupt:
1066 1060 pass
1067 1061 return ar
1068 1062
1069 1063 @sync_results
1070 1064 @save_ids
1071 1065 def map(self, f, *sequences, **kwargs):
1072 1066 """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult
1073 1067
1074 1068 Parallel version of builtin `map`, load-balanced by this View.
1075 1069
1076 1070 `block`, and `chunksize` can be specified by keyword only.
1077 1071
1078 1072 Each `chunksize` elements will be a separate task, and will be
1079 1073 load-balanced. This lets individual elements be available for iteration
1080 1074 as soon as they arrive.
1081 1075
1082 1076 Parameters
1083 1077 ----------
1084 1078
1085 1079 f : callable
1086 1080 function to be mapped
1087 1081 *sequences: one or more sequences of matching length
1088 1082 the sequences to be distributed and passed to `f`
1089 1083 block : bool [default self.block]
1090 1084 whether to wait for the result or not
1091 1085 track : bool
1092 1086 whether to create a MessageTracker to allow the user to
1093 1087 safely edit after arrays and buffers during non-copying
1094 1088 sends.
1095 1089 chunksize : int [default 1]
1096 1090 how many elements should be in each task.
1097 1091 ordered : bool [default True]
1098 1092 Whether the results should be gathered as they arrive, or enforce
1099 1093 the order of submission.
1100 1094
1101 1095 Only applies when iterating through AsyncMapResult as results arrive.
1102 1096 Has no effect when block=True.
1103 1097
1104 1098 Returns
1105 1099 -------
1106 1100
1107 1101 if block=False
1108 1102 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
1109 1103 An object like AsyncResult, but which reassembles the sequence of results
1110 1104 into a single list. AsyncMapResults can be iterated through before all
1111 1105 results are complete.
1112 1106 else
1113 1107 A list, the result of ``map(f,*sequences)``
1114 1108 """
1115 1109
1116 1110 # default
1117 1111 block = kwargs.get('block', self.block)
1118 1112 chunksize = kwargs.get('chunksize', 1)
1119 1113 ordered = kwargs.get('ordered', True)
1120 1114
1121 1115 keyset = set(kwargs.keys())
1122 1116 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1123 1117 if extra_keys:
1124 1118 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1125 1119
1126 1120 assert len(sequences) > 0, "must have some sequences to map onto!"
1127 1121
1128 1122 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1129 1123 return pf.map(*sequences)
1130 1124
1131 1125 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,327 +1,342 b''
1 """Tests for asyncresult.py
1 """Tests for asyncresult.py"""
2 2
3 Authors:
4
5 * Min RK
6 """
7
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
10 #
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
14
15 #-------------------------------------------------------------------------------
16 # Imports
17 #-------------------------------------------------------------------------------
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
18 5
19 6 import time
20 7
21 8 import nose.tools as nt
22 9
23 10 from IPython.utils.io import capture_output
24 11
25 12 from IPython.parallel.error import TimeoutError
26 13 from IPython.parallel import error, Client
27 14 from IPython.parallel.tests import add_engines
28 15 from .clienttest import ClusterTestCase
29 16 from IPython.utils.py3compat import iteritems
30 17
31 18 def setup():
32 19 add_engines(2, total=True)
33 20
34 21 def wait(n):
35 22 import time
36 23 time.sleep(n)
37 24 return n
38 25
39 26 def echo(x):
40 27 return x
41 28
42 29 class AsyncResultTest(ClusterTestCase):
43 30
44 31 def test_single_result_view(self):
45 32 """various one-target views get the right value for single_result"""
46 33 eid = self.client.ids[-1]
47 34 ar = self.client[eid].apply_async(lambda : 42)
48 35 self.assertEqual(ar.get(), 42)
49 36 ar = self.client[[eid]].apply_async(lambda : 42)
50 37 self.assertEqual(ar.get(), [42])
51 38 ar = self.client[-1:].apply_async(lambda : 42)
52 39 self.assertEqual(ar.get(), [42])
53 40
54 41 def test_get_after_done(self):
55 42 ar = self.client[-1].apply_async(lambda : 42)
56 43 ar.wait()
57 44 self.assertTrue(ar.ready())
58 45 self.assertEqual(ar.get(), 42)
59 46 self.assertEqual(ar.get(), 42)
60 47
61 48 def test_get_before_done(self):
62 49 ar = self.client[-1].apply_async(wait, 0.1)
63 50 self.assertRaises(TimeoutError, ar.get, 0)
64 51 ar.wait(0)
65 52 self.assertFalse(ar.ready())
66 53 self.assertEqual(ar.get(), 0.1)
67 54
68 55 def test_get_after_error(self):
69 56 ar = self.client[-1].apply_async(lambda : 1/0)
70 57 ar.wait(10)
71 58 self.assertRaisesRemote(ZeroDivisionError, ar.get)
72 59 self.assertRaisesRemote(ZeroDivisionError, ar.get)
73 60 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
74 61
75 62 def test_get_dict(self):
76 63 n = len(self.client)
77 64 ar = self.client[:].apply_async(lambda : 5)
78 65 self.assertEqual(ar.get(), [5]*n)
79 66 d = ar.get_dict()
80 67 self.assertEqual(sorted(d.keys()), sorted(self.client.ids))
81 68 for eid,r in iteritems(d):
82 69 self.assertEqual(r, 5)
83 70
84 71 def test_get_dict_single(self):
85 72 view = self.client[-1]
86 73 for v in (list(range(5)), 5, ('abc', 'def'), 'string'):
87 74 ar = view.apply_async(echo, v)
88 75 self.assertEqual(ar.get(), v)
89 76 d = ar.get_dict()
90 77 self.assertEqual(d, {view.targets : v})
91 78
92 79 def test_get_dict_bad(self):
93 80 ar = self.client[:].apply_async(lambda : 5)
94 81 ar2 = self.client[:].apply_async(lambda : 5)
95 82 ar = self.client.get_result(ar.msg_ids + ar2.msg_ids)
96 83 self.assertRaises(ValueError, ar.get_dict)
97 84
98 85 def test_list_amr(self):
99 86 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
100 87 rlist = list(ar)
101 88
102 89 def test_getattr(self):
103 90 ar = self.client[:].apply_async(wait, 0.5)
104 91 self.assertEqual(ar.engine_id, [None] * len(ar))
105 92 self.assertRaises(AttributeError, lambda : ar._foo)
106 93 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
107 94 self.assertRaises(AttributeError, lambda : ar.foo)
108 95 self.assertFalse(hasattr(ar, '__length_hint__'))
109 96 self.assertFalse(hasattr(ar, 'foo'))
110 97 self.assertTrue(hasattr(ar, 'engine_id'))
111 98 ar.get(5)
112 99 self.assertRaises(AttributeError, lambda : ar._foo)
113 100 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
114 101 self.assertRaises(AttributeError, lambda : ar.foo)
115 102 self.assertTrue(isinstance(ar.engine_id, list))
116 103 self.assertEqual(ar.engine_id, ar['engine_id'])
117 104 self.assertFalse(hasattr(ar, '__length_hint__'))
118 105 self.assertFalse(hasattr(ar, 'foo'))
119 106 self.assertTrue(hasattr(ar, 'engine_id'))
120 107
121 108 def test_getitem(self):
122 109 ar = self.client[:].apply_async(wait, 0.5)
123 110 self.assertEqual(ar['engine_id'], [None] * len(ar))
124 111 self.assertRaises(KeyError, lambda : ar['foo'])
125 112 ar.get(5)
126 113 self.assertRaises(KeyError, lambda : ar['foo'])
127 114 self.assertTrue(isinstance(ar['engine_id'], list))
128 115 self.assertEqual(ar.engine_id, ar['engine_id'])
129 116
130 117 def test_single_result(self):
131 118 ar = self.client[-1].apply_async(wait, 0.5)
132 119 self.assertRaises(KeyError, lambda : ar['foo'])
133 120 self.assertEqual(ar['engine_id'], None)
134 121 self.assertTrue(ar.get(5) == 0.5)
135 122 self.assertTrue(isinstance(ar['engine_id'], int))
136 123 self.assertTrue(isinstance(ar.engine_id, int))
137 124 self.assertEqual(ar.engine_id, ar['engine_id'])
138 125
139 126 def test_abort(self):
140 127 e = self.client[-1]
141 128 ar = e.execute('import time; time.sleep(1)', block=False)
142 129 ar2 = e.apply_async(lambda : 2)
143 130 ar2.abort()
144 131 self.assertRaises(error.TaskAborted, ar2.get)
145 132 ar.get()
146 133
147 134 def test_len(self):
148 135 v = self.client.load_balanced_view()
149 136 ar = v.map_async(lambda x: x, list(range(10)))
150 137 self.assertEqual(len(ar), 10)
151 138 ar = v.apply_async(lambda x: x, list(range(10)))
152 139 self.assertEqual(len(ar), 1)
153 140 ar = self.client[:].apply_async(lambda x: x, list(range(10)))
154 141 self.assertEqual(len(ar), len(self.client.ids))
155 142
156 143 def test_wall_time_single(self):
157 144 v = self.client.load_balanced_view()
158 145 ar = v.apply_async(time.sleep, 0.25)
159 146 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
160 147 ar.get(2)
161 148 self.assertTrue(ar.wall_time < 1.)
162 149 self.assertTrue(ar.wall_time > 0.2)
163 150
164 151 def test_wall_time_multi(self):
165 152 self.minimum_engines(4)
166 153 v = self.client[:]
167 154 ar = v.apply_async(time.sleep, 0.25)
168 155 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
169 156 ar.get(2)
170 157 self.assertTrue(ar.wall_time < 1.)
171 158 self.assertTrue(ar.wall_time > 0.2)
172 159
173 160 def test_serial_time_single(self):
174 161 v = self.client.load_balanced_view()
175 162 ar = v.apply_async(time.sleep, 0.25)
176 163 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
177 164 ar.get(2)
178 165 self.assertTrue(ar.serial_time < 1.)
179 166 self.assertTrue(ar.serial_time > 0.2)
180 167
181 168 def test_serial_time_multi(self):
182 169 self.minimum_engines(4)
183 170 v = self.client[:]
184 171 ar = v.apply_async(time.sleep, 0.25)
185 172 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
186 173 ar.get(2)
187 174 self.assertTrue(ar.serial_time < 2.)
188 175 self.assertTrue(ar.serial_time > 0.8)
189 176
190 177 def test_elapsed_single(self):
191 178 v = self.client.load_balanced_view()
192 179 ar = v.apply_async(time.sleep, 0.25)
193 180 while not ar.ready():
194 181 time.sleep(0.01)
195 182 self.assertTrue(ar.elapsed < 1)
196 183 self.assertTrue(ar.elapsed < 1)
197 184 ar.get(2)
198 185
199 186 def test_elapsed_multi(self):
200 187 v = self.client[:]
201 188 ar = v.apply_async(time.sleep, 0.25)
202 189 while not ar.ready():
203 190 time.sleep(0.01)
204 191 self.assertTrue(ar.elapsed < 1)
205 192 self.assertTrue(ar.elapsed < 1)
206 193 ar.get(2)
207 194
208 195 def test_hubresult_timestamps(self):
209 196 self.minimum_engines(4)
210 197 v = self.client[:]
211 198 ar = v.apply_async(time.sleep, 0.25)
212 199 ar.get(2)
213 200 rc2 = Client(profile='iptest')
214 201 # must have try/finally to close second Client, otherwise
215 202 # will have dangling sockets causing problems
216 203 try:
217 204 time.sleep(0.25)
218 205 hr = rc2.get_result(ar.msg_ids)
219 206 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
220 207 hr.get(1)
221 208 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
222 209 self.assertEqual(hr.serial_time, ar.serial_time)
223 210 finally:
224 211 rc2.close()
225 212
226 213 def test_display_empty_streams_single(self):
227 214 """empty stdout/err are not displayed (single result)"""
228 215 self.minimum_engines(1)
229 216
230 217 v = self.client[-1]
231 218 ar = v.execute("print (5555)")
232 219 ar.get(5)
233 220 with capture_output() as io:
234 221 ar.display_outputs()
235 222 self.assertEqual(io.stderr, '')
236 223 self.assertEqual('5555\n', io.stdout)
237 224
238 225 ar = v.execute("a=5")
239 226 ar.get(5)
240 227 with capture_output() as io:
241 228 ar.display_outputs()
242 229 self.assertEqual(io.stderr, '')
243 230 self.assertEqual(io.stdout, '')
244 231
245 232 def test_display_empty_streams_type(self):
246 233 """empty stdout/err are not displayed (groupby type)"""
247 234 self.minimum_engines(1)
248 235
249 236 v = self.client[:]
250 237 ar = v.execute("print (5555)")
251 238 ar.get(5)
252 239 with capture_output() as io:
253 240 ar.display_outputs()
254 241 self.assertEqual(io.stderr, '')
255 242 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
256 243 self.assertFalse('\n\n' in io.stdout, io.stdout)
257 244 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
258 245
259 246 ar = v.execute("a=5")
260 247 ar.get(5)
261 248 with capture_output() as io:
262 249 ar.display_outputs()
263 250 self.assertEqual(io.stderr, '')
264 251 self.assertEqual(io.stdout, '')
265 252
266 253 def test_display_empty_streams_engine(self):
267 254 """empty stdout/err are not displayed (groupby engine)"""
268 255 self.minimum_engines(1)
269 256
270 257 v = self.client[:]
271 258 ar = v.execute("print (5555)")
272 259 ar.get(5)
273 260 with capture_output() as io:
274 261 ar.display_outputs('engine')
275 262 self.assertEqual(io.stderr, '')
276 263 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
277 264 self.assertFalse('\n\n' in io.stdout, io.stdout)
278 265 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
279 266
280 267 ar = v.execute("a=5")
281 268 ar.get(5)
282 269 with capture_output() as io:
283 270 ar.display_outputs('engine')
284 271 self.assertEqual(io.stderr, '')
285 272 self.assertEqual(io.stdout, '')
286 273
287 274 def test_await_data(self):
288 275 """asking for ar.data flushes outputs"""
289 276 self.minimum_engines(1)
290 277
291 278 v = self.client[-1]
292 279 ar = v.execute('\n'.join([
293 280 "import time",
294 281 "from IPython.kernel.zmq.datapub import publish_data",
295 282 "for i in range(5):",
296 283 " publish_data(dict(i=i))",
297 284 " time.sleep(0.1)",
298 285 ]), block=False)
299 286 found = set()
300 287 tic = time.time()
301 288 # timeout after 10s
302 289 while time.time() <= tic + 10:
303 290 if ar.data:
304 291 i = ar.data['i']
305 292 found.add(i)
306 293 if i == 4:
307 294 break
308 295 time.sleep(0.05)
309 296
310 297 ar.get(5)
311 298 nt.assert_in(4, found)
312 299 self.assertTrue(len(found) > 1, "should have seen data multiple times, but got: %s" % found)
313 300
314 301 def test_not_single_result(self):
315 302 save_build = self.client._build_targets
316 303 def single_engine(*a, **kw):
317 304 idents, targets = save_build(*a, **kw)
318 305 return idents[:1], targets[:1]
319 306 ids = single_engine('all')[1]
320 307 self.client._build_targets = single_engine
321 308 for targets in ('all', None, ids):
322 309 dv = self.client.direct_view(targets=targets)
323 310 ar = dv.apply_async(lambda : 5)
324 311 self.assertEqual(ar.get(10), [5])
325 312 self.client._build_targets = save_build
313
314 def test_owner_pop(self):
315 self.minimum_engines(1)
316
317 view = self.client[-1]
318 ar = view.apply_async(lambda : 1)
319 ar.get()
320 msg_id = ar.msg_ids[0]
321 self.assertNotIn(msg_id, self.client.results)
322 self.assertNotIn(msg_id, self.client.metadata)
323
324 def test_non_owner(self):
325 self.minimum_engines(1)
326
327 view = self.client[-1]
328 ar = view.apply_async(lambda : 1)
329 ar.owner = False
330 ar.get()
331 msg_id = ar.msg_ids[0]
332 self.assertIn(msg_id, self.client.results)
333 self.assertIn(msg_id, self.client.metadata)
334 ar2 = self.client.get_result(msg_id, owner=True)
335 self.assertIs(type(ar2), type(ar))
336 self.assertTrue(ar2.owner)
337 self.assertEqual(ar.get(), ar2.get())
338 ar2.get()
339 self.assertNotIn(msg_id, self.client.results)
340 self.assertNotIn(msg_id, self.client.metadata)
326 341
327 342
@@ -1,547 +1,550 b''
1 1 """Tests for parallel client.py"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import division
7 7
8 8 import time
9 9 from datetime import datetime
10 10
11 11 import zmq
12 12
13 13 from IPython import parallel
14 14 from IPython.parallel.client import client as clientmod
15 15 from IPython.parallel import error
16 16 from IPython.parallel import AsyncResult, AsyncHubResult
17 17 from IPython.parallel import LoadBalancedView, DirectView
18 18
19 19 from .clienttest import ClusterTestCase, segfault, wait, add_engines
20 20
21 21 def setup():
22 22 add_engines(4, total=True)
23 23
24 24 class TestClient(ClusterTestCase):
25 25
26 26 def test_ids(self):
27 27 n = len(self.client.ids)
28 28 self.add_engines(2)
29 29 self.assertEqual(len(self.client.ids), n+2)
30 30
31 31 def test_iter(self):
32 32 self.minimum_engines(4)
33 33 engine_ids = [ view.targets for view in self.client ]
34 34 self.assertEqual(engine_ids, self.client.ids)
35 35
36 36 def test_view_indexing(self):
37 37 """test index access for views"""
38 38 self.minimum_engines(4)
39 39 targets = self.client._build_targets('all')[-1]
40 40 v = self.client[:]
41 41 self.assertEqual(v.targets, targets)
42 42 t = self.client.ids[2]
43 43 v = self.client[t]
44 44 self.assertTrue(isinstance(v, DirectView))
45 45 self.assertEqual(v.targets, t)
46 46 t = self.client.ids[2:4]
47 47 v = self.client[t]
48 48 self.assertTrue(isinstance(v, DirectView))
49 49 self.assertEqual(v.targets, t)
50 50 v = self.client[::2]
51 51 self.assertTrue(isinstance(v, DirectView))
52 52 self.assertEqual(v.targets, targets[::2])
53 53 v = self.client[1::3]
54 54 self.assertTrue(isinstance(v, DirectView))
55 55 self.assertEqual(v.targets, targets[1::3])
56 56 v = self.client[:-3]
57 57 self.assertTrue(isinstance(v, DirectView))
58 58 self.assertEqual(v.targets, targets[:-3])
59 59 v = self.client[-1]
60 60 self.assertTrue(isinstance(v, DirectView))
61 61 self.assertEqual(v.targets, targets[-1])
62 62 self.assertRaises(TypeError, lambda : self.client[None])
63 63
64 64 def test_lbview_targets(self):
65 65 """test load_balanced_view targets"""
66 66 v = self.client.load_balanced_view()
67 67 self.assertEqual(v.targets, None)
68 68 v = self.client.load_balanced_view(-1)
69 69 self.assertEqual(v.targets, [self.client.ids[-1]])
70 70 v = self.client.load_balanced_view('all')
71 71 self.assertEqual(v.targets, None)
72 72
73 73 def test_dview_targets(self):
74 74 """test direct_view targets"""
75 75 v = self.client.direct_view()
76 76 self.assertEqual(v.targets, 'all')
77 77 v = self.client.direct_view('all')
78 78 self.assertEqual(v.targets, 'all')
79 79 v = self.client.direct_view(-1)
80 80 self.assertEqual(v.targets, self.client.ids[-1])
81 81
82 82 def test_lazy_all_targets(self):
83 83 """test lazy evaluation of rc.direct_view('all')"""
84 84 v = self.client.direct_view()
85 85 self.assertEqual(v.targets, 'all')
86 86
87 87 def double(x):
88 88 return x*2
89 89 seq = list(range(100))
90 90 ref = [ double(x) for x in seq ]
91 91
92 92 # add some engines, which should be used
93 93 self.add_engines(1)
94 94 n1 = len(self.client.ids)
95 95
96 96 # simple apply
97 97 r = v.apply_sync(lambda : 1)
98 98 self.assertEqual(r, [1] * n1)
99 99
100 100 # map goes through remotefunction
101 101 r = v.map_sync(double, seq)
102 102 self.assertEqual(r, ref)
103 103
104 104 # add a couple more engines, and try again
105 105 self.add_engines(2)
106 106 n2 = len(self.client.ids)
107 107 self.assertNotEqual(n2, n1)
108 108
109 109 # apply
110 110 r = v.apply_sync(lambda : 1)
111 111 self.assertEqual(r, [1] * n2)
112 112
113 113 # map
114 114 r = v.map_sync(double, seq)
115 115 self.assertEqual(r, ref)
116 116
117 117 def test_targets(self):
118 118 """test various valid targets arguments"""
119 119 build = self.client._build_targets
120 120 ids = self.client.ids
121 121 idents,targets = build(None)
122 122 self.assertEqual(ids, targets)
123 123
124 124 def test_clear(self):
125 125 """test clear behavior"""
126 126 self.minimum_engines(2)
127 127 v = self.client[:]
128 128 v.block=True
129 129 v.push(dict(a=5))
130 130 v.pull('a')
131 131 id0 = self.client.ids[-1]
132 132 self.client.clear(targets=id0, block=True)
133 133 a = self.client[:-1].get('a')
134 134 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
135 135 self.client.clear(block=True)
136 136 for i in self.client.ids:
137 137 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
138 138
139 139 def test_get_result(self):
140 140 """test getting results from the Hub."""
141 141 c = clientmod.Client(profile='iptest')
142 142 t = c.ids[-1]
143 143 ar = c[t].apply_async(wait, 1)
144 144 # give the monitor time to notice the message
145 145 time.sleep(.25)
146 ahr = self.client.get_result(ar.msg_ids[0])
147 self.assertTrue(isinstance(ahr, AsyncHubResult))
146 ahr = self.client.get_result(ar.msg_ids[0], owner=False)
147 self.assertIsInstance(ahr, AsyncHubResult)
148 148 self.assertEqual(ahr.get(), ar.get())
149 149 ar2 = self.client.get_result(ar.msg_ids[0])
150 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 self.assertNotIsInstance(ar2, AsyncHubResult)
151 self.assertEqual(ahr.get(), ar2.get())
151 152 c.close()
152 153
153 154 def test_get_execute_result(self):
154 155 """test getting execute results from the Hub."""
155 156 c = clientmod.Client(profile='iptest')
156 157 t = c.ids[-1]
157 158 cell = '\n'.join([
158 159 'import time',
159 160 'time.sleep(0.25)',
160 161 '5'
161 162 ])
162 163 ar = c[t].execute("import time; time.sleep(1)", silent=False)
163 164 # give the monitor time to notice the message
164 165 time.sleep(.25)
165 ahr = self.client.get_result(ar.msg_ids[0])
166 self.assertTrue(isinstance(ahr, AsyncHubResult))
166 ahr = self.client.get_result(ar.msg_ids[0], owner=False)
167 self.assertIsInstance(ahr, AsyncHubResult)
167 168 self.assertEqual(ahr.get().execute_result, ar.get().execute_result)
168 169 ar2 = self.client.get_result(ar.msg_ids[0])
169 self.assertFalse(isinstance(ar2, AsyncHubResult))
170 self.assertNotIsInstance(ar2, AsyncHubResult)
171 self.assertEqual(ahr.get(), ar2.get())
170 172 c.close()
171 173
172 174 def test_ids_list(self):
173 175 """test client.ids"""
174 176 ids = self.client.ids
175 177 self.assertEqual(ids, self.client._ids)
176 178 self.assertFalse(ids is self.client._ids)
177 179 ids.remove(ids[-1])
178 180 self.assertNotEqual(ids, self.client._ids)
179 181
180 182 def test_queue_status(self):
181 183 ids = self.client.ids
182 184 id0 = ids[0]
183 185 qs = self.client.queue_status(targets=id0)
184 186 self.assertTrue(isinstance(qs, dict))
185 187 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
186 188 allqs = self.client.queue_status()
187 189 self.assertTrue(isinstance(allqs, dict))
188 190 intkeys = list(allqs.keys())
189 191 intkeys.remove('unassigned')
190 192 print("intkeys", intkeys)
191 193 intkeys = sorted(intkeys)
192 194 ids = self.client.ids
193 195 print("client.ids", ids)
194 196 ids = sorted(self.client.ids)
195 197 self.assertEqual(intkeys, ids)
196 198 unassigned = allqs.pop('unassigned')
197 199 for eid,qs in allqs.items():
198 200 self.assertTrue(isinstance(qs, dict))
199 201 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
200 202
201 203 def test_shutdown(self):
202 204 ids = self.client.ids
203 205 id0 = ids[0]
204 206 self.client.shutdown(id0, block=True)
205 207 while id0 in self.client.ids:
206 208 time.sleep(0.1)
207 209 self.client.spin()
208 210
209 211 self.assertRaises(IndexError, lambda : self.client[id0])
210 212
211 213 def test_result_status(self):
212 214 pass
213 215 # to be written
214 216
215 217 def test_db_query_dt(self):
216 218 """test db query by date"""
217 219 hist = self.client.hub_history()
218 220 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
219 221 tic = middle['submitted']
220 222 before = self.client.db_query({'submitted' : {'$lt' : tic}})
221 223 after = self.client.db_query({'submitted' : {'$gte' : tic}})
222 224 self.assertEqual(len(before)+len(after),len(hist))
223 225 for b in before:
224 226 self.assertTrue(b['submitted'] < tic)
225 227 for a in after:
226 228 self.assertTrue(a['submitted'] >= tic)
227 229 same = self.client.db_query({'submitted' : tic})
228 230 for s in same:
229 231 self.assertTrue(s['submitted'] == tic)
230 232
231 233 def test_db_query_keys(self):
232 234 """test extracting subset of record keys"""
233 235 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
234 236 for rec in found:
235 237 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
236 238
237 239 def test_db_query_default_keys(self):
238 240 """default db_query excludes buffers"""
239 241 found = self.client.db_query({'msg_id': {'$ne' : ''}})
240 242 for rec in found:
241 243 keys = set(rec.keys())
242 244 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
243 245 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
244 246
245 247 def test_db_query_msg_id(self):
246 248 """ensure msg_id is always in db queries"""
247 249 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
248 250 for rec in found:
249 251 self.assertTrue('msg_id' in rec.keys())
250 252 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
251 253 for rec in found:
252 254 self.assertTrue('msg_id' in rec.keys())
253 255 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
254 256 for rec in found:
255 257 self.assertTrue('msg_id' in rec.keys())
256 258
257 259 def test_db_query_get_result(self):
258 260 """pop in db_query shouldn't pop from result itself"""
259 261 self.client[:].apply_sync(lambda : 1)
260 262 found = self.client.db_query({'msg_id': {'$ne' : ''}})
261 263 rc2 = clientmod.Client(profile='iptest')
262 264 # If this bug is not fixed, this call will hang:
263 265 ar = rc2.get_result(self.client.history[-1])
264 266 ar.wait(2)
265 267 self.assertTrue(ar.ready())
266 268 ar.get()
267 269 rc2.close()
268 270
269 271 def test_db_query_in(self):
270 272 """test db query with '$in','$nin' operators"""
271 273 hist = self.client.hub_history()
272 274 even = hist[::2]
273 275 odd = hist[1::2]
274 276 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
275 277 found = [ r['msg_id'] for r in recs ]
276 278 self.assertEqual(set(even), set(found))
277 279 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
278 280 found = [ r['msg_id'] for r in recs ]
279 281 self.assertEqual(set(odd), set(found))
280 282
281 283 def test_hub_history(self):
282 284 hist = self.client.hub_history()
283 285 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
284 286 recdict = {}
285 287 for rec in recs:
286 288 recdict[rec['msg_id']] = rec
287 289
288 290 latest = datetime(1984,1,1)
289 291 for msg_id in hist:
290 292 rec = recdict[msg_id]
291 293 newt = rec['submitted']
292 294 self.assertTrue(newt >= latest)
293 295 latest = newt
294 296 ar = self.client[-1].apply_async(lambda : 1)
295 297 ar.get()
296 298 time.sleep(0.25)
297 299 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
298 300
299 301 def _wait_for_idle(self):
300 302 """wait for the cluster to become idle, according to the everyone."""
301 303 rc = self.client
302 304
303 305 # step 0. wait for local results
304 306 # this should be sufficient 99% of the time.
305 307 rc.wait(timeout=5)
306 308
307 309 # step 1. wait for all requests to be noticed
308 310 # timeout 5s, polling every 100ms
309 311 msg_ids = set(rc.history)
310 312 hub_hist = rc.hub_history()
311 313 for i in range(50):
312 314 if msg_ids.difference(hub_hist):
313 315 time.sleep(0.1)
314 316 hub_hist = rc.hub_history()
315 317 else:
316 318 break
317 319
318 320 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
319 321
320 322 # step 2. wait for all requests to be done
321 323 # timeout 5s, polling every 100ms
322 324 qs = rc.queue_status()
323 325 for i in range(50):
324 326 if qs['unassigned'] or any(qs[eid]['tasks'] + qs[eid]['queue'] for eid in qs if eid != 'unassigned'):
325 327 time.sleep(0.1)
326 328 qs = rc.queue_status()
327 329 else:
328 330 break
329 331
330 332 # ensure Hub up to date:
331 333 self.assertEqual(qs['unassigned'], 0)
332 334 for eid in [ eid for eid in qs if eid != 'unassigned' ]:
333 335 self.assertEqual(qs[eid]['tasks'], 0)
334 336 self.assertEqual(qs[eid]['queue'], 0)
335 337
336 338
337 339 def test_resubmit(self):
338 340 def f():
339 341 import random
340 342 return random.random()
341 343 v = self.client.load_balanced_view()
342 344 ar = v.apply_async(f)
343 345 r1 = ar.get(1)
344 346 # give the Hub a chance to notice:
345 347 self._wait_for_idle()
346 348 ahr = self.client.resubmit(ar.msg_ids)
347 349 r2 = ahr.get(1)
348 350 self.assertFalse(r1 == r2)
349 351
350 352 def test_resubmit_chain(self):
351 353 """resubmit resubmitted tasks"""
352 354 v = self.client.load_balanced_view()
353 355 ar = v.apply_async(lambda x: x, 'x'*1024)
354 356 ar.get()
355 357 self._wait_for_idle()
356 358 ars = [ar]
357 359
358 360 for i in range(10):
359 361 ar = ars[-1]
360 362 ar2 = self.client.resubmit(ar.msg_ids)
361 363
362 364 [ ar.get() for ar in ars ]
363 365
364 366 def test_resubmit_header(self):
365 367 """resubmit shouldn't clobber the whole header"""
366 368 def f():
367 369 import random
368 370 return random.random()
369 371 v = self.client.load_balanced_view()
370 372 v.retries = 1
371 373 ar = v.apply_async(f)
372 374 r1 = ar.get(1)
373 375 # give the Hub a chance to notice:
374 376 self._wait_for_idle()
375 377 ahr = self.client.resubmit(ar.msg_ids)
376 378 ahr.get(1)
377 379 time.sleep(0.5)
378 380 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
379 381 h1,h2 = [ r['header'] for r in records ]
380 382 for key in set(h1.keys()).union(set(h2.keys())):
381 383 if key in ('msg_id', 'date'):
382 384 self.assertNotEqual(h1[key], h2[key])
383 385 else:
384 386 self.assertEqual(h1[key], h2[key])
385 387
386 388 def test_resubmit_aborted(self):
387 389 def f():
388 390 import random
389 391 return random.random()
390 392 v = self.client.load_balanced_view()
391 393 # restrict to one engine, so we can put a sleep
392 394 # ahead of the task, so it will get aborted
393 395 eid = self.client.ids[-1]
394 396 v.targets = [eid]
395 397 sleep = v.apply_async(time.sleep, 0.5)
396 398 ar = v.apply_async(f)
397 399 ar.abort()
398 400 self.assertRaises(error.TaskAborted, ar.get)
399 401 # Give the Hub a chance to get up to date:
400 402 self._wait_for_idle()
401 403 ahr = self.client.resubmit(ar.msg_ids)
402 404 r2 = ahr.get(1)
403 405
404 406 def test_resubmit_inflight(self):
405 407 """resubmit of inflight task"""
406 408 v = self.client.load_balanced_view()
407 409 ar = v.apply_async(time.sleep,1)
408 410 # give the message a chance to arrive
409 411 time.sleep(0.2)
410 412 ahr = self.client.resubmit(ar.msg_ids)
411 413 ar.get(2)
412 414 ahr.get(2)
413 415
414 416 def test_resubmit_badkey(self):
415 417 """ensure KeyError on resubmit of nonexistant task"""
416 418 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
417 419
418 420 def test_purge_hub_results(self):
419 421 # ensure there are some tasks
420 422 for i in range(5):
421 423 self.client[:].apply_sync(lambda : 1)
422 424 # Wait for the Hub to realise the result is done:
423 425 # This prevents a race condition, where we
424 426 # might purge a result the Hub still thinks is pending.
425 427 self._wait_for_idle()
426 428 rc2 = clientmod.Client(profile='iptest')
427 429 hist = self.client.hub_history()
428 430 ahr = rc2.get_result([hist[-1]])
429 431 ahr.wait(10)
430 432 self.client.purge_hub_results(hist[-1])
431 433 newhist = self.client.hub_history()
432 434 self.assertEqual(len(newhist)+1,len(hist))
433 435 rc2.spin()
434 436 rc2.close()
435 437
436 438 def test_purge_local_results(self):
437 439 # ensure there are some tasks
438 440 res = []
439 441 for i in range(5):
440 442 res.append(self.client[:].apply_async(lambda : 1))
441 443 self._wait_for_idle()
442 444 self.client.wait(10) # wait for the results to come back
443 445 before = len(self.client.results)
444 446 self.assertEqual(len(self.client.metadata),before)
445 447 self.client.purge_local_results(res[-1])
446 448 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
447 449 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
448 450
449 451 def test_purge_local_results_outstanding(self):
450 452 v = self.client[-1]
451 453 ar = v.apply_async(lambda : 1)
452 454 msg_id = ar.msg_ids[0]
455 ar.owner = False
453 456 ar.get()
454 457 self._wait_for_idle()
455 458 ar2 = v.apply_async(time.sleep, 1)
456 459 self.assertIn(msg_id, self.client.results)
457 460 self.assertIn(msg_id, self.client.metadata)
458 461 self.client.purge_local_results(ar)
459 462 self.assertNotIn(msg_id, self.client.results)
460 463 self.assertNotIn(msg_id, self.client.metadata)
461 464 with self.assertRaises(RuntimeError):
462 465 self.client.purge_local_results(ar2)
463 466 ar2.get()
464 467 self.client.purge_local_results(ar2)
465 468
466 469 def test_purge_all_local_results_outstanding(self):
467 470 v = self.client[-1]
468 471 ar = v.apply_async(time.sleep, 1)
469 472 with self.assertRaises(RuntimeError):
470 473 self.client.purge_local_results('all')
471 474 ar.get()
472 475 self.client.purge_local_results('all')
473 476
474 477 def test_purge_all_hub_results(self):
475 478 self.client.purge_hub_results('all')
476 479 hist = self.client.hub_history()
477 480 self.assertEqual(len(hist), 0)
478 481
479 482 def test_purge_all_local_results(self):
480 483 self.client.purge_local_results('all')
481 484 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
482 485 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
483 486
484 487 def test_purge_all_results(self):
485 488 # ensure there are some tasks
486 489 for i in range(5):
487 490 self.client[:].apply_sync(lambda : 1)
488 491 self.client.wait(10)
489 492 self._wait_for_idle()
490 493 self.client.purge_results('all')
491 494 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
492 495 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
493 496 hist = self.client.hub_history()
494 497 self.assertEqual(len(hist), 0, msg="hub history not empty")
495 498
496 499 def test_purge_everything(self):
497 500 # ensure there are some tasks
498 501 for i in range(5):
499 502 self.client[:].apply_sync(lambda : 1)
500 503 self.client.wait(10)
501 504 self._wait_for_idle()
502 505 self.client.purge_everything()
503 506 # The client results
504 507 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
505 508 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
506 509 # The client "bookkeeping"
507 510 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
508 511 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
509 512 # the hub results
510 513 hist = self.client.hub_history()
511 514 self.assertEqual(len(hist), 0, msg="hub history not empty")
512 515
513 516
514 517 def test_spin_thread(self):
515 518 self.client.spin_thread(0.01)
516 519 ar = self.client[-1].apply_async(lambda : 1)
517 520 md = self.client.metadata[ar.msg_ids[0]]
518 521 # 3s timeout, 100ms poll
519 522 for i in range(30):
520 523 time.sleep(0.1)
521 524 if md['received'] is not None:
522 525 break
523 526 self.assertIsInstance(md['received'], datetime)
524 527
525 528 def test_stop_spin_thread(self):
526 529 self.client.spin_thread(0.01)
527 530 self.client.stop_spin_thread()
528 531 ar = self.client[-1].apply_async(lambda : 1)
529 532 md = self.client.metadata[ar.msg_ids[0]]
530 533 # 500ms timeout, 100ms poll
531 534 for i in range(5):
532 535 time.sleep(0.1)
533 536 self.assertIsNone(md['received'], None)
534 537
535 538 def test_activate(self):
536 539 ip = get_ipython()
537 540 magics = ip.magics_manager.magics
538 541 self.assertTrue('px' in magics['line'])
539 542 self.assertTrue('px' in magics['cell'])
540 543 v0 = self.client.activate(-1, '0')
541 544 self.assertTrue('px0' in magics['line'])
542 545 self.assertTrue('px0' in magics['cell'])
543 546 self.assertEqual(v0.targets, self.client.ids[-1])
544 547 v0 = self.client.activate('all', 'all')
545 548 self.assertTrue('pxall' in magics['line'])
546 549 self.assertTrue('pxall' in magics['cell'])
547 550 self.assertEqual(v0.targets, 'all')
@@ -1,842 +1,843 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects"""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6
7 7 import base64
8 8 import sys
9 9 import platform
10 10 import time
11 11 from collections import namedtuple
12 12 from tempfile import NamedTemporaryFile
13 13
14 14 import zmq
15 15 from nose.plugins.attrib import attr
16 16
17 17 from IPython.testing import decorators as dec
18 18 from IPython.utils.io import capture_output
19 19 from IPython.utils.py3compat import unicode_type
20 20
21 21 from IPython import parallel as pmod
22 22 from IPython.parallel import error
23 23 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
24 24 from IPython.parallel.util import interactive
25 25
26 26 from IPython.parallel.tests import add_engines
27 27
28 28 from .clienttest import ClusterTestCase, crash, wait, skip_without
29 29
30 30 def setup():
31 31 add_engines(3, total=True)
32 32
33 33 point = namedtuple("point", "x y")
34 34
35 35 class TestView(ClusterTestCase):
36 36
37 37 def setUp(self):
38 38 # On Win XP, wait for resource cleanup, else parallel test group fails
39 39 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
40 40 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
41 41 time.sleep(2)
42 42 super(TestView, self).setUp()
43 43
44 44 @attr('crash')
45 45 def test_z_crash_mux(self):
46 46 """test graceful handling of engine death (direct)"""
47 47 # self.add_engines(1)
48 48 eid = self.client.ids[-1]
49 49 ar = self.client[eid].apply_async(crash)
50 50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 51 eid = ar.engine_id
52 52 tic = time.time()
53 53 while eid in self.client.ids and time.time()-tic < 5:
54 54 time.sleep(.01)
55 55 self.client.spin()
56 56 self.assertFalse(eid in self.client.ids, "Engine should have died")
57 57
58 58 def test_push_pull(self):
59 59 """test pushing and pulling"""
60 60 data = dict(a=10, b=1.05, c=list(range(10)), d={'e':(1,2),'f':'hi'})
61 61 t = self.client.ids[-1]
62 62 v = self.client[t]
63 63 push = v.push
64 64 pull = v.pull
65 65 v.block=True
66 66 nengines = len(self.client)
67 67 push({'data':data})
68 68 d = pull('data')
69 69 self.assertEqual(d, data)
70 70 self.client[:].push({'data':data})
71 71 d = self.client[:].pull('data', block=True)
72 72 self.assertEqual(d, nengines*[data])
73 73 ar = push({'data':data}, block=False)
74 74 self.assertTrue(isinstance(ar, AsyncResult))
75 75 r = ar.get()
76 76 ar = self.client[:].pull('data', block=False)
77 77 self.assertTrue(isinstance(ar, AsyncResult))
78 78 r = ar.get()
79 79 self.assertEqual(r, nengines*[data])
80 80 self.client[:].push(dict(a=10,b=20))
81 81 r = self.client[:].pull(('a','b'), block=True)
82 82 self.assertEqual(r, nengines*[[10,20]])
83 83
84 84 def test_push_pull_function(self):
85 85 "test pushing and pulling functions"
86 86 def testf(x):
87 87 return 2.0*x
88 88
89 89 t = self.client.ids[-1]
90 90 v = self.client[t]
91 91 v.block=True
92 92 push = v.push
93 93 pull = v.pull
94 94 execute = v.execute
95 95 push({'testf':testf})
96 96 r = pull('testf')
97 97 self.assertEqual(r(1.0), testf(1.0))
98 98 execute('r = testf(10)')
99 99 r = pull('r')
100 100 self.assertEqual(r, testf(10))
101 101 ar = self.client[:].push({'testf':testf}, block=False)
102 102 ar.get()
103 103 ar = self.client[:].pull('testf', block=False)
104 104 rlist = ar.get()
105 105 for r in rlist:
106 106 self.assertEqual(r(1.0), testf(1.0))
107 107 execute("def g(x): return x*x")
108 108 r = pull(('testf','g'))
109 109 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
110 110
111 111 def test_push_function_globals(self):
112 112 """test that pushed functions have access to globals"""
113 113 @interactive
114 114 def geta():
115 115 return a
116 116 # self.add_engines(1)
117 117 v = self.client[-1]
118 118 v.block=True
119 119 v['f'] = geta
120 120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 121 v.execute('a=5')
122 122 v.execute('b=f()')
123 123 self.assertEqual(v['b'], 5)
124 124
125 125 def test_push_function_defaults(self):
126 126 """test that pushed functions preserve default args"""
127 127 def echo(a=10):
128 128 return a
129 129 v = self.client[-1]
130 130 v.block=True
131 131 v['f'] = echo
132 132 v.execute('b=f()')
133 133 self.assertEqual(v['b'], 10)
134 134
135 135 def test_get_result(self):
136 136 """test getting results from the Hub."""
137 137 c = pmod.Client(profile='iptest')
138 138 # self.add_engines(1)
139 139 t = c.ids[-1]
140 140 v = c[t]
141 141 v2 = self.client[t]
142 142 ar = v.apply_async(wait, 1)
143 143 # give the monitor time to notice the message
144 144 time.sleep(.25)
145 ahr = v2.get_result(ar.msg_ids[0])
146 self.assertTrue(isinstance(ahr, AsyncHubResult))
145 ahr = v2.get_result(ar.msg_ids[0], owner=False)
146 self.assertIsInstance(ahr, AsyncHubResult)
147 147 self.assertEqual(ahr.get(), ar.get())
148 148 ar2 = v2.get_result(ar.msg_ids[0])
149 self.assertFalse(isinstance(ar2, AsyncHubResult))
149 self.assertNotIsInstance(ar2, AsyncHubResult)
150 self.assertEqual(ahr.get(), ar2.get())
150 151 c.spin()
151 152 c.close()
152 153
153 154 def test_run_newline(self):
154 155 """test that run appends newline to files"""
155 156 with NamedTemporaryFile('w', delete=False) as f:
156 157 f.write("""def g():
157 158 return 5
158 159 """)
159 160 v = self.client[-1]
160 161 v.run(f.name, block=True)
161 162 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
162 163
163 164 def test_apply_tracked(self):
164 165 """test tracking for apply"""
165 166 # self.add_engines(1)
166 167 t = self.client.ids[-1]
167 168 v = self.client[t]
168 169 v.block=False
169 170 def echo(n=1024*1024, **kwargs):
170 171 with v.temp_flags(**kwargs):
171 172 return v.apply(lambda x: x, 'x'*n)
172 173 ar = echo(1, track=False)
173 174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
174 175 self.assertTrue(ar.sent)
175 176 ar = echo(track=True)
176 177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
177 178 self.assertEqual(ar.sent, ar._tracker.done)
178 179 ar._tracker.wait()
179 180 self.assertTrue(ar.sent)
180 181
181 182 def test_push_tracked(self):
182 183 t = self.client.ids[-1]
183 184 ns = dict(x='x'*1024*1024)
184 185 v = self.client[t]
185 186 ar = v.push(ns, block=False, track=False)
186 187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 188 self.assertTrue(ar.sent)
188 189
189 190 ar = v.push(ns, block=False, track=True)
190 191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
191 192 ar._tracker.wait()
192 193 self.assertEqual(ar.sent, ar._tracker.done)
193 194 self.assertTrue(ar.sent)
194 195 ar.get()
195 196
196 197 def test_scatter_tracked(self):
197 198 t = self.client.ids
198 199 x='x'*1024*1024
199 200 ar = self.client[t].scatter('x', x, block=False, track=False)
200 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 202 self.assertTrue(ar.sent)
202 203
203 204 ar = self.client[t].scatter('x', x, block=False, track=True)
204 205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
205 206 self.assertEqual(ar.sent, ar._tracker.done)
206 207 ar._tracker.wait()
207 208 self.assertTrue(ar.sent)
208 209 ar.get()
209 210
210 211 def test_remote_reference(self):
211 212 v = self.client[-1]
212 213 v['a'] = 123
213 214 ra = pmod.Reference('a')
214 215 b = v.apply_sync(lambda x: x, ra)
215 216 self.assertEqual(b, 123)
216 217
217 218
218 219 def test_scatter_gather(self):
219 220 view = self.client[:]
220 221 seq1 = list(range(16))
221 222 view.scatter('a', seq1)
222 223 seq2 = view.gather('a', block=True)
223 224 self.assertEqual(seq2, seq1)
224 225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
225 226
226 227 @skip_without('numpy')
227 228 def test_scatter_gather_numpy(self):
228 229 import numpy
229 230 from numpy.testing.utils import assert_array_equal
230 231 view = self.client[:]
231 232 a = numpy.arange(64)
232 233 view.scatter('a', a, block=True)
233 234 b = view.gather('a', block=True)
234 235 assert_array_equal(b, a)
235 236
236 237 def test_scatter_gather_lazy(self):
237 238 """scatter/gather with targets='all'"""
238 239 view = self.client.direct_view(targets='all')
239 240 x = list(range(64))
240 241 view.scatter('x', x)
241 242 gathered = view.gather('x', block=True)
242 243 self.assertEqual(gathered, x)
243 244
244 245
245 246 @dec.known_failure_py3
246 247 @skip_without('numpy')
247 248 def test_push_numpy_nocopy(self):
248 249 import numpy
249 250 view = self.client[:]
250 251 a = numpy.arange(64)
251 252 view['A'] = a
252 253 @interactive
253 254 def check_writeable(x):
254 255 return x.flags.writeable
255 256
256 257 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
257 258 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
258 259
259 260 view.push(dict(B=a))
260 261 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
261 262 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
262 263
263 264 @skip_without('numpy')
264 265 def test_apply_numpy(self):
265 266 """view.apply(f, ndarray)"""
266 267 import numpy
267 268 from numpy.testing.utils import assert_array_equal
268 269
269 270 A = numpy.random.random((100,100))
270 271 view = self.client[-1]
271 272 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
272 273 B = A.astype(dt)
273 274 C = view.apply_sync(lambda x:x, B)
274 275 assert_array_equal(B,C)
275 276
276 277 @skip_without('numpy')
277 278 def test_apply_numpy_object_dtype(self):
278 279 """view.apply(f, ndarray) with dtype=object"""
279 280 import numpy
280 281 from numpy.testing.utils import assert_array_equal
281 282 view = self.client[-1]
282 283
283 284 A = numpy.array([dict(a=5)])
284 285 B = view.apply_sync(lambda x:x, A)
285 286 assert_array_equal(A,B)
286 287
287 288 A = numpy.array([(0, dict(b=10))], dtype=[('i', int), ('o', object)])
288 289 B = view.apply_sync(lambda x:x, A)
289 290 assert_array_equal(A,B)
290 291
291 292 @skip_without('numpy')
292 293 def test_push_pull_recarray(self):
293 294 """push/pull recarrays"""
294 295 import numpy
295 296 from numpy.testing.utils import assert_array_equal
296 297
297 298 view = self.client[-1]
298 299
299 300 R = numpy.array([
300 301 (1, 'hi', 0.),
301 302 (2**30, 'there', 2.5),
302 303 (-99999, 'world', -12345.6789),
303 304 ], [('n', int), ('s', '|S10'), ('f', float)])
304 305
305 306 view['RR'] = R
306 307 R2 = view['RR']
307 308
308 309 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 310 self.assertEqual(r_dtype, R.dtype)
310 311 self.assertEqual(r_shape, R.shape)
311 312 self.assertEqual(R2.dtype, R.dtype)
312 313 self.assertEqual(R2.shape, R.shape)
313 314 assert_array_equal(R2, R)
314 315
315 316 @skip_without('pandas')
316 317 def test_push_pull_timeseries(self):
317 318 """push/pull pandas.TimeSeries"""
318 319 import pandas
319 320
320 321 ts = pandas.TimeSeries(list(range(10)))
321 322
322 323 view = self.client[-1]
323 324
324 325 view.push(dict(ts=ts), block=True)
325 326 rts = view['ts']
326 327
327 328 self.assertEqual(type(rts), type(ts))
328 329 self.assertTrue((ts == rts).all())
329 330
330 331 def test_map(self):
331 332 view = self.client[:]
332 333 def f(x):
333 334 return x**2
334 335 data = list(range(16))
335 336 r = view.map_sync(f, data)
336 337 self.assertEqual(r, list(map(f, data)))
337 338
338 339 def test_map_empty_sequence(self):
339 340 view = self.client[:]
340 341 r = view.map_sync(lambda x: x, [])
341 342 self.assertEqual(r, [])
342 343
343 344 def test_map_iterable(self):
344 345 """test map on iterables (direct)"""
345 346 view = self.client[:]
346 347 # 101 is prime, so it won't be evenly distributed
347 348 arr = range(101)
348 349 # ensure it will be an iterator, even in Python 3
349 350 it = iter(arr)
350 351 r = view.map_sync(lambda x: x, it)
351 352 self.assertEqual(r, list(arr))
352 353
353 354 @skip_without('numpy')
354 355 def test_map_numpy(self):
355 356 """test map on numpy arrays (direct)"""
356 357 import numpy
357 358 from numpy.testing.utils import assert_array_equal
358 359
359 360 view = self.client[:]
360 361 # 101 is prime, so it won't be evenly distributed
361 362 arr = numpy.arange(101)
362 363 r = view.map_sync(lambda x: x, arr)
363 364 assert_array_equal(r, arr)
364 365
365 366 def test_scatter_gather_nonblocking(self):
366 367 data = list(range(16))
367 368 view = self.client[:]
368 369 view.scatter('a', data, block=False)
369 370 ar = view.gather('a', block=False)
370 371 self.assertEqual(ar.get(), data)
371 372
372 373 @skip_without('numpy')
373 374 def test_scatter_gather_numpy_nonblocking(self):
374 375 import numpy
375 376 from numpy.testing.utils import assert_array_equal
376 377 a = numpy.arange(64)
377 378 view = self.client[:]
378 379 ar = view.scatter('a', a, block=False)
379 380 self.assertTrue(isinstance(ar, AsyncResult))
380 381 amr = view.gather('a', block=False)
381 382 self.assertTrue(isinstance(amr, AsyncMapResult))
382 383 assert_array_equal(amr.get(), a)
383 384
384 385 def test_execute(self):
385 386 view = self.client[:]
386 387 # self.client.debug=True
387 388 execute = view.execute
388 389 ar = execute('c=30', block=False)
389 390 self.assertTrue(isinstance(ar, AsyncResult))
390 391 ar = execute('d=[0,1,2]', block=False)
391 392 self.client.wait(ar, 1)
392 393 self.assertEqual(len(ar.get()), len(self.client))
393 394 for c in view['c']:
394 395 self.assertEqual(c, 30)
395 396
396 397 def test_abort(self):
397 398 view = self.client[-1]
398 399 ar = view.execute('import time; time.sleep(1)', block=False)
399 400 ar2 = view.apply_async(lambda : 2)
400 401 ar3 = view.apply_async(lambda : 3)
401 402 view.abort(ar2)
402 403 view.abort(ar3.msg_ids)
403 404 self.assertRaises(error.TaskAborted, ar2.get)
404 405 self.assertRaises(error.TaskAborted, ar3.get)
405 406
406 407 def test_abort_all(self):
407 408 """view.abort() aborts all outstanding tasks"""
408 409 view = self.client[-1]
409 410 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
410 411 view.abort()
411 412 view.wait(timeout=5)
412 413 for ar in ars[5:]:
413 414 self.assertRaises(error.TaskAborted, ar.get)
414 415
415 416 def test_temp_flags(self):
416 417 view = self.client[-1]
417 418 view.block=True
418 419 with view.temp_flags(block=False):
419 420 self.assertFalse(view.block)
420 421 self.assertTrue(view.block)
421 422
422 423 @dec.known_failure_py3
423 424 def test_importer(self):
424 425 view = self.client[-1]
425 426 view.clear(block=True)
426 427 with view.importer:
427 428 import re
428 429
429 430 @interactive
430 431 def findall(pat, s):
431 432 # this globals() step isn't necessary in real code
432 433 # only to prevent a closure in the test
433 434 re = globals()['re']
434 435 return re.findall(pat, s)
435 436
436 437 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
437 438
438 439 def test_unicode_execute(self):
439 440 """test executing unicode strings"""
440 441 v = self.client[-1]
441 442 v.block=True
442 443 if sys.version_info[0] >= 3:
443 444 code="a='é'"
444 445 else:
445 446 code=u"a=u'é'"
446 447 v.execute(code)
447 448 self.assertEqual(v['a'], u'é')
448 449
449 450 def test_unicode_apply_result(self):
450 451 """test unicode apply results"""
451 452 v = self.client[-1]
452 453 r = v.apply_sync(lambda : u'é')
453 454 self.assertEqual(r, u'é')
454 455
455 456 def test_unicode_apply_arg(self):
456 457 """test passing unicode arguments to apply"""
457 458 v = self.client[-1]
458 459
459 460 @interactive
460 461 def check_unicode(a, check):
461 462 assert not isinstance(a, bytes), "%r is bytes, not unicode"%a
462 463 assert isinstance(check, bytes), "%r is not bytes"%check
463 464 assert a.encode('utf8') == check, "%s != %s"%(a,check)
464 465
465 466 for s in [ u'é', u'ßø®∫',u'asdf' ]:
466 467 try:
467 468 v.apply_sync(check_unicode, s, s.encode('utf8'))
468 469 except error.RemoteError as e:
469 470 if e.ename == 'AssertionError':
470 471 self.fail(e.evalue)
471 472 else:
472 473 raise e
473 474
474 475 def test_map_reference(self):
475 476 """view.map(<Reference>, *seqs) should work"""
476 477 v = self.client[:]
477 478 v.scatter('n', self.client.ids, flatten=True)
478 479 v.execute("f = lambda x,y: x*y")
479 480 rf = pmod.Reference('f')
480 481 nlist = list(range(10))
481 482 mlist = nlist[::-1]
482 483 expected = [ m*n for m,n in zip(mlist, nlist) ]
483 484 result = v.map_sync(rf, mlist, nlist)
484 485 self.assertEqual(result, expected)
485 486
486 487 def test_apply_reference(self):
487 488 """view.apply(<Reference>, *args) should work"""
488 489 v = self.client[:]
489 490 v.scatter('n', self.client.ids, flatten=True)
490 491 v.execute("f = lambda x: n*x")
491 492 rf = pmod.Reference('f')
492 493 result = v.apply_sync(rf, 5)
493 494 expected = [ 5*id for id in self.client.ids ]
494 495 self.assertEqual(result, expected)
495 496
496 497 def test_eval_reference(self):
497 498 v = self.client[self.client.ids[0]]
498 499 v['g'] = list(range(5))
499 500 rg = pmod.Reference('g[0]')
500 501 echo = lambda x:x
501 502 self.assertEqual(v.apply_sync(echo, rg), 0)
502 503
503 504 def test_reference_nameerror(self):
504 505 v = self.client[self.client.ids[0]]
505 506 r = pmod.Reference('elvis_has_left')
506 507 echo = lambda x:x
507 508 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
508 509
509 510 def test_single_engine_map(self):
510 511 e0 = self.client[self.client.ids[0]]
511 512 r = list(range(5))
512 513 check = [ -1*i for i in r ]
513 514 result = e0.map_sync(lambda x: -1*x, r)
514 515 self.assertEqual(result, check)
515 516
516 517 def test_len(self):
517 518 """len(view) makes sense"""
518 519 e0 = self.client[self.client.ids[0]]
519 520 self.assertEqual(len(e0), 1)
520 521 v = self.client[:]
521 522 self.assertEqual(len(v), len(self.client.ids))
522 523 v = self.client.direct_view('all')
523 524 self.assertEqual(len(v), len(self.client.ids))
524 525 v = self.client[:2]
525 526 self.assertEqual(len(v), 2)
526 527 v = self.client[:1]
527 528 self.assertEqual(len(v), 1)
528 529 v = self.client.load_balanced_view()
529 530 self.assertEqual(len(v), len(self.client.ids))
530 531
531 532
532 533 # begin execute tests
533 534
534 535 def test_execute_reply(self):
535 536 e0 = self.client[self.client.ids[0]]
536 537 e0.block = True
537 538 ar = e0.execute("5", silent=False)
538 539 er = ar.get()
539 540 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
540 541 self.assertEqual(er.execute_result['data']['text/plain'], '5')
541 542
542 543 def test_execute_reply_rich(self):
543 544 e0 = self.client[self.client.ids[0]]
544 545 e0.block = True
545 546 e0.execute("from IPython.display import Image, HTML")
546 547 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
547 548 er = ar.get()
548 549 b64data = base64.encodestring(b'garbage').decode('ascii')
549 550 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
550 551 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
551 552 er = ar.get()
552 553 self.assertEqual(er._repr_html_(), "<b>bold</b>")
553 554
554 555 def test_execute_reply_stdout(self):
555 556 e0 = self.client[self.client.ids[0]]
556 557 e0.block = True
557 558 ar = e0.execute("print (5)", silent=False)
558 559 er = ar.get()
559 560 self.assertEqual(er.stdout.strip(), '5')
560 561
561 562 def test_execute_result(self):
562 563 """execute triggers execute_result with silent=False"""
563 564 view = self.client[:]
564 565 ar = view.execute("5", silent=False, block=True)
565 566
566 567 expected = [{'text/plain' : '5'}] * len(view)
567 568 mimes = [ out['data'] for out in ar.execute_result ]
568 569 self.assertEqual(mimes, expected)
569 570
570 571 def test_execute_silent(self):
571 572 """execute does not trigger execute_result with silent=True"""
572 573 view = self.client[:]
573 574 ar = view.execute("5", block=True)
574 575 expected = [None] * len(view)
575 576 self.assertEqual(ar.execute_result, expected)
576 577
577 578 def test_execute_magic(self):
578 579 """execute accepts IPython commands"""
579 580 view = self.client[:]
580 581 view.execute("a = 5")
581 582 ar = view.execute("%whos", block=True)
582 583 # this will raise, if that failed
583 584 ar.get(5)
584 585 for stdout in ar.stdout:
585 586 lines = stdout.splitlines()
586 587 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
587 588 found = False
588 589 for line in lines[2:]:
589 590 split = line.split()
590 591 if split == ['a', 'int', '5']:
591 592 found = True
592 593 break
593 594 self.assertTrue(found, "whos output wrong: %s" % stdout)
594 595
595 596 def test_execute_displaypub(self):
596 597 """execute tracks display_pub output"""
597 598 view = self.client[:]
598 599 view.execute("from IPython.core.display import *")
599 600 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
600 601
601 602 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
602 603 for outputs in ar.outputs:
603 604 mimes = [ out['data'] for out in outputs ]
604 605 self.assertEqual(mimes, expected)
605 606
606 607 def test_apply_displaypub(self):
607 608 """apply tracks display_pub output"""
608 609 view = self.client[:]
609 610 view.execute("from IPython.core.display import *")
610 611
611 612 @interactive
612 613 def publish():
613 614 [ display(i) for i in range(5) ]
614 615
615 616 ar = view.apply_async(publish)
616 617 ar.get(5)
617 618 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
618 619 for outputs in ar.outputs:
619 620 mimes = [ out['data'] for out in outputs ]
620 621 self.assertEqual(mimes, expected)
621 622
622 623 def test_execute_raises(self):
623 624 """exceptions in execute requests raise appropriately"""
624 625 view = self.client[-1]
625 626 ar = view.execute("1/0")
626 627 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
627 628
628 629 def test_remoteerror_render_exception(self):
629 630 """RemoteErrors get nice tracebacks"""
630 631 view = self.client[-1]
631 632 ar = view.execute("1/0")
632 633 ip = get_ipython()
633 634 ip.user_ns['ar'] = ar
634 635 with capture_output() as io:
635 636 ip.run_cell("ar.get(2)")
636 637
637 638 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
638 639
639 640 def test_compositeerror_render_exception(self):
640 641 """CompositeErrors get nice tracebacks"""
641 642 view = self.client[:]
642 643 ar = view.execute("1/0")
643 644 ip = get_ipython()
644 645 ip.user_ns['ar'] = ar
645 646
646 647 with capture_output() as io:
647 648 ip.run_cell("ar.get(2)")
648 649
649 650 count = min(error.CompositeError.tb_limit, len(view))
650 651
651 652 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
652 653 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
653 654 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
654 655
655 656 def test_compositeerror_truncate(self):
656 657 """Truncate CompositeErrors with many exceptions"""
657 658 view = self.client[:]
658 659 msg_ids = []
659 660 for i in range(10):
660 661 ar = view.execute("1/0")
661 662 msg_ids.extend(ar.msg_ids)
662 663
663 664 ar = self.client.get_result(msg_ids)
664 665 try:
665 666 ar.get()
666 667 except error.CompositeError as _e:
667 668 e = _e
668 669 else:
669 670 self.fail("Should have raised CompositeError")
670 671
671 672 lines = e.render_traceback()
672 673 with capture_output() as io:
673 674 e.print_traceback()
674 675
675 676 self.assertTrue("more exceptions" in lines[-1])
676 677 count = e.tb_limit
677 678
678 679 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
679 680 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
680 681 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
681 682
682 683 @dec.skipif_not_matplotlib
683 684 def test_magic_pylab(self):
684 685 """%pylab works on engines"""
685 686 view = self.client[-1]
686 687 ar = view.execute("%pylab inline")
687 688 # at least check if this raised:
688 689 reply = ar.get(5)
689 690 # include imports, in case user config
690 691 ar = view.execute("plot(rand(100))", silent=False)
691 692 reply = ar.get(5)
692 693 self.assertEqual(len(reply.outputs), 1)
693 694 output = reply.outputs[0]
694 695 self.assertTrue("data" in output)
695 696 data = output['data']
696 697 self.assertTrue("image/png" in data)
697 698
698 699 def test_func_default_func(self):
699 700 """interactively defined function as apply func default"""
700 701 def foo():
701 702 return 'foo'
702 703
703 704 def bar(f=foo):
704 705 return f()
705 706
706 707 view = self.client[-1]
707 708 ar = view.apply_async(bar)
708 709 r = ar.get(10)
709 710 self.assertEqual(r, 'foo')
710 711 def test_data_pub_single(self):
711 712 view = self.client[-1]
712 713 ar = view.execute('\n'.join([
713 714 'from IPython.kernel.zmq.datapub import publish_data',
714 715 'for i in range(5):',
715 716 ' publish_data(dict(i=i))'
716 717 ]), block=False)
717 718 self.assertTrue(isinstance(ar.data, dict))
718 719 ar.get(5)
719 720 self.assertEqual(ar.data, dict(i=4))
720 721
721 722 def test_data_pub(self):
722 723 view = self.client[:]
723 724 ar = view.execute('\n'.join([
724 725 'from IPython.kernel.zmq.datapub import publish_data',
725 726 'for i in range(5):',
726 727 ' publish_data(dict(i=i))'
727 728 ]), block=False)
728 729 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
729 730 ar.get(5)
730 731 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
731 732
732 733 def test_can_list_arg(self):
733 734 """args in lists are canned"""
734 735 view = self.client[-1]
735 736 view['a'] = 128
736 737 rA = pmod.Reference('a')
737 738 ar = view.apply_async(lambda x: x, [rA])
738 739 r = ar.get(5)
739 740 self.assertEqual(r, [128])
740 741
741 742 def test_can_dict_arg(self):
742 743 """args in dicts are canned"""
743 744 view = self.client[-1]
744 745 view['a'] = 128
745 746 rA = pmod.Reference('a')
746 747 ar = view.apply_async(lambda x: x, dict(foo=rA))
747 748 r = ar.get(5)
748 749 self.assertEqual(r, dict(foo=128))
749 750
750 751 def test_can_list_kwarg(self):
751 752 """kwargs in lists are canned"""
752 753 view = self.client[-1]
753 754 view['a'] = 128
754 755 rA = pmod.Reference('a')
755 756 ar = view.apply_async(lambda x=5: x, x=[rA])
756 757 r = ar.get(5)
757 758 self.assertEqual(r, [128])
758 759
759 760 def test_can_dict_kwarg(self):
760 761 """kwargs in dicts are canned"""
761 762 view = self.client[-1]
762 763 view['a'] = 128
763 764 rA = pmod.Reference('a')
764 765 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
765 766 r = ar.get(5)
766 767 self.assertEqual(r, dict(foo=128))
767 768
768 769 def test_map_ref(self):
769 770 """view.map works with references"""
770 771 view = self.client[:]
771 772 ranks = sorted(self.client.ids)
772 773 view.scatter('rank', ranks, flatten=True)
773 774 rrank = pmod.Reference('rank')
774 775
775 776 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
776 777 drank = amr.get(5)
777 778 self.assertEqual(drank, [ r*2 for r in ranks ])
778 779
779 780 def test_nested_getitem_setitem(self):
780 781 """get and set with view['a.b']"""
781 782 view = self.client[-1]
782 783 view.execute('\n'.join([
783 784 'class A(object): pass',
784 785 'a = A()',
785 786 'a.b = 128',
786 787 ]), block=True)
787 788 ra = pmod.Reference('a')
788 789
789 790 r = view.apply_sync(lambda x: x.b, ra)
790 791 self.assertEqual(r, 128)
791 792 self.assertEqual(view['a.b'], 128)
792 793
793 794 view['a.b'] = 0
794 795
795 796 r = view.apply_sync(lambda x: x.b, ra)
796 797 self.assertEqual(r, 0)
797 798 self.assertEqual(view['a.b'], 0)
798 799
799 800 def test_return_namedtuple(self):
800 801 def namedtuplify(x, y):
801 802 from IPython.parallel.tests.test_view import point
802 803 return point(x, y)
803 804
804 805 view = self.client[-1]
805 806 p = view.apply_sync(namedtuplify, 1, 2)
806 807 self.assertEqual(p.x, 1)
807 808 self.assertEqual(p.y, 2)
808 809
809 810 def test_apply_namedtuple(self):
810 811 def echoxy(p):
811 812 return p.y, p.x
812 813
813 814 view = self.client[-1]
814 815 tup = view.apply_sync(echoxy, point(1, 2))
815 816 self.assertEqual(tup, (2,1))
816 817
817 818 def test_sync_imports(self):
818 819 view = self.client[-1]
819 820 with capture_output() as io:
820 821 with view.sync_imports():
821 822 import IPython
822 823 self.assertIn("IPython", io.stdout)
823 824
824 825 @interactive
825 826 def find_ipython():
826 827 return 'IPython' in globals()
827 828
828 829 assert view.apply_sync(find_ipython)
829 830
830 831 def test_sync_imports_quiet(self):
831 832 view = self.client[-1]
832 833 with capture_output() as io:
833 834 with view.sync_imports(quiet=True):
834 835 import IPython
835 836 self.assertEqual(io.stdout, '')
836 837
837 838 @interactive
838 839 def find_ipython():
839 840 return 'IPython' in globals()
840 841
841 842 assert view.apply_sync(find_ipython)
842 843
General Comments 0
You need to be logged in to leave comments. Login now