From 11c0dbe721a92aec815df7853f72b63c6cf52201 2011-04-08 00:38:21 From: MinRK Date: 2011-04-08 00:38:21 Subject: [PATCH] fix/test pushed function globals --- diff --git a/IPython/utils/pickleutil.py b/IPython/utils/pickleutil.py index 3314e9e..fcd7dc5 100644 --- a/IPython/utils/pickleutil.py +++ b/IPython/utils/pickleutil.py @@ -100,9 +100,9 @@ def uncan(obj, g=None): elif isinstance(obj, CannedObject): return obj.getObject(g) elif isinstance(obj,dict): - return uncanDict(obj) + return uncanDict(obj, g) elif isinstance(obj, (list,tuple)): - return uncanSequence(obj) + return uncanSequence(obj, g) else: return obj diff --git a/IPython/zmq/parallel/error.py b/IPython/zmq/parallel/error.py index 0cf4058..fea4057 100644 --- a/IPython/zmq/parallel/error.py +++ b/IPython/zmq/parallel/error.py @@ -201,7 +201,7 @@ class TaskRejectError(KernelError): """ -class CompositeError(KernelError): +class CompositeError(RemoteError): """Error for representing possibly multiple errors on engines""" def __init__(self, message, elist): Exception.__init__(self, *(message, elist)) @@ -215,7 +215,7 @@ class CompositeError(KernelError): if not ei: return '[Engine Exception]' else: - return '[%i:%s]: ' % (ei['engineid'], ei['method']) + return '[%s:%s]: ' % (ei['engineid'], ei['method']) def _get_traceback(self, ev): try: @@ -256,10 +256,7 @@ class CompositeError(KernelError): except: raise IndexError("an exception with index %i does not exist"%excid) else: - try: - raise RemoteError(en, ev, etb, ei) - except: - et,ev,tb = sys.exc_info() + raise RemoteError(en, ev, etb, ei) def collect_exceptions(rdict_or_list, method='unspecified'): @@ -290,6 +287,6 @@ def collect_exceptions(rdict_or_list, method='unspecified'): # instance (e in this case) try: raise CompositeError(msg, elist) - except CompositeError, e: + except CompositeError as e: raise e diff --git a/IPython/zmq/parallel/tests/clienttest.py b/IPython/zmq/parallel/tests/clienttest.py index 9f2684a..bbd18a3 100644 --- a/IPython/zmq/parallel/tests/clienttest.py +++ b/IPython/zmq/parallel/tests/clienttest.py @@ -8,9 +8,10 @@ from zmq.tests import BaseZMQTestCase from IPython.external.decorator import decorator +from IPython.zmq.parallel import error +from IPython.zmq.parallel.client import Client from IPython.zmq.parallel.ipcluster import launch_process from IPython.zmq.parallel.entry_point import select_random_ports -from IPython.zmq.parallel.client import Client from IPython.zmq.parallel.tests import processes,add_engine # simple tasks for use in apply tests @@ -70,6 +71,16 @@ class ClusterTestCase(BaseZMQTestCase): self.sockets.append(getattr(c, name)) return c + def assertRaisesRemote(self, etype, f, *args, **kwargs): + try: + f(*args, **kwargs) + except error.CompositeError as e: + e.raise_exception() + except error.RemoteError as e: + self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__)) + else: + self.fail("should have raised a RemoteError") + def setUp(self): BaseZMQTestCase.setUp(self) self.client = self.connect_client() diff --git a/IPython/zmq/parallel/tests/test_client.py b/IPython/zmq/parallel/tests/test_client.py index a016ff2..7ea30c4 100644 --- a/IPython/zmq/parallel/tests/test_client.py +++ b/IPython/zmq/parallel/tests/test_client.py @@ -44,10 +44,13 @@ class TestClient(ClusterTestCase): v = self.client[:-3] self.assert_(isinstance(v, DirectView)) self.assertEquals(v.targets, targets[:-3]) + v = self.client[-1] + self.assert_(isinstance(v, DirectView)) + self.assertEquals(v.targets, targets[-1]) nt.assert_raises(TypeError, lambda : self.client[None]) def test_view_cache(self): - """test blocking and non-blocking behavior""" + """test that multiple view requests return the same object""" v = self.client[:2] v2 =self.client[:2] self.assertTrue(v is v2) @@ -65,6 +68,7 @@ class TestClient(ClusterTestCase): # self.client.push() def test_push_pull(self): + """test pushing and pulling""" data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'}) self.add_engines(4) push = self.client.push @@ -89,6 +93,7 @@ class TestClient(ClusterTestCase): self.assertEquals(r, nengines*[[10,20]]) def test_push_pull_function(self): + "test pushing and pulling functions" def testf(x): return 2.0*x @@ -112,6 +117,18 @@ class TestClient(ClusterTestCase): execute("def g(x): return x*x", targets=0) r = pull(('testf','g'),targets=0) self.assertEquals((r[0](10),r[1](10)), (testf(10), 100)) - + + def test_push_function_globals(self): + """test that pushed functions have access to globals""" + def geta(): + return a + self.add_engines(1) + v = self.client[-1] + v.block=True + v['f'] = geta + self.assertRaisesRemote(NameError, v.execute, 'b=f()') + v.execute('a=5') + v.execute('b=f()') + self.assertEquals(v['b'], 5) \ No newline at end of file