diff --git a/IPython/testing/tests/test_tools.py b/IPython/testing/tests/test_tools.py index 9dea513..8ceb9b3 100644 --- a/IPython/testing/tests/test_tools.py +++ b/IPython/testing/tests/test_tools.py @@ -17,6 +17,7 @@ from __future__ import with_statement import os import sys +import unittest import nose.tools as nt @@ -71,3 +72,19 @@ def test_temp_pyfile(): with open(fname) as fh2: src2 = fh2.read() yield nt.assert_equal(src2, src) + +class TestAssertPrints(unittest.TestCase): + def test_passing(self): + with tt.AssertPrints("abc"): + print "abcd" + print "def" + print b"ghi" + + def test_failing(self): + def func(): + with tt.AssertPrints("abc"): + print "acd" + print "def" + print b"ghi" + + self.assertRaises(AssertionError, func) diff --git a/IPython/testing/tools.py b/IPython/testing/tools.py index 6c634aa..5496699 100644 --- a/IPython/testing/tools.py +++ b/IPython/testing/tools.py @@ -34,6 +34,7 @@ import sys import tempfile from contextlib import contextmanager +from io import StringIO try: # These tools are used by parts of the runtime, so we make the nose @@ -46,9 +47,9 @@ except ImportError: from IPython.config.loader import Config from IPython.utils.process import find_cmd, getoutputerror -from IPython.utils.text import list_strings -from IPython.utils.io import temp_pyfile -from IPython.utils.py3compat import PY3 +from IPython.utils.text import list_strings, getdefaultencoding +from IPython.utils.io import temp_pyfile, Tee +from IPython.utils import py3compat from . import decorators as dec from . import skipdoctest @@ -210,7 +211,7 @@ def ipexec(fname, options=None): _ip = get_ipython() test_dir = os.path.dirname(__file__) - ipython_cmd = find_cmd('ipython3' if PY3 else 'ipython') + ipython_cmd = find_cmd('ipython3' if py3compat.PY3 else 'ipython') # Absolute path for filename full_fname = os.path.join(test_dir, fname) full_cmd = '%s %s %s' % (ipython_cmd, cmdargs, full_fname) @@ -324,6 +325,47 @@ def check_pairs(func, pairs): out = func(inp) assert out == expected, pair_fail_msg.format(name, inp, expected, out) +if py3compat.PY3: + MyStringIO = StringIO +else: + # In Python 2, stdout/stderr can have either bytes or unicode written to them, + # so we need a class that can handle both. + class MyStringIO(StringIO): + def write(self, s): + s = py3compat.cast_unicode(s, encoding=getdefaultencoding()) + super(MyStringIO, self).write(s) + +notprinted_msg = """Did not find {0!r} in printed output (on {1}): +{2!r}""" +class AssertPrints(object): + """Context manager for testing that code prints certain text. + + Examples + -------- + >>> with AssertPrints("abc"): + ... print "abcd" + ... print "def" + ... + abcd + def + """ + def __init__(self, s, channel='stdout'): + self.s = s + self.channel = channel + + def __enter__(self): + self.orig_stream = getattr(sys, self.channel) + self.buffer = MyStringIO() + self.tee = Tee(self.buffer, channel=self.channel) + setattr(sys, self.channel, self.tee) + + def __exit__(self, etype, value, traceback): + self.tee.flush() + setattr(sys, self.channel, self.orig_stream) + printed = self.buffer.getvalue() + assert self.s in printed, notprinted_msg.format(self.s, self.channel, printed) + return False + @contextmanager def mute_warn(): from IPython.utils import warn diff --git a/IPython/utils/tests/test_path.py b/IPython/utils/tests/test_path.py index bcf23b6..3cfd0ad 100644 --- a/IPython/utils/tests/test_path.py +++ b/IPython/utils/tests/test_path.py @@ -29,7 +29,7 @@ from nose import with_setup import IPython from IPython.testing import decorators as dec from IPython.testing.decorators import skip_if_not_win32, skip_win32 -from IPython.testing.tools import make_tempfile +from IPython.testing.tools import make_tempfile, AssertPrints from IPython.utils import path, io from IPython.utils import py3compat @@ -404,13 +404,8 @@ def test_not_writable_ipdir(): ipdir = os.path.join(tmpdir, '.ipython') os.mkdir(ipdir) os.chmod(ipdir, 600) - stderr = io.stderr - pipe = StringIO() - io.stderr = pipe - ipdir = path.get_ipython_dir() - io.stderr.flush() - io.stderr = stderr - nt.assert_true('WARNING' in pipe.getvalue()) + with AssertPrints('WARNING', channel='stderr'): + ipdir = path.get_ipython_dir() env.pop('IPYTHON_DIR', None) def test_unquote_filename():