From 4f574e163c57914da5f32b3b026a1feb54bb538c 2011-04-08 00:38:21 From: MinRK Date: 2011-04-08 00:38:21 Subject: [PATCH] testing fixes --- diff --git a/IPython/zmq/parallel/asyncresult.py b/IPython/zmq/parallel/asyncresult.py index 0cfb16c..67daa14 100644 --- a/IPython/zmq/parallel/asyncresult.py +++ b/IPython/zmq/parallel/asyncresult.py @@ -35,7 +35,7 @@ class AsyncResult(object): msg_ids = None - def __init__(self, client, msg_ids, fname=''): + def __init__(self, client, msg_ids, fname='unknown'): self._client = client if isinstance(msg_ids, basestring): msg_ids = [msg_ids] @@ -265,7 +265,7 @@ class AsyncHubResult(AsyncResult): else: rdict = self._client.result_status(remote_ids, status_only=False) pending = rdict['pending'] - while pending and time.time() < start+timeout: + while pending and (timeout < 0 or time.time() < start+timeout): rdict = self._client.result_status(remote_ids, status_only=False) pending = rdict['pending'] if pending: diff --git a/IPython/zmq/parallel/client.py b/IPython/zmq/parallel/client.py index 78734b9..e02806d 100644 --- a/IPython/zmq/parallel/client.py +++ b/IPython/zmq/parallel/client.py @@ -360,16 +360,17 @@ class Client(HasTraits): if cluster_dir is not None: try: self._cd = ClusterDir.find_cluster_dir(cluster_dir) + return except ClusterDirError: pass elif profile is not None: try: self._cd = ClusterDir.find_cluster_dir_by_profile( ipython_dir, profile) + return except ClusterDirError: pass - else: - self._cd = None + self._cd = None @property def ids(self): @@ -489,9 +490,9 @@ class Client(HasTraits): """unwrap exception, and remap engineid to int.""" e = ss.unwrap_exception(content) if e.engine_info: - e_uuid = e.engine_info['engineid'] + e_uuid = e.engine_info['engine_uuid'] eid = self._engines[e_uuid] - e.engine_info['engineid'] = eid + e.engine_info['engine_id'] = eid return e def _register_engine(self, msg): @@ -1338,11 +1339,11 @@ class Client(HasTraits): be lists of msg_ids that are incomplete or complete. If `status_only` is False, then completed results will be keyed by their `msg_id`. """ - if not isinstance(indices_or_msg_ids, (list,tuple)): - indices_or_msg_ids = [indices_or_msg_ids] + if not isinstance(msg_ids, (list,tuple)): + indices_or_msg_ids = [msg_ids] theids = [] - for msg_id in indices_or_msg_ids: + for msg_id in msg_ids: if isinstance(msg_id, int): msg_id = self.history[msg_id] if not isinstance(msg_id, basestring): diff --git a/IPython/zmq/parallel/error.py b/IPython/zmq/parallel/error.py index fea4057..cf48621 100644 --- a/IPython/zmq/parallel/error.py +++ b/IPython/zmq/parallel/error.py @@ -175,7 +175,7 @@ class RemoteError(KernelError): self.args=(ename, evalue) def __repr__(self): - engineid = self.engine_info.get('engineid', ' ') + engineid = self.engine_info.get('engine_id', ' ') return ""%(engineid, self.ename, self.evalue) def __str__(self): diff --git a/IPython/zmq/parallel/hub.py b/IPython/zmq/parallel/hub.py index 1c5b9ec..a992817 100755 --- a/IPython/zmq/parallel/hub.py +++ b/IPython/zmq/parallel/hub.py @@ -702,7 +702,7 @@ class Hub(LoggingFactory): self.log.error("task::invalid task tracking message", exc_info=True) return content = msg['content'] - print (content) + # print (content) msg_id = content['msg_id'] engine_uuid = content['engine_id'] eid = self.by_ident[engine_uuid] @@ -728,7 +728,7 @@ class Hub(LoggingFactory): def save_iopub_message(self, topics, msg): """save an iopub message into the db""" - print (topics) + # print (topics) try: msg = self.session.unpack_message(msg, content=True) except: diff --git a/IPython/zmq/parallel/remotefunction.py b/IPython/zmq/parallel/remotefunction.py index 8169bcb..2cdd716 100644 --- a/IPython/zmq/parallel/remotefunction.py +++ b/IPython/zmq/parallel/remotefunction.py @@ -12,6 +12,8 @@ import warnings +from IPython.testing import decorators as testdec + import map as Map from asyncresult import AsyncMapResult @@ -19,26 +21,32 @@ from asyncresult import AsyncMapResult # Decorators #----------------------------------------------------------------------------- +@testdec.skip_doctest def remote(client, bound=True, block=None, targets=None, balanced=None): """Turn a function into a remote function. This method can be used for map: - >>> @remote(client,block=True) - def func(a) + In [1]: @remote(client,block=True) + ...: def func(a): + ...: pass """ + def remote_function(f): return RemoteFunction(client, f, bound, block, targets, balanced) return remote_function +@testdec.skip_doctest def parallel(client, dist='b', bound=True, block=None, targets='all', balanced=None): """Turn a function into a parallel remote function. This method can be used for map: - >>> @parallel(client,block=True) - def func(a) + In [1]: @parallel(client,block=True) + ...: def func(a): + ...: pass """ + def parallel_function(f): return ParallelFunction(client, f, dist, bound, block, targets, balanced) return parallel_function diff --git a/IPython/zmq/parallel/streamkernel.py b/IPython/zmq/parallel/streamkernel.py index 81d9f7a..12458c5 100755 --- a/IPython/zmq/parallel/streamkernel.py +++ b/IPython/zmq/parallel/streamkernel.py @@ -104,7 +104,7 @@ class Kernel(SessionFactory): self._initial_exec_lines() def _wrap_exception(self, method=None): - e_info = dict(engineid=self.ident, method=method) + e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method) content=wrap_exception(e_info) return content diff --git a/IPython/zmq/parallel/tests/__init__.py b/IPython/zmq/parallel/tests/__init__.py index 03165bc..27f491b 100644 --- a/IPython/zmq/parallel/tests/__init__.py +++ b/IPython/zmq/parallel/tests/__init__.py @@ -29,7 +29,6 @@ def teardown(): p = processes.pop() if p.poll() is None: try: - print 'terminating' p.terminate() except Exception, e: print e diff --git a/IPython/zmq/parallel/tests/clienttest.py b/IPython/zmq/parallel/tests/clienttest.py index bbd18a3..ce55404 100644 --- a/IPython/zmq/parallel/tests/clienttest.py +++ b/IPython/zmq/parallel/tests/clienttest.py @@ -17,7 +17,7 @@ from IPython.zmq.parallel.tests import processes,add_engine # simple tasks for use in apply tests def segfault(): - """""" + """this will segfault""" import ctypes ctypes.memset(-1,0,1) @@ -73,9 +73,10 @@ class ClusterTestCase(BaseZMQTestCase): def assertRaisesRemote(self, etype, f, *args, **kwargs): try: - f(*args, **kwargs) - except error.CompositeError as e: - e.raise_exception() + try: + f(*args, **kwargs) + except error.CompositeError as e: + e.raise_exception() except error.RemoteError as e: self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__)) else: @@ -87,10 +88,11 @@ class ClusterTestCase(BaseZMQTestCase): self.base_engine_count=len(self.client.ids) self.engines=[] - def tearDown(self): - [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ] - # while len(self.client.ids) > self.base_engine_count: - # time.sleep(.1) - del self.engines - BaseZMQTestCase.tearDown(self) + # def tearDown(self): + # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ] + # [ e.wait() for e in self.engines ] + # while len(self.client.ids) > self.base_engine_count: + # time.sleep(.1) + # del self.engines + # BaseZMQTestCase.tearDown(self) \ No newline at end of file diff --git a/IPython/zmq/parallel/tests/test_client.py b/IPython/zmq/parallel/tests/test_client.py index 7ea30c4..21ff03b 100644 --- a/IPython/zmq/parallel/tests/test_client.py +++ b/IPython/zmq/parallel/tests/test_client.py @@ -2,28 +2,35 @@ import time import nose.tools as nt -from IPython.zmq.parallel.asyncresult import AsyncResult +from IPython.zmq.parallel import client as clientmod +from IPython.zmq.parallel import error +from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult from IPython.zmq.parallel.view import LoadBalancedView, DirectView -from clienttest import ClusterTestCase, segfault +from clienttest import ClusterTestCase, segfault, wait class TestClient(ClusterTestCase): def test_ids(self): - self.assertEquals(len(self.client.ids), 1) + n = len(self.client.ids) self.add_engines(3) - self.assertEquals(len(self.client.ids), 4) + self.assertEquals(len(self.client.ids), n+3) + self.assertTrue def test_segfault(self): + """test graceful handling of engine death""" self.add_engines(1) eid = self.client.ids[-1] - self.client[eid].apply(segfault) + ar = self.client.apply(segfault, block=False) + self.assertRaisesRemote(error.EngineError, ar.get) + eid = ar.engine_id while eid in self.client.ids: time.sleep(.01) self.client.spin() def test_view_indexing(self): - self.add_engines(4) + """test index access for views""" + self.add_engines(2) targets = self.client._build_targets('all')[-1] v = self.client[:] self.assertEquals(v.targets, targets) @@ -60,17 +67,30 @@ class TestClient(ClusterTestCase): def test_targets(self): """test various valid targets arguments""" - pass + build = self.client._build_targets + ids = self.client.ids + idents,targets = build(None) + self.assertEquals(ids, targets) def test_clear(self): """test clear behavior""" - # self.add_engines(4) - # self.client.push() + self.add_engines(2) + self.client.block=True + self.client.push(dict(a=5)) + self.client.pull('a') + id0 = self.client.ids[-1] + self.client.clear(targets=id0) + self.client.pull('a', targets=self.client.ids[:-1]) + self.assertRaisesRemote(NameError, self.client.pull, 'a') + self.client.clear() + for i in self.client.ids: + self.assertRaisesRemote(NameError, self.client.pull, 'a', targets=i) + def test_push_pull(self): """test pushing and pulling""" data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'}) - self.add_engines(4) + self.add_engines(2) push = self.client.push pull = self.client.pull self.client.block=True @@ -131,4 +151,15 @@ class TestClient(ClusterTestCase): v.execute('b=f()') self.assertEquals(v['b'], 5) + def test_get_result(self): + """test getting results from the Hub.""" + c = clientmod.Client(profile='iptest') + t = self.client.ids[-1] + ar = c.apply(wait, (1,), block=False, targets=t) + time.sleep(.25) + ahr = self.client.get_result(ar.msg_ids) + self.assertTrue(isinstance(ahr, AsyncHubResult)) + self.assertEquals(ahr.get(), ar.get()) + ar2 = self.client.get_result(ar.msg_ids) + self.assertFalse(isinstance(ar2, AsyncHubResult)) \ No newline at end of file diff --git a/IPython/zmq/parallel/view.py b/IPython/zmq/parallel/view.py index 9be98d2..d2082a7 100644 --- a/IPython/zmq/parallel/view.py +++ b/IPython/zmq/parallel/view.py @@ -10,6 +10,7 @@ # Imports #----------------------------------------------------------------------------- +from IPython.testing import decorators as testdec from IPython.utils.traitlets import HasTraits, Bool, List, Dict, Set, Int, Instance from IPython.external.decorator import decorator @@ -330,7 +331,7 @@ class View(HasTraits): block = self.block if block is None else block return parallel(self.client, bound=bound, targets=self._targets, block=block, balanced=self._balanced) - +@testdec.skip_doctest class DirectView(View): """Direct Multiplexer View of one or more engines. @@ -413,7 +414,7 @@ class DirectView(View): return self.client.push(ns, targets=self._targets, block=self.block) push = update - + def get(self, key_s): """get object(s) by `key_s` from remote namespace will return one object if it is a key. @@ -430,26 +431,24 @@ class DirectView(View): block = block if block is not None else self.block return self.client.pull(key_s, block=block, targets=self._targets) - def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None): + def scatter(self, key, seq, dist='b', flatten=False, block=None): """ Partition a Python sequence and send the partitions to a set of engines. """ block = block if block is not None else self.block - targets = targets if targets is not None else self._targets return self.client.scatter(key, seq, dist=dist, flatten=flatten, - targets=targets, block=block) + targets=self._targets, block=block) @sync_results @save_ids - def gather(self, key, dist='b', targets=None, block=None): + def gather(self, key, dist='b', block=None): """ Gather a partitioned sequence on a set of engines as a single local seq. """ block = block if block is not None else self.block - targets = targets if targets is not None else self._targets - return self.client.gather(key, dist=dist, targets=targets, block=block) + return self.client.gather(key, dist=dist, targets=self._targets, block=block) def __getitem__(self, key): return self.get(key) @@ -496,7 +495,8 @@ class DirectView(View): print "You must first load the parallelmagic extension " \ "by doing '%load_ext parallelmagic'" - + +@testdec.skip_doctest class LoadBalancedView(View): """An load-balancing View that only executes via the Task scheduler.