From a94d559bad16f3cf3c973056a3f102f2868d1b19 2014-02-09 04:02:56 From: MinRK Date: 2014-02-09 04:02:56 Subject: [PATCH] add regex support to AssertPrints --- diff --git a/IPython/testing/tools.py b/IPython/testing/tools.py index 1dacec4..daaf20f 100644 --- a/IPython/testing/tools.py +++ b/IPython/testing/tools.py @@ -327,6 +327,8 @@ else: s = py3compat.cast_unicode(s, encoding=DEFAULT_ENCODING) super(MyStringIO, self).write(s) +_re_type = type(re.compile(r'')) + notprinted_msg = """Did not find {0!r} in printed output (on {1}): ------- {2!s} @@ -347,7 +349,7 @@ class AssertPrints(object): """ def __init__(self, s, channel='stdout', suppress=True): self.s = s - if isinstance(self.s, py3compat.string_types): + if isinstance(self.s, (py3compat.string_types, _re_type)): self.s = [self.s] self.channel = channel self.suppress = suppress @@ -366,7 +368,10 @@ class AssertPrints(object): setattr(sys, self.channel, self.orig_stream) printed = self.buffer.getvalue() for s in self.s: - assert s in printed, notprinted_msg.format(s, self.channel, printed) + if isinstance(s, _re_type): + assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed) + else: + assert s in printed, notprinted_msg.format(s, self.channel, printed) return False printed_msg = """Found {0!r} in printed output (on {1}): @@ -387,7 +392,10 @@ class AssertNotPrints(AssertPrints): setattr(sys, self.channel, self.orig_stream) printed = self.buffer.getvalue() for s in self.s: - assert s not in printed, printed_msg.format(s, self.channel, printed) + if isinstance(s, _re_type): + assert not s.search(printed), printed_msg.format(s.pattern, self.channel, printed) + else: + assert s not in printed, printed_msg.format(s, self.channel, printed) return False @contextmanager