##// END OF EJS Templates
Merge pull request #2301 from takluyver/ast-transfomers...
Thomas Kluyver -
r8813:a7cc5832 merge
parent child Browse files
Show More
@@ -171,7 +171,7 b' class HistoryAccessor(Configurable):'
171 171 self.hist_file = self._get_hist_file_name(profile)
172 172
173 173 if sqlite3 is None and self.enabled:
174 warn("IPython History requires SQLite, your history will not be saved\n")
174 warn("IPython History requires SQLite, your history will not be saved")
175 175 self.enabled = False
176 176
177 177 self.init_db()
@@ -197,6 +197,13 b' class InteractiveShell(SingletonConfigurable):'
197 197
198 198 _instance = None
199 199
200 ast_transformers = List([], config=True, help=
201 """
202 A list of ast.NodeTransformer subclass instances, which will be applied
203 to user input before code is run.
204 """
205 )
206
200 207 autocall = Enum((0,1,2), default_value=0, config=True, help=
201 208 """
202 209 Make IPython automatically call any callable object even if you didn't
@@ -326,7 +333,7 b' class InteractiveShell(SingletonConfigurable):'
326 333 'prompt_out' : 'out_template',
327 334 'prompts_pad_left' : 'justify',
328 335 }
329 warn("InteractiveShell.{name} is deprecated, use PromptManager.{newname}\n".format(
336 warn("InteractiveShell.{name} is deprecated, use PromptManager.{newname}".format(
330 337 name=name, newname=table[name])
331 338 )
332 339 # protect against weird cases where self.config may not exist:
@@ -709,7 +716,7 b' class InteractiveShell(SingletonConfigurable):'
709 716 return
710 717
711 718 warn("Attempting to work in a virtualenv. If you encounter problems, please "
712 "install IPython inside the virtualenv.\n")
719 "install IPython inside the virtualenv.")
713 720 if sys.platform == "win32":
714 721 virtual_env = os.path.join(os.environ['VIRTUAL_ENV'], 'Lib', 'site-packages')
715 722 else:
@@ -2611,6 +2618,8 b' class InteractiveShell(SingletonConfigurable):'
2611 2618 self.execution_count += 1
2612 2619 return None
2613 2620
2621 code_ast = self.transform_ast(code_ast)
2622
2614 2623 interactivity = "none" if silent else self.ast_node_interactivity
2615 2624 self.run_ast_nodes(code_ast.body, cell_name,
2616 2625 interactivity=interactivity)
@@ -2644,6 +2653,31 b' class InteractiveShell(SingletonConfigurable):'
2644 2653 # Each cell is a *single* input, regardless of how many lines it has
2645 2654 self.execution_count += 1
2646 2655
2656 def transform_ast(self, node):
2657 """Apply the AST transformations from self.ast_transformers
2658
2659 Parameters
2660 ----------
2661 node : ast.Node
2662 The root node to be transformed. Typically called with the ast.Module
2663 produced by parsing user input.
2664
2665 Returns
2666 -------
2667 An ast.Node corresponding to the node it was called with. Note that it
2668 may also modify the passed object, so don't rely on references to the
2669 original AST.
2670 """
2671 for transformer in self.ast_transformers:
2672 try:
2673 node = transformer.visit(node)
2674 except Exception:
2675 warn("AST transformer %r threw an error. It will be unregistered." % transformer)
2676 self.ast_transformers.remove(transformer)
2677
2678 return ast.fix_missing_locations(node)
2679
2680
2647 2681 def run_ast_nodes(self, nodelist, cell_name, interactivity='last_expr'):
2648 2682 """Run a sequence of AST nodes. The execution mode depends on the
2649 2683 interactivity parameter.
@@ -14,6 +14,7 b''
14 14
15 15 # Stdlib
16 16 import __builtin__ as builtin_mod
17 import ast
17 18 import bdb
18 19 import os
19 20 import sys
@@ -781,26 +782,54 b' python-profiler package from non-free.""")'
781 782 # but is there a better way to achieve that the code stmt has access
782 783 # to the shell namespace?
783 784 transform = self.shell.input_splitter.transform_cell
785
784 786 if cell is None:
785 787 # called as line magic
786 setup = 'pass'
787 stmt = timeit.reindent(transform(stmt), 8)
788 else:
789 setup = timeit.reindent(transform(stmt), 4)
790 stmt = timeit.reindent(transform(cell), 8)
791
792 # From Python 3.3, this template uses new-style string formatting.
793 if sys.version_info >= (3, 3):
794 src = timeit.template.format(stmt=stmt, setup=setup)
788 ast_setup = ast.parse("pass")
789 ast_stmt = ast.parse(transform(stmt))
795 790 else:
796 src = timeit.template % dict(stmt=stmt, setup=setup)
791 ast_setup = ast.parse(transform(stmt))
792 ast_stmt = ast.parse(transform(cell))
793
794 ast_setup = self.shell.transform_ast(ast_setup)
795 ast_stmt = self.shell.transform_ast(ast_stmt)
796
797 # This codestring is taken from timeit.template - we fill it in as an
798 # AST, so that we can apply our AST transformations to the user code
799 # without affecting the timing code.
800 timeit_ast_template = ast.parse('def inner(_it, _timer):\n'
801 ' setup\n'
802 ' _t0 = _timer()\n'
803 ' for _i in _it:\n'
804 ' stmt\n'
805 ' _t1 = _timer()\n'
806 ' return _t1 - _t0\n')
807
808 class TimeitTemplateFiller(ast.NodeTransformer):
809 "This is quite tightly tied to the template definition above."
810 def visit_FunctionDef(self, node):
811 "Fill in the setup statement"
812 self.generic_visit(node)
813 if node.name == "inner":
814 node.body[:1] = ast_setup.body
815
816 return node
817
818 def visit_For(self, node):
819 "Fill in the statement to be timed"
820 if getattr(getattr(node.body[0], 'value', None), 'id', None) == 'stmt':
821 node.body = ast_stmt.body
822 return node
823
824 timeit_ast = TimeitTemplateFiller().visit(timeit_ast_template)
825 timeit_ast = ast.fix_missing_locations(timeit_ast)
797 826
798 827 # Track compilation time so it can be reported if too long
799 828 # Minimum time above which compilation time will be reported
800 829 tc_min = 0.1
801 830
802 831 t0 = clock()
803 code = compile(src, "<magic-timeit>", "exec")
832 code = compile(timeit_ast, "<magic-timeit>", "exec")
804 833 tc = clock()-t0
805 834
806 835 ns = {}
@@ -885,19 +914,30 b' python-profiler package from non-free.""")'
885 914
886 915 expr = self.shell.prefilter(parameter_s,False)
887 916
917 # Minimum time above which parse time will be reported
918 tp_min = 0.1
919
920 t0 = clock()
921 expr_ast = ast.parse(expr)
922 tp = clock()-t0
923
924 # Apply AST transformations
925 expr_ast = self.shell.transform_ast(expr_ast)
926
888 927 # Minimum time above which compilation time will be reported
889 928 tc_min = 0.1
890 929
891 try:
930 if len(expr_ast.body)==1 and isinstance(expr_ast.body[0], ast.Expr):
892 931 mode = 'eval'
893 t0 = clock()
894 code = compile(expr,'<timed eval>',mode)
895 tc = clock()-t0
896 except SyntaxError:
932 source = '<timed eval>'
933 expr_ast = ast.Expression(expr_ast.body[0].value)
934 else:
897 935 mode = 'exec'
936 source = '<timed exec>'
898 937 t0 = clock()
899 code = compile(expr,'<timed exec>',mode)
938 code = compile(expr_ast, source, mode)
900 939 tc = clock()-t0
940
901 941 # skew measurement as little as possible
902 942 glob = self.shell.user_ns
903 943 wtime = time.time
@@ -923,6 +963,8 b' python-profiler package from non-free.""")'
923 963 print "Wall time: %.2f s" % wall_time
924 964 if tc > tc_min:
925 965 print "Compiler : %.2f s" % tc
966 if tp > tp_min:
967 print "Parser : %.2f s" % tp
926 968 return out
927 969
928 970 @skip_doctest
@@ -20,6 +20,7 b' Authors'
20 20 # Imports
21 21 #-----------------------------------------------------------------------------
22 22 # stdlib
23 import ast
23 24 import os
24 25 import shutil
25 26 import sys
@@ -426,6 +427,132 b' class TestModules(unittest.TestCase, tt.TempFileMixin):'
426 427 out = "False\nFalse\nFalse\n"
427 428 tt.ipexec_validate(self.fname, out)
428 429
430 class Negator(ast.NodeTransformer):
431 """Negates all number literals in an AST."""
432 def visit_Num(self, node):
433 node.n = -node.n
434 return node
435
436 class TestAstTransform(unittest.TestCase):
437 def setUp(self):
438 self.negator = Negator()
439 ip.ast_transformers.append(self.negator)
440
441 def tearDown(self):
442 ip.ast_transformers.remove(self.negator)
443
444 def test_run_cell(self):
445 with tt.AssertPrints('-34'):
446 ip.run_cell('print (12 + 22)')
447
448 # A named reference to a number shouldn't be transformed.
449 ip.user_ns['n'] = 55
450 with tt.AssertNotPrints('-55'):
451 ip.run_cell('print (n)')
452
453 def test_timeit(self):
454 called = set()
455 def f(x):
456 called.add(x)
457 ip.push({'f':f})
458
459 with tt.AssertPrints("best of "):
460 ip.run_line_magic("timeit", "-n1 f(1)")
461 self.assertEqual(called, set([-1]))
462 called.clear()
463
464 with tt.AssertPrints("best of "):
465 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
466 self.assertEqual(called, set([-2, -3]))
467
468 def test_time(self):
469 called = []
470 def f(x):
471 called.append(x)
472 ip.push({'f':f})
473
474 # Test with an expression
475 with tt.AssertPrints("CPU times"):
476 ip.run_line_magic("time", "f(5+9)")
477 self.assertEqual(called, [-14])
478 called[:] = []
479
480 # Test with a statement (different code path)
481 with tt.AssertPrints("CPU times"):
482 ip.run_line_magic("time", "a = f(-3 + -2)")
483 self.assertEqual(called, [5])
484
485 def test_macro(self):
486 ip.push({'a':10})
487 # The AST transformation makes this do a+=-1
488 ip.define_macro("amacro", "a+=1\nprint(a)")
489
490 with tt.AssertPrints("9"):
491 ip.run_cell("amacro")
492 with tt.AssertPrints("8"):
493 ip.run_cell("amacro")
494
495 class IntegerWrapper(ast.NodeTransformer):
496 """Wraps all integers in a call to Integer()"""
497 def visit_Num(self, node):
498 if isinstance(node.n, int):
499 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
500 args=[node], keywords=[])
501 return node
502
503 class TestAstTransform2(unittest.TestCase):
504 def setUp(self):
505 self.intwrapper = IntegerWrapper()
506 ip.ast_transformers.append(self.intwrapper)
507
508 self.calls = []
509 def Integer(*args):
510 self.calls.append(args)
511 return args
512 ip.push({"Integer": Integer})
513
514 def tearDown(self):
515 ip.ast_transformers.remove(self.intwrapper)
516 del ip.user_ns['Integer']
517
518 def test_run_cell(self):
519 ip.run_cell("n = 2")
520 self.assertEqual(self.calls, [(2,)])
521
522 # This shouldn't throw an error
523 ip.run_cell("o = 2.0")
524 self.assertEqual(ip.user_ns['o'], 2.0)
525
526 def test_timeit(self):
527 called = set()
528 def f(x):
529 called.add(x)
530 ip.push({'f':f})
531
532 with tt.AssertPrints("best of "):
533 ip.run_line_magic("timeit", "-n1 f(1)")
534 self.assertEqual(called, set([(1,)]))
535 called.clear()
536
537 with tt.AssertPrints("best of "):
538 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
539 self.assertEqual(called, set([(2,), (3,)]))
540
541 class ErrorTransformer(ast.NodeTransformer):
542 """Throws an error when it sees a number."""
543 def visit_Num(self):
544 raise ValueError("test")
545
546 class TestAstTransformError(unittest.TestCase):
547 def test_unregistering(self):
548 err_transformer = ErrorTransformer()
549 ip.ast_transformers.append(err_transformer)
550
551 with tt.AssertPrints("unregister", channel='stderr'):
552 ip.run_cell("1 + 2")
553
554 # This should have been removed.
555 nt.assert_not_in(err_transformer, ip.ast_transformers)
429 556
430 557 def test__IPYTHON__():
431 558 # This shouldn't raise a NameError, that's all
@@ -351,7 +351,7 b' class TerminalIPythonApp(BaseIPythonApplication, InteractiveShellApp):'
351 351 """Replace --pylab='inline' with --pylab='auto'"""
352 352 if new == 'inline':
353 353 warn.warn("'inline' not available as pylab backend, "
354 "using 'auto' instead.\n")
354 "using 'auto' instead.")
355 355 self.pylab = 'auto'
356 356
357 357 def start(self):
@@ -105,7 +105,7 b' class InputHookManager(object):'
105 105
106 106 def __init__(self):
107 107 if ctypes is None:
108 warn("IPython GUI event loop requires ctypes, %gui will not be available\n")
108 warn("IPython GUI event loop requires ctypes, %gui will not be available")
109 109 return
110 110 self.PYFUNC = ctypes.PYFUNCTYPE(ctypes.c_int)
111 111 self._apps = {}
@@ -319,7 +319,7 b' def make_exclude():'
319 319 continue
320 320 fullpath = pjoin(parent, exclusion)
321 321 if not os.path.exists(fullpath) and not glob.glob(fullpath + '.*'):
322 warn("Excluding nonexistent file: %r\n" % exclusion)
322 warn("Excluding nonexistent file: %r" % exclusion)
323 323
324 324 return exclusions
325 325
@@ -130,9 +130,9 b' def import_fail_info(mod_name,fns=None):'
130 130 """Inform load failure for a module."""
131 131
132 132 if fns == None:
133 warn("Loading of %s failed.\n" % (mod_name,))
133 warn("Loading of %s failed." % (mod_name,))
134 134 else:
135 warn("Loading of %s from %s failed.\n" % (fns,mod_name))
135 warn("Loading of %s from %s failed." % (fns,mod_name))
136 136
137 137
138 138 class NotGiven: pass
@@ -42,7 +42,7 b' def warn(msg,level=2,exit_val=1):'
42 42
43 43 if level>0:
44 44 header = ['','','WARNING: ','ERROR: ','FATAL ERROR: ']
45 io.stderr.write('%s%s' % (header[level],msg))
45 print(header[level], msg, sep='', file=io.stderr)
46 46 if level == 4:
47 47 print('Exiting.\n', file=io.stderr)
48 48 sys.exit(exit_val)
General Comments 0
You need to be logged in to leave comments. Login now