diff --git a/IPython/core/error.py b/IPython/core/error.py index 5346498..684cbc8 100644 --- a/IPython/core/error.py +++ b/IPython/core/error.py @@ -51,3 +51,10 @@ class StdinNotImplementedError(IPythonCoreError, NotImplementedError): For use in IPython kernels, where only some frontends may support stdin requests. """ + +class InputRejected(Exception): + """Input rejected by ast transformer. + + Raise this in your NodeTransformer to indicate that InteractiveShell should + not execute the supplied input. + """ diff --git a/IPython/core/interactiveshell.py b/IPython/core/interactiveshell.py index b78654e..56aea6d 100644 --- a/IPython/core/interactiveshell.py +++ b/IPython/core/interactiveshell.py @@ -41,7 +41,7 @@ from IPython.core.compilerop import CachingCompiler, check_linecache_ipython from IPython.core.display_trap import DisplayTrap from IPython.core.displayhook import DisplayHook from IPython.core.displaypub import DisplayPublisher -from IPython.core.error import UsageError +from IPython.core.error import InputRejected, UsageError from IPython.core.extensions import ExtensionManager from IPython.core.formatters import DisplayFormatter from IPython.core.history import HistoryManager @@ -2786,7 +2786,13 @@ class InteractiveShell(SingletonConfigurable): return None # Apply AST transformations - code_ast = self.transform_ast(code_ast) + try: + code_ast = self.transform_ast(code_ast) + except InputRejected: + self.showtraceback() + if store_history: + self.execution_count += 1 + return None # Execute the user code interactivity = "none" if silent else self.ast_node_interactivity @@ -2822,6 +2828,11 @@ class InteractiveShell(SingletonConfigurable): for transformer in self.ast_transformers: try: node = transformer.visit(node) + except InputRejected: + # User-supplied AST transformers can reject an input by raising + # an InputRejected. Short-circuit in this case so that we + # don't unregister the transform. + raise except Exception: warn("AST transformer %r threw an error. It will be unregistered." % transformer) self.ast_transformers.remove(transformer) diff --git a/IPython/core/tests/test_interactiveshell.py b/IPython/core/tests/test_interactiveshell.py index de8d8d7..4d95a68 100644 --- a/IPython/core/tests/test_interactiveshell.py +++ b/IPython/core/tests/test_interactiveshell.py @@ -24,6 +24,7 @@ from os.path import join import nose.tools as nt +from IPython.core.error import InputRejected from IPython.core.inputtransformer import InputTransformer from IPython.testing.decorators import ( skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist, @@ -681,7 +682,7 @@ class TestAstTransform2(unittest.TestCase): class ErrorTransformer(ast.NodeTransformer): """Throws an error when it sees a number.""" - def visit_Num(self): + def visit_Num(self, node): raise ValueError("test") class TestAstTransformError(unittest.TestCase): @@ -695,6 +696,41 @@ class TestAstTransformError(unittest.TestCase): # This should have been removed. nt.assert_not_in(err_transformer, ip.ast_transformers) + +class StringRejector(ast.NodeTransformer): + """Throws an InputRejected when it sees a string literal. + + Used to verify that NodeTransformers can signal that a piece of code should + not be executed by throwing an InputRejected. + """ + + def visit_Str(self, node): + raise InputRejected("test") + + +class TestAstTransformInputRejection(unittest.TestCase): + + def setUp(self): + self.transformer = StringRejector() + ip.ast_transformers.append(self.transformer) + + def tearDown(self): + ip.ast_transformers.remove(self.transformer) + + def test_input_rejection(self): + """Check that NodeTransformers can reject input.""" + + expect_exception_tb = tt.AssertPrints("InputRejected: test") + expect_no_cell_output = tt.AssertNotPrints("'unsafe'", suppress=False) + + # Run the same check twice to verify that the transformer is not + # disabled after raising. + with expect_exception_tb, expect_no_cell_output: + ip.run_cell("'unsafe'") + + with expect_exception_tb, expect_no_cell_output: + ip.run_cell("'unsafe'") + def test__IPYTHON__(): # This shouldn't raise a NameError, that's all __IPYTHON__ diff --git a/IPython/testing/tools.py b/IPython/testing/tools.py index 5b95094..5970775 100644 --- a/IPython/testing/tools.py +++ b/IPython/testing/tools.py @@ -365,18 +365,21 @@ class AssertPrints(object): setattr(sys, self.channel, self.buffer if self.suppress else self.tee) def __exit__(self, etype, value, traceback): - if value is not None: - # If an error was raised, don't check anything else + try: + if value is not None: + # If an error was raised, don't check anything else + return False + self.tee.flush() + setattr(sys, self.channel, self.orig_stream) + printed = self.buffer.getvalue() + for s in self.s: + if isinstance(s, _re_type): + assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed) + else: + assert s in printed, notprinted_msg.format(s, self.channel, printed) return False - self.tee.flush() - setattr(sys, self.channel, self.orig_stream) - printed = self.buffer.getvalue() - for s in self.s: - if isinstance(s, _re_type): - assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed) - else: - assert s in printed, notprinted_msg.format(s, self.channel, printed) - return False + finally: + self.tee.close() printed_msg = """Found {0!r} in printed output (on {1}): ------- @@ -389,18 +392,24 @@ class AssertNotPrints(AssertPrints): Counterpart of AssertPrints""" def __exit__(self, etype, value, traceback): - if value is not None: - # If an error was raised, don't check anything else + try: + if value is not None: + # If an error was raised, don't check anything else + self.tee.close() + return False + self.tee.flush() + setattr(sys, self.channel, self.orig_stream) + printed = self.buffer.getvalue() + for s in self.s: + if isinstance(s, _re_type): + assert not s.search(printed),printed_msg.format( + s.pattern, self.channel, printed) + else: + assert s not in printed, printed_msg.format( + s, self.channel, printed) return False - self.tee.flush() - setattr(sys, self.channel, self.orig_stream) - printed = self.buffer.getvalue() - for s in self.s: - if isinstance(s, _re_type): - assert not s.search(printed), printed_msg.format(s.pattern, self.channel, printed) - else: - assert s not in printed, printed_msg.format(s, self.channel, printed) - return False + finally: + self.tee.close() @contextmanager def mute_warn():