From 7433cd5dbfced29a529799956b01979feb8b94fa 2014-09-07 21:47:26
From: Thomas Kluyver <takowl@gmail.com>
Date: 2014-09-07 21:47:26
Subject: [PATCH] Merge pull request #6429 from quantopian/transformer-rejection

Make it possible for AST transformers to reject input.
---

diff --git a/IPython/core/error.py b/IPython/core/error.py
index 5346498..684cbc8 100644
--- a/IPython/core/error.py
+++ b/IPython/core/error.py
@@ -51,3 +51,10 @@ class StdinNotImplementedError(IPythonCoreError, NotImplementedError):
     For use in IPython kernels, where only some frontends may support
     stdin requests.
     """
+
+class InputRejected(Exception):
+    """Input rejected by ast transformer.
+
+    Raise this in your NodeTransformer to indicate that InteractiveShell should
+    not execute the supplied input.
+    """
diff --git a/IPython/core/interactiveshell.py b/IPython/core/interactiveshell.py
index b78654e..56aea6d 100644
--- a/IPython/core/interactiveshell.py
+++ b/IPython/core/interactiveshell.py
@@ -41,7 +41,7 @@ from IPython.core.compilerop import CachingCompiler, check_linecache_ipython
 from IPython.core.display_trap import DisplayTrap
 from IPython.core.displayhook import DisplayHook
 from IPython.core.displaypub import DisplayPublisher
-from IPython.core.error import UsageError
+from IPython.core.error import InputRejected, UsageError
 from IPython.core.extensions import ExtensionManager
 from IPython.core.formatters import DisplayFormatter
 from IPython.core.history import HistoryManager
@@ -2786,7 +2786,13 @@ class InteractiveShell(SingletonConfigurable):
                     return None
 
                 # Apply AST transformations
-                code_ast = self.transform_ast(code_ast)
+                try:
+                    code_ast = self.transform_ast(code_ast)
+                except InputRejected:
+                    self.showtraceback()
+                    if store_history:
+                        self.execution_count += 1
+                    return None
 
                 # Execute the user code
                 interactivity = "none" if silent else self.ast_node_interactivity
@@ -2822,6 +2828,11 @@ class InteractiveShell(SingletonConfigurable):
         for transformer in self.ast_transformers:
             try:
                 node = transformer.visit(node)
+            except InputRejected:
+                # User-supplied AST transformers can reject an input by raising
+                # an InputRejected.  Short-circuit in this case so that we
+                # don't unregister the transform.
+                raise
             except Exception:
                 warn("AST transformer %r threw an error. It will be unregistered." % transformer)
                 self.ast_transformers.remove(transformer)
diff --git a/IPython/core/tests/test_interactiveshell.py b/IPython/core/tests/test_interactiveshell.py
index de8d8d7..4d95a68 100644
--- a/IPython/core/tests/test_interactiveshell.py
+++ b/IPython/core/tests/test_interactiveshell.py
@@ -24,6 +24,7 @@ from os.path import join
 
 import nose.tools as nt
 
+from IPython.core.error import InputRejected
 from IPython.core.inputtransformer import InputTransformer
 from IPython.testing.decorators import (
     skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist,
@@ -681,7 +682,7 @@ class TestAstTransform2(unittest.TestCase):
 
 class ErrorTransformer(ast.NodeTransformer):
     """Throws an error when it sees a number."""
-    def visit_Num(self):
+    def visit_Num(self, node):
         raise ValueError("test")
 
 class TestAstTransformError(unittest.TestCase):
@@ -695,6 +696,41 @@ class TestAstTransformError(unittest.TestCase):
         # This should have been removed.
         nt.assert_not_in(err_transformer, ip.ast_transformers)
 
+
+class StringRejector(ast.NodeTransformer):
+    """Throws an InputRejected when it sees a string literal.
+
+    Used to verify that NodeTransformers can signal that a piece of code should
+    not be executed by throwing an InputRejected.
+    """
+
+    def visit_Str(self, node):
+        raise InputRejected("test")
+
+
+class TestAstTransformInputRejection(unittest.TestCase):
+
+    def setUp(self):
+        self.transformer = StringRejector()
+        ip.ast_transformers.append(self.transformer)
+
+    def tearDown(self):
+        ip.ast_transformers.remove(self.transformer)
+
+    def test_input_rejection(self):
+        """Check that NodeTransformers can reject input."""
+
+        expect_exception_tb = tt.AssertPrints("InputRejected: test")
+        expect_no_cell_output = tt.AssertNotPrints("'unsafe'", suppress=False)
+
+        # Run the same check twice to verify that the transformer is not
+        # disabled after raising.
+        with expect_exception_tb, expect_no_cell_output:
+            ip.run_cell("'unsafe'")
+
+        with expect_exception_tb, expect_no_cell_output:
+            ip.run_cell("'unsafe'")
+
 def test__IPYTHON__():
     # This shouldn't raise a NameError, that's all
     __IPYTHON__
diff --git a/IPython/testing/tools.py b/IPython/testing/tools.py
index 5b95094..5970775 100644
--- a/IPython/testing/tools.py
+++ b/IPython/testing/tools.py
@@ -365,18 +365,21 @@ class AssertPrints(object):
         setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
     
     def __exit__(self, etype, value, traceback):
-        if value is not None:
-            # If an error was raised, don't check anything else
+        try:
+            if value is not None:
+                # If an error was raised, don't check anything else
+                return False
+            self.tee.flush()
+            setattr(sys, self.channel, self.orig_stream)
+            printed = self.buffer.getvalue()
+            for s in self.s:
+                if isinstance(s, _re_type):
+                    assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
+                else:
+                    assert s in printed, notprinted_msg.format(s, self.channel, printed)
             return False
-        self.tee.flush()
-        setattr(sys, self.channel, self.orig_stream)
-        printed = self.buffer.getvalue()
-        for s in self.s:
-            if isinstance(s, _re_type):
-                assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
-            else:
-                assert s in printed, notprinted_msg.format(s, self.channel, printed)
-        return False
+        finally:
+            self.tee.close()
 
 printed_msg = """Found {0!r} in printed output (on {1}):
 -------
@@ -389,18 +392,24 @@ class AssertNotPrints(AssertPrints):
     
     Counterpart of AssertPrints"""
     def __exit__(self, etype, value, traceback):
-        if value is not None:
-            # If an error was raised, don't check anything else
+        try:
+            if value is not None:
+                # If an error was raised, don't check anything else
+                self.tee.close()
+                return False
+            self.tee.flush()
+            setattr(sys, self.channel, self.orig_stream)
+            printed = self.buffer.getvalue()
+            for s in self.s:
+                if isinstance(s, _re_type):
+                    assert not s.search(printed),printed_msg.format(
+                        s.pattern, self.channel, printed)
+                else:
+                    assert s not in printed, printed_msg.format(
+                        s, self.channel, printed)
             return False
-        self.tee.flush()
-        setattr(sys, self.channel, self.orig_stream)
-        printed = self.buffer.getvalue()
-        for s in self.s:
-            if isinstance(s, _re_type):
-                assert not s.search(printed), printed_msg.format(s.pattern, self.channel, printed)
-            else:
-                assert s not in printed, printed_msg.format(s, self.channel, printed)
-        return False
+        finally:
+            self.tee.close()
 
 @contextmanager
 def mute_warn():