Show More
@@ -73,6 +73,9 b' class AsyncResult(object):' | |||||
73 | if isinstance(msg_ids, basestring): |
|
73 | if isinstance(msg_ids, basestring): | |
74 | # always a list |
|
74 | # always a list | |
75 | msg_ids = [msg_ids] |
|
75 | msg_ids = [msg_ids] | |
|
76 | self._single_result = True | |||
|
77 | else: | |||
|
78 | self._single_result = False | |||
76 | if tracker is None: |
|
79 | if tracker is None: | |
77 | # default to always done |
|
80 | # default to always done | |
78 | tracker = finished_tracker |
|
81 | tracker = finished_tracker | |
@@ -81,14 +84,11 b' class AsyncResult(object):' | |||||
81 | self._fname=fname |
|
84 | self._fname=fname | |
82 | self._targets = targets |
|
85 | self._targets = targets | |
83 | self._tracker = tracker |
|
86 | self._tracker = tracker | |
|
87 | ||||
84 | self._ready = False |
|
88 | self._ready = False | |
85 | self._outputs_ready = False |
|
89 | self._outputs_ready = False | |
86 | self._success = None |
|
90 | self._success = None | |
87 | self._metadata = [ self._client.metadata.get(id) for id in self.msg_ids ] |
|
91 | self._metadata = [ self._client.metadata.get(id) for id in self.msg_ids ] | |
88 | if len(msg_ids) == 1: |
|
|||
89 | self._single_result = not isinstance(targets, (list, tuple)) |
|
|||
90 | else: |
|
|||
91 | self._single_result = False |
|
|||
92 |
|
92 | |||
93 | def __repr__(self): |
|
93 | def __repr__(self): | |
94 | if self._ready: |
|
94 | if self._ready: |
@@ -549,8 +549,8 b' class DirectView(View):' | |||||
549 | block = self.block if block is None else block |
|
549 | block = self.block if block is None else block | |
550 | track = self.track if track is None else track |
|
550 | track = self.track if track is None else track | |
551 | targets = self.targets if targets is None else targets |
|
551 | targets = self.targets if targets is None else targets | |
552 |
|
552 | |||
553 |
_idents = self.client._build_targets(targets) |
|
553 | _idents, _targets = self.client._build_targets(targets) | |
554 | msg_ids = [] |
|
554 | msg_ids = [] | |
555 | trackers = [] |
|
555 | trackers = [] | |
556 | for ident in _idents: |
|
556 | for ident in _idents: | |
@@ -559,8 +559,10 b' class DirectView(View):' | |||||
559 | if track: |
|
559 | if track: | |
560 | trackers.append(msg['tracker']) |
|
560 | trackers.append(msg['tracker']) | |
561 | msg_ids.append(msg['header']['msg_id']) |
|
561 | msg_ids.append(msg['header']['msg_id']) | |
|
562 | if isinstance(targets, int): | |||
|
563 | msg_ids = msg_ids[0] | |||
562 | tracker = None if track is False else zmq.MessageTracker(*trackers) |
|
564 | tracker = None if track is False else zmq.MessageTracker(*trackers) | |
563 | ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker) |
|
565 | ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets, tracker=tracker) | |
564 | if block: |
|
566 | if block: | |
565 | try: |
|
567 | try: | |
566 | return ar.get() |
|
568 | return ar.get() | |
@@ -631,13 +633,15 b' class DirectView(View):' | |||||
631 | block = self.block if block is None else block |
|
633 | block = self.block if block is None else block | |
632 | targets = self.targets if targets is None else targets |
|
634 | targets = self.targets if targets is None else targets | |
633 |
|
635 | |||
634 |
_idents = self.client._build_targets(targets) |
|
636 | _idents, _targets = self.client._build_targets(targets) | |
635 | msg_ids = [] |
|
637 | msg_ids = [] | |
636 | trackers = [] |
|
638 | trackers = [] | |
637 | for ident in _idents: |
|
639 | for ident in _idents: | |
638 | msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident) |
|
640 | msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident) | |
639 | msg_ids.append(msg['header']['msg_id']) |
|
641 | msg_ids.append(msg['header']['msg_id']) | |
640 | ar = AsyncResult(self.client, msg_ids, fname='execute', targets=targets) |
|
642 | if isinstance(targets, int): | |
|
643 | msg_ids = msg_ids[0] | |||
|
644 | ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets) | |||
641 | if block: |
|
645 | if block: | |
642 | try: |
|
646 | try: | |
643 | ar.get() |
|
647 | ar.get() |
@@ -308,4 +308,18 b' class AsyncResultTest(ClusterTestCase):' | |||||
308 | ar.get(5) |
|
308 | ar.get(5) | |
309 | nt.assert_in(4, found) |
|
309 | nt.assert_in(4, found) | |
310 | self.assertTrue(len(found) > 1, "should have seen data multiple times, but got: %s" % found) |
|
310 | self.assertTrue(len(found) > 1, "should have seen data multiple times, but got: %s" % found) | |
|
311 | ||||
|
312 | def test_not_single_result(self): | |||
|
313 | save_build = self.client._build_targets | |||
|
314 | def single_engine(*a, **kw): | |||
|
315 | idents, targets = save_build(*a, **kw) | |||
|
316 | return idents[:1], targets[:1] | |||
|
317 | ids = single_engine('all')[1] | |||
|
318 | self.client._build_targets = single_engine | |||
|
319 | for targets in ('all', None, ids): | |||
|
320 | dv = self.client.direct_view(targets=targets) | |||
|
321 | ar = dv.apply_async(lambda : 5) | |||
|
322 | self.assertEqual(ar.get(10), [5]) | |||
|
323 | self.client._build_targets = save_build | |||
|
324 | ||||
311 |
|
325 |
General Comments 0
You need to be logged in to leave comments.
Login now