##// END OF EJS Templates
add ownership to AsyncResult objects...
MinRK -
Show More
@@ -16,16 +16,10 b' from IPython.external.decorator import decorator'
16 16 from IPython.parallel import error
17 17 from IPython.utils.py3compat import string_types
18 18
19 #-----------------------------------------------------------------------------
20 # Functions
21 #-----------------------------------------------------------------------------
22 19
23 20 def _raw_text(s):
24 21 display_pretty(s, raw=True)
25 22
26 #-----------------------------------------------------------------------------
27 # Classes
28 #-----------------------------------------------------------------------------
29 23
30 24 # global empty tracker that's always done:
31 25 finished_tracker = MessageTracker()
@@ -48,8 +42,11 b' class AsyncResult(object):'
48 42 _targets = None
49 43 _tracker = None
50 44 _single_result = False
45 owner = False,
51 46
52 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
47 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None,
48 owner=False,
49 ):
53 50 if isinstance(msg_ids, string_types):
54 51 # always a list
55 52 msg_ids = [msg_ids]
@@ -64,6 +61,7 b' class AsyncResult(object):'
64 61 self._fname=fname
65 62 self._targets = targets
66 63 self._tracker = tracker
64 self.owner = owner
67 65
68 66 self._ready = False
69 67 self._outputs_ready = False
@@ -150,6 +148,12 b' class AsyncResult(object):'
150 148 # cutoff infinite wait at 10s
151 149 timeout = 10
152 150 self._wait_for_outputs(timeout)
151
152 if self.owner:
153
154 self._metadata = [self._client.metadata.pop(mid) for mid in self.msg_ids]
155 [self._client.results.pop(mid) for mid in self.msg_ids]
156
153 157
154 158
155 159 def successful(self):
@@ -691,5 +695,9 b' class AsyncHubResult(AsyncResult):'
691 695 self._success = True
692 696 finally:
693 697 self._metadata = [self._client.metadata[mid] for mid in self.msg_ids]
698 if self.owner:
699 [self._client.metadata.pop(mid) for mid in self.msg_ids]
700 [self._client.results.pop(mid) for mid in self.msg_ids]
701
694 702
695 703 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
@@ -1358,7 +1358,7 b' class Client(HasTraits):'
1358 1358 #--------------------------------------------------------------------------
1359 1359
1360 1360 @spin_first
1361 def get_result(self, indices_or_msg_ids=None, block=None):
1361 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
1362 1362 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1363 1363
1364 1364 If the client already has the results, no request to the Hub will be made.
@@ -1384,6 +1384,11 b' class Client(HasTraits):'
1384 1384
1385 1385 block : bool
1386 1386 Whether to wait for the result to be done
1387 owner : bool [default: True]
1388 Whether this AsyncResult should own the result.
1389 If so, calling `ar.get()` will remove data from the
1390 client's result and metadata cache.
1391 There should only be one owner of any given msg_id.
1387 1392
1388 1393 Returns
1389 1394 -------
@@ -1421,9 +1426,9 b' class Client(HasTraits):'
1421 1426 theids = theids[0]
1422 1427
1423 1428 if remote_ids:
1424 ar = AsyncHubResult(self, msg_ids=theids)
1429 ar = AsyncHubResult(self, msg_ids=theids, owner=owner)
1425 1430 else:
1426 ar = AsyncResult(self, msg_ids=theids)
1431 ar = AsyncResult(self, msg_ids=theids, owner=owner)
1427 1432
1428 1433 if block:
1429 1434 ar.wait()
@@ -1703,8 +1708,8 b' class Client(HasTraits):'
1703 1708 if still_outstanding:
1704 1709 raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding)
1705 1710 for mid in msg_ids:
1706 self.results.pop(mid)
1707 self.metadata.pop(mid)
1711 self.results.pop(mid, None)
1712 self.metadata.pop(mid, None)
1708 1713
1709 1714
1710 1715 @spin_first
@@ -1,20 +1,9 b''
1 """Views of remote engines.
1 """Views of remote engines."""
2 2
3 Authors:
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 5
5 * Min RK
6 """
7 6 from __future__ import print_function
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
14
15 #-----------------------------------------------------------------------------
16 # Imports
17 #-----------------------------------------------------------------------------
18 7
19 8 import imp
20 9 import sys
@@ -315,7 +304,7 b' class View(HasTraits):'
315 304 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
316 305
317 306 @spin_after
318 def get_result(self, indices_or_msg_ids=None):
307 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
319 308 """return one or more results, specified by history index or msg_id.
320 309
321 310 See :meth:`IPython.parallel.client.client.Client.get_result` for details.
@@ -330,7 +319,7 b' class View(HasTraits):'
330 319 for i,index in enumerate(indices_or_msg_ids):
331 320 if isinstance(index, int):
332 321 indices_or_msg_ids[i] = self.history[index]
333 return self.client.get_result(indices_or_msg_ids)
322 return self.client.get_result(indices_or_msg_ids, block=block, owner=owner)
334 323
335 324 #-------------------------------------------------------------------
336 325 # Map
@@ -577,7 +566,9 b' class DirectView(View):'
577 566 if isinstance(targets, int):
578 567 msg_ids = msg_ids[0]
579 568 tracker = None if track is False else zmq.MessageTracker(*trackers)
580 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets, tracker=tracker)
569 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets,
570 tracker=tracker, owner=True,
571 )
581 572 if block:
582 573 try:
583 574 return ar.get()
@@ -656,7 +647,7 b' class DirectView(View):'
656 647 msg_ids.append(msg['header']['msg_id'])
657 648 if isinstance(targets, int):
658 649 msg_ids = msg_ids[0]
659 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets)
650 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True)
660 651 if block:
661 652 try:
662 653 ar.get()
@@ -774,7 +765,9 b' class DirectView(View):'
774 765 else:
775 766 tracker = None
776 767
777 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
768 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets,
769 tracker=tracker, owner=True,
770 )
778 771 if block:
779 772 r.wait()
780 773 else:
@@ -1057,8 +1050,9 b' class LoadBalancedView(View):'
1057 1050 metadata=metadata)
1058 1051 tracker = None if track is False else msg['tracker']
1059 1052
1060 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
1061
1053 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f),
1054 targets=None, tracker=tracker, owner=True,
1055 )
1062 1056 if block:
1063 1057 try:
1064 1058 return ar.get()
@@ -1,20 +1,7 b''
1 """Tests for asyncresult.py
1 """Tests for asyncresult.py"""
2 2
3 Authors:
4
5 * Min RK
6 """
7
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
10 #
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
14
15 #-------------------------------------------------------------------------------
16 # Imports
17 #-------------------------------------------------------------------------------
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
18 5
19 6 import time
20 7
@@ -323,5 +310,33 b' class AsyncResultTest(ClusterTestCase):'
323 310 ar = dv.apply_async(lambda : 5)
324 311 self.assertEqual(ar.get(10), [5])
325 312 self.client._build_targets = save_build
313
314 def test_owner_pop(self):
315 self.minimum_engines(1)
316
317 view = self.client[-1]
318 ar = view.apply_async(lambda : 1)
319 ar.get()
320 msg_id = ar.msg_ids[0]
321 self.assertNotIn(msg_id, self.client.results)
322 self.assertNotIn(msg_id, self.client.metadata)
323
324 def test_non_owner(self):
325 self.minimum_engines(1)
326
327 view = self.client[-1]
328 ar = view.apply_async(lambda : 1)
329 ar.owner = False
330 ar.get()
331 msg_id = ar.msg_ids[0]
332 self.assertIn(msg_id, self.client.results)
333 self.assertIn(msg_id, self.client.metadata)
334 ar2 = self.client.get_result(msg_id, owner=True)
335 self.assertIs(type(ar2), type(ar))
336 self.assertTrue(ar2.owner)
337 self.assertEqual(ar.get(), ar2.get())
338 ar2.get()
339 self.assertNotIn(msg_id, self.client.results)
340 self.assertNotIn(msg_id, self.client.metadata)
326 341
327 342
@@ -143,11 +143,12 b' class TestClient(ClusterTestCase):'
143 143 ar = c[t].apply_async(wait, 1)
144 144 # give the monitor time to notice the message
145 145 time.sleep(.25)
146 ahr = self.client.get_result(ar.msg_ids[0])
147 self.assertTrue(isinstance(ahr, AsyncHubResult))
146 ahr = self.client.get_result(ar.msg_ids[0], owner=False)
147 self.assertIsInstance(ahr, AsyncHubResult)
148 148 self.assertEqual(ahr.get(), ar.get())
149 149 ar2 = self.client.get_result(ar.msg_ids[0])
150 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 self.assertNotIsInstance(ar2, AsyncHubResult)
151 self.assertEqual(ahr.get(), ar2.get())
151 152 c.close()
152 153
153 154 def test_get_execute_result(self):
@@ -162,11 +163,12 b' class TestClient(ClusterTestCase):'
162 163 ar = c[t].execute("import time; time.sleep(1)", silent=False)
163 164 # give the monitor time to notice the message
164 165 time.sleep(.25)
165 ahr = self.client.get_result(ar.msg_ids[0])
166 self.assertTrue(isinstance(ahr, AsyncHubResult))
166 ahr = self.client.get_result(ar.msg_ids[0], owner=False)
167 self.assertIsInstance(ahr, AsyncHubResult)
167 168 self.assertEqual(ahr.get().execute_result, ar.get().execute_result)
168 169 ar2 = self.client.get_result(ar.msg_ids[0])
169 self.assertFalse(isinstance(ar2, AsyncHubResult))
170 self.assertNotIsInstance(ar2, AsyncHubResult)
171 self.assertEqual(ahr.get(), ar2.get())
170 172 c.close()
171 173
172 174 def test_ids_list(self):
@@ -450,6 +452,7 b' class TestClient(ClusterTestCase):'
450 452 v = self.client[-1]
451 453 ar = v.apply_async(lambda : 1)
452 454 msg_id = ar.msg_ids[0]
455 ar.owner = False
453 456 ar.get()
454 457 self._wait_for_idle()
455 458 ar2 = v.apply_async(time.sleep, 1)
@@ -142,11 +142,12 b' class TestView(ClusterTestCase):'
142 142 ar = v.apply_async(wait, 1)
143 143 # give the monitor time to notice the message
144 144 time.sleep(.25)
145 ahr = v2.get_result(ar.msg_ids[0])
146 self.assertTrue(isinstance(ahr, AsyncHubResult))
145 ahr = v2.get_result(ar.msg_ids[0], owner=False)
146 self.assertIsInstance(ahr, AsyncHubResult)
147 147 self.assertEqual(ahr.get(), ar.get())
148 148 ar2 = v2.get_result(ar.msg_ids[0])
149 self.assertFalse(isinstance(ar2, AsyncHubResult))
149 self.assertNotIsInstance(ar2, AsyncHubResult)
150 self.assertEqual(ahr.get(), ar2.get())
150 151 c.spin()
151 152 c.close()
152 153
General Comments 0
You need to be logged in to leave comments. Login now