diff --git a/IPython/parallel/client/asyncresult.py b/IPython/parallel/client/asyncresult.py index a8111b4..c5a5e52 100644 --- a/IPython/parallel/client/asyncresult.py +++ b/IPython/parallel/client/asyncresult.py @@ -16,16 +16,10 @@ from IPython.external.decorator import decorator from IPython.parallel import error from IPython.utils.py3compat import string_types -#----------------------------------------------------------------------------- -# Functions -#----------------------------------------------------------------------------- def _raw_text(s): display_pretty(s, raw=True) -#----------------------------------------------------------------------------- -# Classes -#----------------------------------------------------------------------------- # global empty tracker that's always done: finished_tracker = MessageTracker() @@ -48,8 +42,11 @@ class AsyncResult(object): _targets = None _tracker = None _single_result = False + owner = False, - def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None): + def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None, + owner=False, + ): if isinstance(msg_ids, string_types): # always a list msg_ids = [msg_ids] @@ -64,6 +61,7 @@ class AsyncResult(object): self._fname=fname self._targets = targets self._tracker = tracker + self.owner = owner self._ready = False self._outputs_ready = False @@ -150,6 +148,12 @@ class AsyncResult(object): # cutoff infinite wait at 10s timeout = 10 self._wait_for_outputs(timeout) + + if self.owner: + + self._metadata = [self._client.metadata.pop(mid) for mid in self.msg_ids] + [self._client.results.pop(mid) for mid in self.msg_ids] + def successful(self): @@ -691,5 +695,9 @@ class AsyncHubResult(AsyncResult): self._success = True finally: self._metadata = [self._client.metadata[mid] for mid in self.msg_ids] + if self.owner: + [self._client.metadata.pop(mid) for mid in self.msg_ids] + [self._client.results.pop(mid) for mid in self.msg_ids] + __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult'] diff --git a/IPython/parallel/client/client.py b/IPython/parallel/client/client.py index fe6203d..87ad3aa 100644 --- a/IPython/parallel/client/client.py +++ b/IPython/parallel/client/client.py @@ -1358,7 +1358,7 @@ class Client(HasTraits): #-------------------------------------------------------------------------- @spin_first - def get_result(self, indices_or_msg_ids=None, block=None): + def get_result(self, indices_or_msg_ids=None, block=None, owner=True): """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object. If the client already has the results, no request to the Hub will be made. @@ -1384,6 +1384,11 @@ class Client(HasTraits): block : bool Whether to wait for the result to be done + owner : bool [default: True] + Whether this AsyncResult should own the result. + If so, calling `ar.get()` will remove data from the + client's result and metadata cache. + There should only be one owner of any given msg_id. Returns ------- @@ -1421,9 +1426,9 @@ class Client(HasTraits): theids = theids[0] if remote_ids: - ar = AsyncHubResult(self, msg_ids=theids) + ar = AsyncHubResult(self, msg_ids=theids, owner=owner) else: - ar = AsyncResult(self, msg_ids=theids) + ar = AsyncResult(self, msg_ids=theids, owner=owner) if block: ar.wait() @@ -1703,8 +1708,8 @@ class Client(HasTraits): if still_outstanding: raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding) for mid in msg_ids: - self.results.pop(mid) - self.metadata.pop(mid) + self.results.pop(mid, None) + self.metadata.pop(mid, None) @spin_first diff --git a/IPython/parallel/client/view.py b/IPython/parallel/client/view.py index bbd474a..56a2e52 100644 --- a/IPython/parallel/client/view.py +++ b/IPython/parallel/client/view.py @@ -1,20 +1,9 @@ -"""Views of remote engines. +"""Views of remote engines.""" -Authors: +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. -* Min RK -""" from __future__ import print_function -#----------------------------------------------------------------------------- -# Copyright (C) 2010-2011 The IPython Development Team -# -# Distributed under the terms of the BSD License. The full license is in -# the file COPYING, distributed as part of this software. -#----------------------------------------------------------------------------- - -#----------------------------------------------------------------------------- -# Imports -#----------------------------------------------------------------------------- import imp import sys @@ -315,7 +304,7 @@ class View(HasTraits): return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block) @spin_after - def get_result(self, indices_or_msg_ids=None): + def get_result(self, indices_or_msg_ids=None, block=None, owner=True): """return one or more results, specified by history index or msg_id. See :meth:`IPython.parallel.client.client.Client.get_result` for details. @@ -330,7 +319,7 @@ class View(HasTraits): for i,index in enumerate(indices_or_msg_ids): if isinstance(index, int): indices_or_msg_ids[i] = self.history[index] - return self.client.get_result(indices_or_msg_ids) + return self.client.get_result(indices_or_msg_ids, block=block, owner=owner) #------------------------------------------------------------------- # Map @@ -577,7 +566,9 @@ class DirectView(View): if isinstance(targets, int): msg_ids = msg_ids[0] tracker = None if track is False else zmq.MessageTracker(*trackers) - ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets, tracker=tracker) + ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets, + tracker=tracker, owner=True, + ) if block: try: return ar.get() @@ -656,7 +647,7 @@ class DirectView(View): msg_ids.append(msg['header']['msg_id']) if isinstance(targets, int): msg_ids = msg_ids[0] - ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets) + ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True) if block: try: ar.get() @@ -774,7 +765,9 @@ class DirectView(View): else: tracker = None - r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker) + r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, + tracker=tracker, owner=True, + ) if block: r.wait() else: @@ -1057,8 +1050,9 @@ class LoadBalancedView(View): metadata=metadata) tracker = None if track is False else msg['tracker'] - ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker) - + ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), + targets=None, tracker=tracker, owner=True, + ) if block: try: return ar.get() diff --git a/IPython/parallel/tests/test_asyncresult.py b/IPython/parallel/tests/test_asyncresult.py index f330a52..ddb458d 100644 --- a/IPython/parallel/tests/test_asyncresult.py +++ b/IPython/parallel/tests/test_asyncresult.py @@ -1,20 +1,7 @@ -"""Tests for asyncresult.py +"""Tests for asyncresult.py""" -Authors: - -* Min RK -""" - -#------------------------------------------------------------------------------- -# Copyright (C) 2011 The IPython Development Team -# -# Distributed under the terms of the BSD License. The full license is in -# the file COPYING, distributed as part of this software. -#------------------------------------------------------------------------------- - -#------------------------------------------------------------------------------- -# Imports -#------------------------------------------------------------------------------- +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. import time @@ -323,5 +310,33 @@ class AsyncResultTest(ClusterTestCase): ar = dv.apply_async(lambda : 5) self.assertEqual(ar.get(10), [5]) self.client._build_targets = save_build + + def test_owner_pop(self): + self.minimum_engines(1) + + view = self.client[-1] + ar = view.apply_async(lambda : 1) + ar.get() + msg_id = ar.msg_ids[0] + self.assertNotIn(msg_id, self.client.results) + self.assertNotIn(msg_id, self.client.metadata) + + def test_non_owner(self): + self.minimum_engines(1) + + view = self.client[-1] + ar = view.apply_async(lambda : 1) + ar.owner = False + ar.get() + msg_id = ar.msg_ids[0] + self.assertIn(msg_id, self.client.results) + self.assertIn(msg_id, self.client.metadata) + ar2 = self.client.get_result(msg_id, owner=True) + self.assertIs(type(ar2), type(ar)) + self.assertTrue(ar2.owner) + self.assertEqual(ar.get(), ar2.get()) + ar2.get() + self.assertNotIn(msg_id, self.client.results) + self.assertNotIn(msg_id, self.client.metadata) diff --git a/IPython/parallel/tests/test_client.py b/IPython/parallel/tests/test_client.py index e19da91..02dd290 100644 --- a/IPython/parallel/tests/test_client.py +++ b/IPython/parallel/tests/test_client.py @@ -143,11 +143,12 @@ class TestClient(ClusterTestCase): ar = c[t].apply_async(wait, 1) # give the monitor time to notice the message time.sleep(.25) - ahr = self.client.get_result(ar.msg_ids[0]) - self.assertTrue(isinstance(ahr, AsyncHubResult)) + ahr = self.client.get_result(ar.msg_ids[0], owner=False) + self.assertIsInstance(ahr, AsyncHubResult) self.assertEqual(ahr.get(), ar.get()) ar2 = self.client.get_result(ar.msg_ids[0]) - self.assertFalse(isinstance(ar2, AsyncHubResult)) + self.assertNotIsInstance(ar2, AsyncHubResult) + self.assertEqual(ahr.get(), ar2.get()) c.close() def test_get_execute_result(self): @@ -162,11 +163,12 @@ class TestClient(ClusterTestCase): ar = c[t].execute("import time; time.sleep(1)", silent=False) # give the monitor time to notice the message time.sleep(.25) - ahr = self.client.get_result(ar.msg_ids[0]) - self.assertTrue(isinstance(ahr, AsyncHubResult)) + ahr = self.client.get_result(ar.msg_ids[0], owner=False) + self.assertIsInstance(ahr, AsyncHubResult) self.assertEqual(ahr.get().execute_result, ar.get().execute_result) ar2 = self.client.get_result(ar.msg_ids[0]) - self.assertFalse(isinstance(ar2, AsyncHubResult)) + self.assertNotIsInstance(ar2, AsyncHubResult) + self.assertEqual(ahr.get(), ar2.get()) c.close() def test_ids_list(self): @@ -450,6 +452,7 @@ class TestClient(ClusterTestCase): v = self.client[-1] ar = v.apply_async(lambda : 1) msg_id = ar.msg_ids[0] + ar.owner = False ar.get() self._wait_for_idle() ar2 = v.apply_async(time.sleep, 1) diff --git a/IPython/parallel/tests/test_view.py b/IPython/parallel/tests/test_view.py index 2bf499b..460f13e 100644 --- a/IPython/parallel/tests/test_view.py +++ b/IPython/parallel/tests/test_view.py @@ -142,11 +142,12 @@ class TestView(ClusterTestCase): ar = v.apply_async(wait, 1) # give the monitor time to notice the message time.sleep(.25) - ahr = v2.get_result(ar.msg_ids[0]) - self.assertTrue(isinstance(ahr, AsyncHubResult)) + ahr = v2.get_result(ar.msg_ids[0], owner=False) + self.assertIsInstance(ahr, AsyncHubResult) self.assertEqual(ahr.get(), ar.get()) ar2 = v2.get_result(ar.msg_ids[0]) - self.assertFalse(isinstance(ar2, AsyncHubResult)) + self.assertNotIsInstance(ar2, AsyncHubResult) + self.assertEqual(ahr.get(), ar2.get()) c.spin() c.close()