##// END OF EJS Templates
Merge pull request #6429 from quantopian/transformer-rejection...
Thomas Kluyver -
r17802:7433cd5d merge
parent child Browse files
Show More
@@ -51,3 +51,10 b' class StdinNotImplementedError(IPythonCoreError, NotImplementedError):'
51 51 For use in IPython kernels, where only some frontends may support
52 52 stdin requests.
53 53 """
54
55 class InputRejected(Exception):
56 """Input rejected by ast transformer.
57
58 Raise this in your NodeTransformer to indicate that InteractiveShell should
59 not execute the supplied input.
60 """
@@ -41,7 +41,7 b' from IPython.core.compilerop import CachingCompiler, check_linecache_ipython'
41 41 from IPython.core.display_trap import DisplayTrap
42 42 from IPython.core.displayhook import DisplayHook
43 43 from IPython.core.displaypub import DisplayPublisher
44 from IPython.core.error import UsageError
44 from IPython.core.error import InputRejected, UsageError
45 45 from IPython.core.extensions import ExtensionManager
46 46 from IPython.core.formatters import DisplayFormatter
47 47 from IPython.core.history import HistoryManager
@@ -2786,7 +2786,13 b' class InteractiveShell(SingletonConfigurable):'
2786 2786 return None
2787 2787
2788 2788 # Apply AST transformations
2789 code_ast = self.transform_ast(code_ast)
2789 try:
2790 code_ast = self.transform_ast(code_ast)
2791 except InputRejected:
2792 self.showtraceback()
2793 if store_history:
2794 self.execution_count += 1
2795 return None
2790 2796
2791 2797 # Execute the user code
2792 2798 interactivity = "none" if silent else self.ast_node_interactivity
@@ -2822,6 +2828,11 b' class InteractiveShell(SingletonConfigurable):'
2822 2828 for transformer in self.ast_transformers:
2823 2829 try:
2824 2830 node = transformer.visit(node)
2831 except InputRejected:
2832 # User-supplied AST transformers can reject an input by raising
2833 # an InputRejected. Short-circuit in this case so that we
2834 # don't unregister the transform.
2835 raise
2825 2836 except Exception:
2826 2837 warn("AST transformer %r threw an error. It will be unregistered." % transformer)
2827 2838 self.ast_transformers.remove(transformer)
@@ -24,6 +24,7 b' from os.path import join'
24 24
25 25 import nose.tools as nt
26 26
27 from IPython.core.error import InputRejected
27 28 from IPython.core.inputtransformer import InputTransformer
28 29 from IPython.testing.decorators import (
29 30 skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist,
@@ -681,7 +682,7 b' class TestAstTransform2(unittest.TestCase):'
681 682
682 683 class ErrorTransformer(ast.NodeTransformer):
683 684 """Throws an error when it sees a number."""
684 def visit_Num(self):
685 def visit_Num(self, node):
685 686 raise ValueError("test")
686 687
687 688 class TestAstTransformError(unittest.TestCase):
@@ -695,6 +696,41 b' class TestAstTransformError(unittest.TestCase):'
695 696 # This should have been removed.
696 697 nt.assert_not_in(err_transformer, ip.ast_transformers)
697 698
699
700 class StringRejector(ast.NodeTransformer):
701 """Throws an InputRejected when it sees a string literal.
702
703 Used to verify that NodeTransformers can signal that a piece of code should
704 not be executed by throwing an InputRejected.
705 """
706
707 def visit_Str(self, node):
708 raise InputRejected("test")
709
710
711 class TestAstTransformInputRejection(unittest.TestCase):
712
713 def setUp(self):
714 self.transformer = StringRejector()
715 ip.ast_transformers.append(self.transformer)
716
717 def tearDown(self):
718 ip.ast_transformers.remove(self.transformer)
719
720 def test_input_rejection(self):
721 """Check that NodeTransformers can reject input."""
722
723 expect_exception_tb = tt.AssertPrints("InputRejected: test")
724 expect_no_cell_output = tt.AssertNotPrints("'unsafe'", suppress=False)
725
726 # Run the same check twice to verify that the transformer is not
727 # disabled after raising.
728 with expect_exception_tb, expect_no_cell_output:
729 ip.run_cell("'unsafe'")
730
731 with expect_exception_tb, expect_no_cell_output:
732 ip.run_cell("'unsafe'")
733
698 734 def test__IPYTHON__():
699 735 # This shouldn't raise a NameError, that's all
700 736 __IPYTHON__
@@ -365,18 +365,21 b' class AssertPrints(object):'
365 365 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
366 366
367 367 def __exit__(self, etype, value, traceback):
368 if value is not None:
369 # If an error was raised, don't check anything else
368 try:
369 if value is not None:
370 # If an error was raised, don't check anything else
371 return False
372 self.tee.flush()
373 setattr(sys, self.channel, self.orig_stream)
374 printed = self.buffer.getvalue()
375 for s in self.s:
376 if isinstance(s, _re_type):
377 assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
378 else:
379 assert s in printed, notprinted_msg.format(s, self.channel, printed)
370 380 return False
371 self.tee.flush()
372 setattr(sys, self.channel, self.orig_stream)
373 printed = self.buffer.getvalue()
374 for s in self.s:
375 if isinstance(s, _re_type):
376 assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
377 else:
378 assert s in printed, notprinted_msg.format(s, self.channel, printed)
379 return False
381 finally:
382 self.tee.close()
380 383
381 384 printed_msg = """Found {0!r} in printed output (on {1}):
382 385 -------
@@ -389,18 +392,24 b' class AssertNotPrints(AssertPrints):'
389 392
390 393 Counterpart of AssertPrints"""
391 394 def __exit__(self, etype, value, traceback):
392 if value is not None:
393 # If an error was raised, don't check anything else
395 try:
396 if value is not None:
397 # If an error was raised, don't check anything else
398 self.tee.close()
399 return False
400 self.tee.flush()
401 setattr(sys, self.channel, self.orig_stream)
402 printed = self.buffer.getvalue()
403 for s in self.s:
404 if isinstance(s, _re_type):
405 assert not s.search(printed),printed_msg.format(
406 s.pattern, self.channel, printed)
407 else:
408 assert s not in printed, printed_msg.format(
409 s, self.channel, printed)
394 410 return False
395 self.tee.flush()
396 setattr(sys, self.channel, self.orig_stream)
397 printed = self.buffer.getvalue()
398 for s in self.s:
399 if isinstance(s, _re_type):
400 assert not s.search(printed), printed_msg.format(s.pattern, self.channel, printed)
401 else:
402 assert s not in printed, printed_msg.format(s, self.channel, printed)
403 return False
411 finally:
412 self.tee.close()
404 413
405 414 @contextmanager
406 415 def mute_warn():
General Comments 0
You need to be logged in to leave comments. Login now