diff --git a/IPython/parallel/tests/__init__.py b/IPython/parallel/tests/__init__.py index 83aba19..cb7d960 100644 --- a/IPython/parallel/tests/__init__.py +++ b/IPython/parallel/tests/__init__.py @@ -68,9 +68,18 @@ def setup(): time.sleep(0.1) add_engines(1) -def add_engines(n=1, profile='iptest'): +def add_engines(n=1, profile='iptest', total=False): + """add a number of engines to a given profile. + + If total is True, then already running engines are counted, and only + the additional engines necessary (if any) are started. + """ rc = Client(profile=profile) base = len(rc) + + if total: + n = max(n - base, 0) + eps = [] for i in range(n): ep = TestProcessLauncher() diff --git a/IPython/parallel/tests/clienttest.py b/IPython/parallel/tests/clienttest.py index 07635c2..b902293 100644 --- a/IPython/parallel/tests/clienttest.py +++ b/IPython/parallel/tests/clienttest.py @@ -80,6 +80,13 @@ class ClusterTestCase(BaseZMQTestCase): self.engines.extend(add_engines(n)) if block: self.wait_on_engines() + + def minimum_engines(self, n=1, block=True): + """add engines until there are at least n connected""" + self.engines.extend(add_engines(n, total=True)) + if block: + self.wait_on_engines() + def wait_on_engines(self, timeout=5): """wait for our engines to connect.""" diff --git a/IPython/parallel/tests/test_asyncresult.py b/IPython/parallel/tests/test_asyncresult.py index 30dae4b..d9f07a1 100644 --- a/IPython/parallel/tests/test_asyncresult.py +++ b/IPython/parallel/tests/test_asyncresult.py @@ -23,7 +23,7 @@ from IPython.parallel.tests import add_engines from .clienttest import ClusterTestCase def setup(): - add_engines(2) + add_engines(2, total=True) def wait(n): import time diff --git a/IPython/parallel/tests/test_client.py b/IPython/parallel/tests/test_client.py index db58524..bd6ab3a 100644 --- a/IPython/parallel/tests/test_client.py +++ b/IPython/parallel/tests/test_client.py @@ -32,18 +32,18 @@ from IPython.parallel import LoadBalancedView, DirectView from clienttest import ClusterTestCase, segfault, wait, add_engines def setup(): - add_engines(4) + add_engines(4, total=True) class TestClient(ClusterTestCase): def test_ids(self): n = len(self.client.ids) - self.add_engines(3) - self.assertEquals(len(self.client.ids), n+3) + self.add_engines(2) + self.assertEquals(len(self.client.ids), n+2) def test_view_indexing(self): """test index access for views""" - self.add_engines(2) + self.minimum_engines(4) targets = self.client._build_targets('all')[-1] v = self.client[:] self.assertEquals(v.targets, targets) @@ -98,7 +98,7 @@ class TestClient(ClusterTestCase): ref = [ double(x) for x in seq ] # add some engines, which should be used - self.add_engines(2) + self.add_engines(1) n1 = len(self.client.ids) # simple apply @@ -131,7 +131,7 @@ class TestClient(ClusterTestCase): def test_clear(self): """test clear behavior""" - # self.add_engines(2) + self.minimum_engines(2) v = self.client[:] v.block=True v.push(dict(a=5)) @@ -142,13 +142,11 @@ class TestClient(ClusterTestCase): self.assertRaisesRemote(NameError, self.client[id0].get, 'a') self.client.clear(block=True) for i in self.client.ids: - # print i self.assertRaisesRemote(NameError, self.client[i].get, 'a') def test_get_result(self): """test getting results from the Hub.""" c = clientmod.Client(profile='iptest') - # self.add_engines(1) t = c.ids[-1] ar = c[t].apply_async(wait, 1) # give the monitor time to notice the message @@ -162,7 +160,6 @@ class TestClient(ClusterTestCase): def test_ids_list(self): """test client.ids""" - # self.add_engines(2) ids = self.client.ids self.assertEquals(ids, self.client._ids) self.assertFalse(ids is self.client._ids) @@ -170,7 +167,6 @@ class TestClient(ClusterTestCase): self.assertNotEquals(ids, self.client._ids) def test_queue_status(self): - # self.addEngine(4) ids = self.client.ids id0 = ids[0] qs = self.client.queue_status(targets=id0) @@ -187,7 +183,6 @@ class TestClient(ClusterTestCase): self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks']) def test_shutdown(self): - # self.addEngine(4) ids = self.client.ids id0 = ids[0] self.client.shutdown(id0, block=True) diff --git a/IPython/parallel/tests/test_dependency.py b/IPython/parallel/tests/test_dependency.py index 8efdf48..1fb1e50 100644 --- a/IPython/parallel/tests/test_dependency.py +++ b/IPython/parallel/tests/test_dependency.py @@ -30,7 +30,7 @@ from IPython.parallel.tests import add_engines from .clienttest import ClusterTestCase def setup(): - add_engines(1) + add_engines(1, total=True) @pmod.require('time') def wait(n): diff --git a/IPython/parallel/tests/test_lbview.py b/IPython/parallel/tests/test_lbview.py index 00756ef..f1774ac 100644 --- a/IPython/parallel/tests/test_lbview.py +++ b/IPython/parallel/tests/test_lbview.py @@ -30,7 +30,7 @@ from IPython.parallel.tests import add_engines from .clienttest import ClusterTestCase, crash, wait, skip_without def setup(): - add_engines(3) + add_engines(3, total=True) class TestLoadBalancedView(ClusterTestCase): @@ -120,7 +120,6 @@ class TestLoadBalancedView(ClusterTestCase): self.assertRaises(error.TaskAborted, ar3.get) def test_retries(self): - add_engines(3) view = self.view view.timeout = 1 # prevent hang if this doesn't behave def fail(): @@ -138,8 +137,7 @@ class TestLoadBalancedView(ClusterTestCase): self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1) def test_impossible_dependency(self): - if len(self.client) < 2: - add_engines(2) + self.minimum_engines(2) view = self.client.load_balanced_view() ar1 = view.apply_async(lambda : 1) ar1.get() diff --git a/IPython/parallel/tests/test_view.py b/IPython/parallel/tests/test_view.py index c2908bb..3f1d335 100644 --- a/IPython/parallel/tests/test_view.py +++ b/IPython/parallel/tests/test_view.py @@ -37,7 +37,7 @@ from IPython.parallel.tests import add_engines from .clienttest import ClusterTestCase, crash, wait, skip_without def setup(): - add_engines(3) + add_engines(3, total=True) class TestView(ClusterTestCase): @@ -296,7 +296,7 @@ class TestView(ClusterTestCase): def test_abort_all(self): """view.abort() aborts all outstanding tasks""" view = self.client[-1] - ars = [ view.apply_async(time.sleep, 1) for i in range(10) ] + ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ] view.abort() view.wait(timeout=5) for ar in ars[5:]: