From 0ae639e3e56f93606cbe4907144234c141ecbdaf 2012-05-29 03:20:31 From: MinRK Date: 2012-05-29 03:20:31 Subject: [PATCH] move some general parallel test utilities to clienttest --- diff --git a/IPython/parallel/tests/clienttest.py b/IPython/parallel/tests/clienttest.py index b902293..60017b6 100644 --- a/IPython/parallel/tests/clienttest.py +++ b/IPython/parallel/tests/clienttest.py @@ -15,6 +15,7 @@ Authors: import sys import tempfile import time +from StringIO import StringIO from nose import SkipTest @@ -59,6 +60,28 @@ def raiser(eclass): """raise an exception""" raise eclass() +def generate_output(): + """function for testing output + + publishes two outputs of each type, and returns + a rich displayable object. + """ + + import sys + from IPython.core.display import display, HTML, Math + + print "stdout" + print >> sys.stderr, "stderr" + + display(HTML("HTML")) + + print "stdout2" + print >> sys.stderr, "stderr2" + + display(Math(r"\alpha=\beta")) + + return Math("42") + # test decorator for skipping tests when libraries are unavailable def skip_without(*names): """skip a test if some names are not importable""" @@ -73,6 +96,41 @@ def skip_without(*names): return f(*args, **kwargs) return skip_without_names +#------------------------------------------------------------------------------- +# Classes +#------------------------------------------------------------------------------- + +class CapturedIO(object): + """Simple object for containing captured stdout/err StringIO objects""" + + def __init__(self, stdout, stderr): + self.stdout_io = stdout + self.stderr_io = stderr + + @property + def stdout(self): + return self.stdout_io.getvalue() + + @property + def stderr(self): + return self.stderr_io.getvalue() + + +class capture_output(object): + """context manager for capturing stdout/err""" + + def __enter__(self): + self.sys_stdout = sys.stdout + self.sys_stderr = sys.stderr + stdout = sys.stdout = StringIO() + stderr = sys.stderr = StringIO() + return CapturedIO(stdout, stderr) + + def __exit__(self, exc_type, exc_value, traceback): + sys.stdout = self.sys_stdout + sys.stderr = self.sys_stderr + + class ClusterTestCase(BaseZMQTestCase): def add_engines(self, n=1, block=True): @@ -117,6 +175,17 @@ class ClusterTestCase(BaseZMQTestCase): else: self.fail("should have raised a RemoteError") + def _wait_for(self, f, timeout=10): + """wait for a condition""" + tic = time.time() + while time.time() <= tic + timeout: + if f(): + return + time.sleep(0.1) + self.client.spin() + if not f(): + print "Warning: Awaited condition never arrived" + def setUp(self): BaseZMQTestCase.setUp(self) self.client = self.connect_client() diff --git a/IPython/parallel/tests/test_view.py b/IPython/parallel/tests/test_view.py index f48b162..85f1f94 100644 --- a/IPython/parallel/tests/test_view.py +++ b/IPython/parallel/tests/test_view.py @@ -590,16 +590,6 @@ class TestView(ClusterTestCase, ParametricTestCase): # begin execute tests - def _wait_for(self, f, timeout=10): - tic = time.time() - while time.time() <= tic + timeout: - if f(): - return - time.sleep(0.1) - self.client.spin() - if not f(): - print "Warning: Awaited condition never arrived" - def test_execute_reply(self): e0 = self.client[self.client.ids[0]]