diff --git a/IPython/extensions/tests/test_autoreload.py b/IPython/extensions/tests/test_autoreload.py index c0bb425..d4b4005 100644 --- a/IPython/extensions/tests/test_autoreload.py +++ b/IPython/extensions/tests/test_autoreload.py @@ -7,6 +7,7 @@ import time from StringIO import StringIO import nose.tools as nt +import IPython.testing.tools as tt from IPython.extensions.autoreload import AutoreloadInterface from IPython.core.hooks import TryNext @@ -197,19 +198,11 @@ class Bar: # old-style class: weakref doesn't work for it on Python < 2.7 a syntax error """) - old_stderr = sys.stderr - new_stderr = StringIO() - sys.stderr = new_stderr - try: + with tt.AssertPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'): self.shell.run_code("pass") # trigger reload + with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'): self.shell.run_code("pass") # trigger another reload - check_module_contents() - finally: - sys.stderr = old_stderr - - nt.assert_true(('[autoreload of %s failed:' % mod_name) in - new_stderr.getvalue()) - nt.assert_equal(new_stderr.getvalue().count('[autoreload of'), 1) + check_module_contents() # # Rewrite module (this time reload should succeed) diff --git a/IPython/testing/tools.py b/IPython/testing/tools.py index 5496699..36986ff 100644 --- a/IPython/testing/tools.py +++ b/IPython/testing/tools.py @@ -342,22 +342,23 @@ class AssertPrints(object): Examples -------- - >>> with AssertPrints("abc"): + >>> with AssertPrints("abc", suppress=False): ... print "abcd" ... print "def" ... abcd def """ - def __init__(self, s, channel='stdout'): + def __init__(self, s, channel='stdout', suppress=True): self.s = s self.channel = channel + self.suppress = suppress def __enter__(self): self.orig_stream = getattr(sys, self.channel) self.buffer = MyStringIO() self.tee = Tee(self.buffer, channel=self.channel) - setattr(sys, self.channel, self.tee) + setattr(sys, self.channel, self.buffer if self.suppress else self.tee) def __exit__(self, etype, value, traceback): self.tee.flush() @@ -365,6 +366,17 @@ class AssertPrints(object): printed = self.buffer.getvalue() assert self.s in printed, notprinted_msg.format(self.s, self.channel, printed) return False + +class AssertNotPrints(AssertPrints): + """Context manager for checking that certain output *isn't* produced. + + Counterpart of AssertPrints""" + def __exit__(self, etype, value, traceback): + self.tee.flush() + setattr(sys, self.channel, self.orig_stream) + printed = self.buffer.getvalue() + assert self.s not in printed, notprinted_msg.format(self.s, self.channel, printed) + return False @contextmanager def mute_warn():