diff --git a/IPython/parallel/client/asyncresult.py b/IPython/parallel/client/asyncresult.py index 338cbc6..5ddf697 100644 --- a/IPython/parallel/client/asyncresult.py +++ b/IPython/parallel/client/asyncresult.py @@ -82,6 +82,7 @@ class AsyncResult(object): self._targets = targets self._tracker = tracker self._ready = False + self._outputs_ready = False self._success = None self._metadata = [ self._client.metadata.get(id) for id in self.msg_ids ] if len(msg_ids) == 1: @@ -134,6 +135,9 @@ class AsyncResult(object): """Return whether the call has completed.""" if not self._ready: self.wait(0) + elif not self._outputs_ready: + self._wait_for_outputs(0) + return self._ready def wait(self, timeout=-1): @@ -142,6 +146,7 @@ class AsyncResult(object): This method always returns None. """ if self._ready: + self._wait_for_outputs(timeout) return self._ready = self._client.wait(self.msg_ids, timeout) if self._ready: @@ -161,8 +166,10 @@ class AsyncResult(object): else: self._success = True finally: - - self._wait_for_outputs(10) + if timeout is None or timeout < 0: + # cutoff infinite wait at 10s + timeout = 10 + self._wait_for_outputs(timeout) def successful(self): @@ -251,6 +258,7 @@ class AsyncResult(object): return error.collect_exceptions(self._result[key], self._fname) elif isinstance(key, basestring): # metadata proxy *does not* require that results are done + self.wait(0) values = [ md[key] for md in self._metadata ] if self._single_result: return values[0] @@ -377,11 +385,13 @@ class AsyncResult(object): """ return self.timedelta(self.submitted, self.received) - def wait_interactive(self, interval=1., timeout=None): + def wait_interactive(self, interval=1., timeout=-1): """interactive wait, printing progress at regular intervals""" + if timeout is None: + timeout = -1 N = len(self) tic = time.time() - while not self.ready() and (timeout is None or time.time() - tic <= timeout): + while not self.ready() and (timeout < 0 or time.time() - tic <= timeout): self.wait(interval) clear_output() print("%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed), end="") @@ -433,13 +443,21 @@ class AsyncResult(object): def _wait_for_outputs(self, timeout=-1): """wait for the 'status=idle' message that indicates we have all outputs """ - if not self._success: + if self._outputs_ready or not self._success: # don't wait on errors return + + # cast None to -1 for infinite timeout + if timeout is None: + timeout = -1 + tic = time.time() - while not all(md['outputs_ready'] for md in self._metadata): + self._client._flush_iopub(self._client._iopub_socket) + self._outputs_ready = all(md['outputs_ready'] for md in self._metadata) + while not self._outputs_ready: time.sleep(0.01) self._client._flush_iopub(self._client._iopub_socket) + self._outputs_ready = all(md['outputs_ready'] for md in self._metadata) if timeout >= 0 and time.time() > tic + timeout: break @@ -643,9 +661,9 @@ class AsyncHubResult(AsyncResult): so use `AsyncHubResult.wait()` sparingly. """ - def _wait_for_outputs(self, timeout=None): + def _wait_for_outputs(self, timeout=-1): """no-op, because HubResults are never incomplete""" - return + self._outputs_ready = True def wait(self, timeout=-1): """wait for result to complete.""" diff --git a/IPython/parallel/tests/test_asyncresult.py b/IPython/parallel/tests/test_asyncresult.py index 9ffeeef..a09a123 100644 --- a/IPython/parallel/tests/test_asyncresult.py +++ b/IPython/parallel/tests/test_asyncresult.py @@ -18,6 +18,8 @@ Authors: import time +import nose.tools as nt + from IPython.utils.io import capture_output from IPython.parallel.error import TimeoutError @@ -263,5 +265,30 @@ class AsyncResultTest(ClusterTestCase): ar.display_outputs('engine') self.assertEqual(io.stderr, '') self.assertEqual(io.stdout, '') + + def test_await_data(self): + """asking for ar.data flushes outputs""" + self.minimum_engines(1) + v = self.client[-1] + ar = v.execute('\n'.join([ + "import time", + "from IPython.zmq.datapub import publish_data", + "for i in range(5):", + " publish_data(dict(i=i))", + " time.sleep(0.1)", + ]), block=False) + found = set() + tic = time.time() + # timeout after 10s + while time.time() <= tic + 10: + if ar.data: + found.add(ar.data['i']) + if ar.data['i'] == 4: + break + time.sleep(0.05) + + ar.get(5) + nt.assert_in(4, found) + self.assertTrue(len(found) > 1, "should have seen data multiple times, but got: %s" % found)