From ccbe1c2bf6c8eb6c7257de2d120e6b3139bf06b5 2013-07-16 00:06:14 From: Min RK Date: 2013-07-16 00:06:14 Subject: [PATCH] Merge pull request #3649 from minrk/get_dict_single fix AsyncResult.get_dict for single result and add tests for single-result and invalid input (multiple results on one engine). closes #3646 --- diff --git a/IPython/parallel/client/asyncresult.py b/IPython/parallel/client/asyncresult.py index 52b3d72..0b86b7f 100644 --- a/IPython/parallel/client/asyncresult.py +++ b/IPython/parallel/client/asyncresult.py @@ -191,14 +191,21 @@ class AsyncResult(object): """ results = self.get(timeout) + if self._single_result: + results = [results] engine_ids = [ md['engine_id'] for md in self._metadata ] - bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k)) - maxcount = bycount.count(bycount[-1]) - if maxcount > 1: - raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%( - maxcount, bycount[-1])) + + + rdict = {} + for engine_id, result in zip(engine_ids, results): + if engine_id in rdict: + raise ValueError("Cannot build dict, %i jobs ran on engine #%i" % ( + engine_ids.count(engine_id), engine_id) + ) + else: + rdict[engine_id] = result - return dict(zip(engine_ids,results)) + return rdict @property def result(self): diff --git a/IPython/parallel/tests/test_asyncresult.py b/IPython/parallel/tests/test_asyncresult.py index 0c112a9..53d4318 100644 --- a/IPython/parallel/tests/test_asyncresult.py +++ b/IPython/parallel/tests/test_asyncresult.py @@ -35,6 +35,9 @@ def wait(n): time.sleep(n) return n +def echo(x): + return x + class AsyncResultTest(ClusterTestCase): def test_single_result_view(self): @@ -77,6 +80,20 @@ class AsyncResultTest(ClusterTestCase): for eid,r in d.iteritems(): self.assertEqual(r, 5) + def test_get_dict_single(self): + view = self.client[-1] + for v in (range(5), 5, ('abc', 'def'), 'string'): + ar = view.apply_async(echo, v) + self.assertEqual(ar.get(), v) + d = ar.get_dict() + self.assertEqual(d, {view.targets : v}) + + def test_get_dict_bad(self): + ar = self.client[:].apply_async(lambda : 5) + ar2 = self.client[:].apply_async(lambda : 5) + ar = self.client.get_result(ar.msg_ids + ar2.msg_ids) + self.assertRaises(ValueError, ar.get_dict) + def test_list_amr(self): ar = self.client.load_balanced_view().map_async(wait, [0.1]*5) rlist = list(ar)