Show More
@@ -51,3 +51,10 b' class StdinNotImplementedError(IPythonCoreError, NotImplementedError):' | |||||
51 | For use in IPython kernels, where only some frontends may support |
|
51 | For use in IPython kernels, where only some frontends may support | |
52 | stdin requests. |
|
52 | stdin requests. | |
53 | """ |
|
53 | """ | |
|
54 | ||||
|
55 | class InputRejected(Exception): | |||
|
56 | """Input rejected by ast transformer. | |||
|
57 | ||||
|
58 | Raise this in your NodeTransformer to indicate that InteractiveShell should | |||
|
59 | not execute the supplied input. | |||
|
60 | """ |
@@ -41,7 +41,7 b' from IPython.core.compilerop import CachingCompiler, check_linecache_ipython' | |||||
41 | from IPython.core.display_trap import DisplayTrap |
|
41 | from IPython.core.display_trap import DisplayTrap | |
42 | from IPython.core.displayhook import DisplayHook |
|
42 | from IPython.core.displayhook import DisplayHook | |
43 | from IPython.core.displaypub import DisplayPublisher |
|
43 | from IPython.core.displaypub import DisplayPublisher | |
44 | from IPython.core.error import UsageError |
|
44 | from IPython.core.error import InputRejected, UsageError | |
45 | from IPython.core.extensions import ExtensionManager |
|
45 | from IPython.core.extensions import ExtensionManager | |
46 | from IPython.core.formatters import DisplayFormatter |
|
46 | from IPython.core.formatters import DisplayFormatter | |
47 | from IPython.core.history import HistoryManager |
|
47 | from IPython.core.history import HistoryManager | |
@@ -2786,7 +2786,13 b' class InteractiveShell(SingletonConfigurable):' | |||||
2786 | return None |
|
2786 | return None | |
2787 |
|
2787 | |||
2788 | # Apply AST transformations |
|
2788 | # Apply AST transformations | |
2789 | code_ast = self.transform_ast(code_ast) |
|
2789 | try: | |
|
2790 | code_ast = self.transform_ast(code_ast) | |||
|
2791 | except InputRejected: | |||
|
2792 | self.showtraceback() | |||
|
2793 | if store_history: | |||
|
2794 | self.execution_count += 1 | |||
|
2795 | return None | |||
2790 |
|
2796 | |||
2791 | # Execute the user code |
|
2797 | # Execute the user code | |
2792 | interactivity = "none" if silent else self.ast_node_interactivity |
|
2798 | interactivity = "none" if silent else self.ast_node_interactivity | |
@@ -2822,6 +2828,11 b' class InteractiveShell(SingletonConfigurable):' | |||||
2822 | for transformer in self.ast_transformers: |
|
2828 | for transformer in self.ast_transformers: | |
2823 | try: |
|
2829 | try: | |
2824 | node = transformer.visit(node) |
|
2830 | node = transformer.visit(node) | |
|
2831 | except InputRejected: | |||
|
2832 | # User-supplied AST transformers can reject an input by raising | |||
|
2833 | # an InputRejected. Short-circuit in this case so that we | |||
|
2834 | # don't unregister the transform. | |||
|
2835 | raise | |||
2825 | except Exception: |
|
2836 | except Exception: | |
2826 | warn("AST transformer %r threw an error. It will be unregistered." % transformer) |
|
2837 | warn("AST transformer %r threw an error. It will be unregistered." % transformer) | |
2827 | self.ast_transformers.remove(transformer) |
|
2838 | self.ast_transformers.remove(transformer) |
@@ -24,6 +24,7 b' from os.path import join' | |||||
24 |
|
24 | |||
25 | import nose.tools as nt |
|
25 | import nose.tools as nt | |
26 |
|
26 | |||
|
27 | from IPython.core.error import InputRejected | |||
27 | from IPython.core.inputtransformer import InputTransformer |
|
28 | from IPython.core.inputtransformer import InputTransformer | |
28 | from IPython.testing.decorators import ( |
|
29 | from IPython.testing.decorators import ( | |
29 | skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist, |
|
30 | skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist, | |
@@ -681,7 +682,7 b' class TestAstTransform2(unittest.TestCase):' | |||||
681 |
|
682 | |||
682 | class ErrorTransformer(ast.NodeTransformer): |
|
683 | class ErrorTransformer(ast.NodeTransformer): | |
683 | """Throws an error when it sees a number.""" |
|
684 | """Throws an error when it sees a number.""" | |
684 | def visit_Num(self): |
|
685 | def visit_Num(self, node): | |
685 | raise ValueError("test") |
|
686 | raise ValueError("test") | |
686 |
|
687 | |||
687 | class TestAstTransformError(unittest.TestCase): |
|
688 | class TestAstTransformError(unittest.TestCase): | |
@@ -695,6 +696,41 b' class TestAstTransformError(unittest.TestCase):' | |||||
695 | # This should have been removed. |
|
696 | # This should have been removed. | |
696 | nt.assert_not_in(err_transformer, ip.ast_transformers) |
|
697 | nt.assert_not_in(err_transformer, ip.ast_transformers) | |
697 |
|
698 | |||
|
699 | ||||
|
700 | class StringRejector(ast.NodeTransformer): | |||
|
701 | """Throws an InputRejected when it sees a string literal. | |||
|
702 | ||||
|
703 | Used to verify that NodeTransformers can signal that a piece of code should | |||
|
704 | not be executed by throwing an InputRejected. | |||
|
705 | """ | |||
|
706 | ||||
|
707 | def visit_Str(self, node): | |||
|
708 | raise InputRejected("test") | |||
|
709 | ||||
|
710 | ||||
|
711 | class TestAstTransformInputRejection(unittest.TestCase): | |||
|
712 | ||||
|
713 | def setUp(self): | |||
|
714 | self.transformer = StringRejector() | |||
|
715 | ip.ast_transformers.append(self.transformer) | |||
|
716 | ||||
|
717 | def tearDown(self): | |||
|
718 | ip.ast_transformers.remove(self.transformer) | |||
|
719 | ||||
|
720 | def test_input_rejection(self): | |||
|
721 | """Check that NodeTransformers can reject input.""" | |||
|
722 | ||||
|
723 | expect_exception_tb = tt.AssertPrints("InputRejected: test") | |||
|
724 | expect_no_cell_output = tt.AssertNotPrints("'unsafe'", suppress=False) | |||
|
725 | ||||
|
726 | # Run the same check twice to verify that the transformer is not | |||
|
727 | # disabled after raising. | |||
|
728 | with expect_exception_tb, expect_no_cell_output: | |||
|
729 | ip.run_cell("'unsafe'") | |||
|
730 | ||||
|
731 | with expect_exception_tb, expect_no_cell_output: | |||
|
732 | ip.run_cell("'unsafe'") | |||
|
733 | ||||
698 | def test__IPYTHON__(): |
|
734 | def test__IPYTHON__(): | |
699 | # This shouldn't raise a NameError, that's all |
|
735 | # This shouldn't raise a NameError, that's all | |
700 | __IPYTHON__ |
|
736 | __IPYTHON__ |
@@ -365,18 +365,21 b' class AssertPrints(object):' | |||||
365 | setattr(sys, self.channel, self.buffer if self.suppress else self.tee) |
|
365 | setattr(sys, self.channel, self.buffer if self.suppress else self.tee) | |
366 |
|
366 | |||
367 | def __exit__(self, etype, value, traceback): |
|
367 | def __exit__(self, etype, value, traceback): | |
368 | if value is not None: |
|
368 | try: | |
369 | # If an error was raised, don't check anything else |
|
369 | if value is not None: | |
|
370 | # If an error was raised, don't check anything else | |||
|
371 | return False | |||
|
372 | self.tee.flush() | |||
|
373 | setattr(sys, self.channel, self.orig_stream) | |||
|
374 | printed = self.buffer.getvalue() | |||
|
375 | for s in self.s: | |||
|
376 | if isinstance(s, _re_type): | |||
|
377 | assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed) | |||
|
378 | else: | |||
|
379 | assert s in printed, notprinted_msg.format(s, self.channel, printed) | |||
370 | return False |
|
380 | return False | |
371 | self.tee.flush() |
|
381 | finally: | |
372 | setattr(sys, self.channel, self.orig_stream) |
|
382 | self.tee.close() | |
373 | printed = self.buffer.getvalue() |
|
|||
374 | for s in self.s: |
|
|||
375 | if isinstance(s, _re_type): |
|
|||
376 | assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed) |
|
|||
377 | else: |
|
|||
378 | assert s in printed, notprinted_msg.format(s, self.channel, printed) |
|
|||
379 | return False |
|
|||
380 |
|
383 | |||
381 | printed_msg = """Found {0!r} in printed output (on {1}): |
|
384 | printed_msg = """Found {0!r} in printed output (on {1}): | |
382 | ------- |
|
385 | ------- | |
@@ -389,18 +392,24 b' class AssertNotPrints(AssertPrints):' | |||||
389 |
|
392 | |||
390 | Counterpart of AssertPrints""" |
|
393 | Counterpart of AssertPrints""" | |
391 | def __exit__(self, etype, value, traceback): |
|
394 | def __exit__(self, etype, value, traceback): | |
392 | if value is not None: |
|
395 | try: | |
393 | # If an error was raised, don't check anything else |
|
396 | if value is not None: | |
|
397 | # If an error was raised, don't check anything else | |||
|
398 | self.tee.close() | |||
|
399 | return False | |||
|
400 | self.tee.flush() | |||
|
401 | setattr(sys, self.channel, self.orig_stream) | |||
|
402 | printed = self.buffer.getvalue() | |||
|
403 | for s in self.s: | |||
|
404 | if isinstance(s, _re_type): | |||
|
405 | assert not s.search(printed),printed_msg.format( | |||
|
406 | s.pattern, self.channel, printed) | |||
|
407 | else: | |||
|
408 | assert s not in printed, printed_msg.format( | |||
|
409 | s, self.channel, printed) | |||
394 | return False |
|
410 | return False | |
395 | self.tee.flush() |
|
411 | finally: | |
396 | setattr(sys, self.channel, self.orig_stream) |
|
412 | self.tee.close() | |
397 | printed = self.buffer.getvalue() |
|
|||
398 | for s in self.s: |
|
|||
399 | if isinstance(s, _re_type): |
|
|||
400 | assert not s.search(printed), printed_msg.format(s.pattern, self.channel, printed) |
|
|||
401 | else: |
|
|||
402 | assert s not in printed, printed_msg.format(s, self.channel, printed) |
|
|||
403 | return False |
|
|||
404 |
|
413 | |||
405 | @contextmanager |
|
414 | @contextmanager | |
406 | def mute_warn(): |
|
415 | def mute_warn(): |
General Comments 0
You need to be logged in to leave comments.
Login now