##// END OF EJS Templates
Merge pull request #2301 from takluyver/ast-transfomers...
Thomas Kluyver -
r8813:a7cc5832 merge
parent child Browse files
Show More
@@ -171,7 +171,7 b' class HistoryAccessor(Configurable):'
171 self.hist_file = self._get_hist_file_name(profile)
171 self.hist_file = self._get_hist_file_name(profile)
172
172
173 if sqlite3 is None and self.enabled:
173 if sqlite3 is None and self.enabled:
174 warn("IPython History requires SQLite, your history will not be saved\n")
174 warn("IPython History requires SQLite, your history will not be saved")
175 self.enabled = False
175 self.enabled = False
176
176
177 self.init_db()
177 self.init_db()
@@ -196,6 +196,13 b' class InteractiveShell(SingletonConfigurable):'
196 """An enhanced, interactive shell for Python."""
196 """An enhanced, interactive shell for Python."""
197
197
198 _instance = None
198 _instance = None
199
200 ast_transformers = List([], config=True, help=
201 """
202 A list of ast.NodeTransformer subclass instances, which will be applied
203 to user input before code is run.
204 """
205 )
199
206
200 autocall = Enum((0,1,2), default_value=0, config=True, help=
207 autocall = Enum((0,1,2), default_value=0, config=True, help=
201 """
208 """
@@ -326,7 +333,7 b' class InteractiveShell(SingletonConfigurable):'
326 'prompt_out' : 'out_template',
333 'prompt_out' : 'out_template',
327 'prompts_pad_left' : 'justify',
334 'prompts_pad_left' : 'justify',
328 }
335 }
329 warn("InteractiveShell.{name} is deprecated, use PromptManager.{newname}\n".format(
336 warn("InteractiveShell.{name} is deprecated, use PromptManager.{newname}".format(
330 name=name, newname=table[name])
337 name=name, newname=table[name])
331 )
338 )
332 # protect against weird cases where self.config may not exist:
339 # protect against weird cases where self.config may not exist:
@@ -709,7 +716,7 b' class InteractiveShell(SingletonConfigurable):'
709 return
716 return
710
717
711 warn("Attempting to work in a virtualenv. If you encounter problems, please "
718 warn("Attempting to work in a virtualenv. If you encounter problems, please "
712 "install IPython inside the virtualenv.\n")
719 "install IPython inside the virtualenv.")
713 if sys.platform == "win32":
720 if sys.platform == "win32":
714 virtual_env = os.path.join(os.environ['VIRTUAL_ENV'], 'Lib', 'site-packages')
721 virtual_env = os.path.join(os.environ['VIRTUAL_ENV'], 'Lib', 'site-packages')
715 else:
722 else:
@@ -2611,6 +2618,8 b' class InteractiveShell(SingletonConfigurable):'
2611 self.execution_count += 1
2618 self.execution_count += 1
2612 return None
2619 return None
2613
2620
2621 code_ast = self.transform_ast(code_ast)
2622
2614 interactivity = "none" if silent else self.ast_node_interactivity
2623 interactivity = "none" if silent else self.ast_node_interactivity
2615 self.run_ast_nodes(code_ast.body, cell_name,
2624 self.run_ast_nodes(code_ast.body, cell_name,
2616 interactivity=interactivity)
2625 interactivity=interactivity)
@@ -2643,6 +2652,31 b' class InteractiveShell(SingletonConfigurable):'
2643 self.history_manager.store_output(self.execution_count)
2652 self.history_manager.store_output(self.execution_count)
2644 # Each cell is a *single* input, regardless of how many lines it has
2653 # Each cell is a *single* input, regardless of how many lines it has
2645 self.execution_count += 1
2654 self.execution_count += 1
2655
2656 def transform_ast(self, node):
2657 """Apply the AST transformations from self.ast_transformers
2658
2659 Parameters
2660 ----------
2661 node : ast.Node
2662 The root node to be transformed. Typically called with the ast.Module
2663 produced by parsing user input.
2664
2665 Returns
2666 -------
2667 An ast.Node corresponding to the node it was called with. Note that it
2668 may also modify the passed object, so don't rely on references to the
2669 original AST.
2670 """
2671 for transformer in self.ast_transformers:
2672 try:
2673 node = transformer.visit(node)
2674 except Exception:
2675 warn("AST transformer %r threw an error. It will be unregistered." % transformer)
2676 self.ast_transformers.remove(transformer)
2677
2678 return ast.fix_missing_locations(node)
2679
2646
2680
2647 def run_ast_nodes(self, nodelist, cell_name, interactivity='last_expr'):
2681 def run_ast_nodes(self, nodelist, cell_name, interactivity='last_expr'):
2648 """Run a sequence of AST nodes. The execution mode depends on the
2682 """Run a sequence of AST nodes. The execution mode depends on the
@@ -14,6 +14,7 b''
14
14
15 # Stdlib
15 # Stdlib
16 import __builtin__ as builtin_mod
16 import __builtin__ as builtin_mod
17 import ast
17 import bdb
18 import bdb
18 import os
19 import os
19 import sys
20 import sys
@@ -781,26 +782,54 b' python-profiler package from non-free.""")'
781 # but is there a better way to achieve that the code stmt has access
782 # but is there a better way to achieve that the code stmt has access
782 # to the shell namespace?
783 # to the shell namespace?
783 transform = self.shell.input_splitter.transform_cell
784 transform = self.shell.input_splitter.transform_cell
785
784 if cell is None:
786 if cell is None:
785 # called as line magic
787 # called as line magic
786 setup = 'pass'
788 ast_setup = ast.parse("pass")
787 stmt = timeit.reindent(transform(stmt), 8)
789 ast_stmt = ast.parse(transform(stmt))
788 else:
789 setup = timeit.reindent(transform(stmt), 4)
790 stmt = timeit.reindent(transform(cell), 8)
791
792 # From Python 3.3, this template uses new-style string formatting.
793 if sys.version_info >= (3, 3):
794 src = timeit.template.format(stmt=stmt, setup=setup)
795 else:
790 else:
796 src = timeit.template % dict(stmt=stmt, setup=setup)
791 ast_setup = ast.parse(transform(stmt))
792 ast_stmt = ast.parse(transform(cell))
793
794 ast_setup = self.shell.transform_ast(ast_setup)
795 ast_stmt = self.shell.transform_ast(ast_stmt)
796
797 # This codestring is taken from timeit.template - we fill it in as an
798 # AST, so that we can apply our AST transformations to the user code
799 # without affecting the timing code.
800 timeit_ast_template = ast.parse('def inner(_it, _timer):\n'
801 ' setup\n'
802 ' _t0 = _timer()\n'
803 ' for _i in _it:\n'
804 ' stmt\n'
805 ' _t1 = _timer()\n'
806 ' return _t1 - _t0\n')
807
808 class TimeitTemplateFiller(ast.NodeTransformer):
809 "This is quite tightly tied to the template definition above."
810 def visit_FunctionDef(self, node):
811 "Fill in the setup statement"
812 self.generic_visit(node)
813 if node.name == "inner":
814 node.body[:1] = ast_setup.body
815
816 return node
817
818 def visit_For(self, node):
819 "Fill in the statement to be timed"
820 if getattr(getattr(node.body[0], 'value', None), 'id', None) == 'stmt':
821 node.body = ast_stmt.body
822 return node
823
824 timeit_ast = TimeitTemplateFiller().visit(timeit_ast_template)
825 timeit_ast = ast.fix_missing_locations(timeit_ast)
797
826
798 # Track compilation time so it can be reported if too long
827 # Track compilation time so it can be reported if too long
799 # Minimum time above which compilation time will be reported
828 # Minimum time above which compilation time will be reported
800 tc_min = 0.1
829 tc_min = 0.1
801
830
802 t0 = clock()
831 t0 = clock()
803 code = compile(src, "<magic-timeit>", "exec")
832 code = compile(timeit_ast, "<magic-timeit>", "exec")
804 tc = clock()-t0
833 tc = clock()-t0
805
834
806 ns = {}
835 ns = {}
@@ -884,20 +913,31 b' python-profiler package from non-free.""")'
884 # fail immediately if the given expression can't be compiled
913 # fail immediately if the given expression can't be compiled
885
914
886 expr = self.shell.prefilter(parameter_s,False)
915 expr = self.shell.prefilter(parameter_s,False)
916
917 # Minimum time above which parse time will be reported
918 tp_min = 0.1
919
920 t0 = clock()
921 expr_ast = ast.parse(expr)
922 tp = clock()-t0
923
924 # Apply AST transformations
925 expr_ast = self.shell.transform_ast(expr_ast)
887
926
888 # Minimum time above which compilation time will be reported
927 # Minimum time above which compilation time will be reported
889 tc_min = 0.1
928 tc_min = 0.1
890
929
891 try:
930 if len(expr_ast.body)==1 and isinstance(expr_ast.body[0], ast.Expr):
892 mode = 'eval'
931 mode = 'eval'
893 t0 = clock()
932 source = '<timed eval>'
894 code = compile(expr,'<timed eval>',mode)
933 expr_ast = ast.Expression(expr_ast.body[0].value)
895 tc = clock()-t0
934 else:
896 except SyntaxError:
897 mode = 'exec'
935 mode = 'exec'
898 t0 = clock()
936 source = '<timed exec>'
899 code = compile(expr,'<timed exec>',mode)
937 t0 = clock()
900 tc = clock()-t0
938 code = compile(expr_ast, source, mode)
939 tc = clock()-t0
940
901 # skew measurement as little as possible
941 # skew measurement as little as possible
902 glob = self.shell.user_ns
942 glob = self.shell.user_ns
903 wtime = time.time
943 wtime = time.time
@@ -923,6 +963,8 b' python-profiler package from non-free.""")'
923 print "Wall time: %.2f s" % wall_time
963 print "Wall time: %.2f s" % wall_time
924 if tc > tc_min:
964 if tc > tc_min:
925 print "Compiler : %.2f s" % tc
965 print "Compiler : %.2f s" % tc
966 if tp > tp_min:
967 print "Parser : %.2f s" % tp
926 return out
968 return out
927
969
928 @skip_doctest
970 @skip_doctest
@@ -20,6 +20,7 b' Authors'
20 # Imports
20 # Imports
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22 # stdlib
22 # stdlib
23 import ast
23 import os
24 import os
24 import shutil
25 import shutil
25 import sys
26 import sys
@@ -426,6 +427,132 b' class TestModules(unittest.TestCase, tt.TempFileMixin):'
426 out = "False\nFalse\nFalse\n"
427 out = "False\nFalse\nFalse\n"
427 tt.ipexec_validate(self.fname, out)
428 tt.ipexec_validate(self.fname, out)
428
429
430 class Negator(ast.NodeTransformer):
431 """Negates all number literals in an AST."""
432 def visit_Num(self, node):
433 node.n = -node.n
434 return node
435
436 class TestAstTransform(unittest.TestCase):
437 def setUp(self):
438 self.negator = Negator()
439 ip.ast_transformers.append(self.negator)
440
441 def tearDown(self):
442 ip.ast_transformers.remove(self.negator)
443
444 def test_run_cell(self):
445 with tt.AssertPrints('-34'):
446 ip.run_cell('print (12 + 22)')
447
448 # A named reference to a number shouldn't be transformed.
449 ip.user_ns['n'] = 55
450 with tt.AssertNotPrints('-55'):
451 ip.run_cell('print (n)')
452
453 def test_timeit(self):
454 called = set()
455 def f(x):
456 called.add(x)
457 ip.push({'f':f})
458
459 with tt.AssertPrints("best of "):
460 ip.run_line_magic("timeit", "-n1 f(1)")
461 self.assertEqual(called, set([-1]))
462 called.clear()
463
464 with tt.AssertPrints("best of "):
465 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
466 self.assertEqual(called, set([-2, -3]))
467
468 def test_time(self):
469 called = []
470 def f(x):
471 called.append(x)
472 ip.push({'f':f})
473
474 # Test with an expression
475 with tt.AssertPrints("CPU times"):
476 ip.run_line_magic("time", "f(5+9)")
477 self.assertEqual(called, [-14])
478 called[:] = []
479
480 # Test with a statement (different code path)
481 with tt.AssertPrints("CPU times"):
482 ip.run_line_magic("time", "a = f(-3 + -2)")
483 self.assertEqual(called, [5])
484
485 def test_macro(self):
486 ip.push({'a':10})
487 # The AST transformation makes this do a+=-1
488 ip.define_macro("amacro", "a+=1\nprint(a)")
489
490 with tt.AssertPrints("9"):
491 ip.run_cell("amacro")
492 with tt.AssertPrints("8"):
493 ip.run_cell("amacro")
494
495 class IntegerWrapper(ast.NodeTransformer):
496 """Wraps all integers in a call to Integer()"""
497 def visit_Num(self, node):
498 if isinstance(node.n, int):
499 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
500 args=[node], keywords=[])
501 return node
502
503 class TestAstTransform2(unittest.TestCase):
504 def setUp(self):
505 self.intwrapper = IntegerWrapper()
506 ip.ast_transformers.append(self.intwrapper)
507
508 self.calls = []
509 def Integer(*args):
510 self.calls.append(args)
511 return args
512 ip.push({"Integer": Integer})
513
514 def tearDown(self):
515 ip.ast_transformers.remove(self.intwrapper)
516 del ip.user_ns['Integer']
517
518 def test_run_cell(self):
519 ip.run_cell("n = 2")
520 self.assertEqual(self.calls, [(2,)])
521
522 # This shouldn't throw an error
523 ip.run_cell("o = 2.0")
524 self.assertEqual(ip.user_ns['o'], 2.0)
525
526 def test_timeit(self):
527 called = set()
528 def f(x):
529 called.add(x)
530 ip.push({'f':f})
531
532 with tt.AssertPrints("best of "):
533 ip.run_line_magic("timeit", "-n1 f(1)")
534 self.assertEqual(called, set([(1,)]))
535 called.clear()
536
537 with tt.AssertPrints("best of "):
538 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
539 self.assertEqual(called, set([(2,), (3,)]))
540
541 class ErrorTransformer(ast.NodeTransformer):
542 """Throws an error when it sees a number."""
543 def visit_Num(self):
544 raise ValueError("test")
545
546 class TestAstTransformError(unittest.TestCase):
547 def test_unregistering(self):
548 err_transformer = ErrorTransformer()
549 ip.ast_transformers.append(err_transformer)
550
551 with tt.AssertPrints("unregister", channel='stderr'):
552 ip.run_cell("1 + 2")
553
554 # This should have been removed.
555 nt.assert_not_in(err_transformer, ip.ast_transformers)
429
556
430 def test__IPYTHON__():
557 def test__IPYTHON__():
431 # This shouldn't raise a NameError, that's all
558 # This shouldn't raise a NameError, that's all
@@ -351,7 +351,7 b' class TerminalIPythonApp(BaseIPythonApplication, InteractiveShellApp):'
351 """Replace --pylab='inline' with --pylab='auto'"""
351 """Replace --pylab='inline' with --pylab='auto'"""
352 if new == 'inline':
352 if new == 'inline':
353 warn.warn("'inline' not available as pylab backend, "
353 warn.warn("'inline' not available as pylab backend, "
354 "using 'auto' instead.\n")
354 "using 'auto' instead.")
355 self.pylab = 'auto'
355 self.pylab = 'auto'
356
356
357 def start(self):
357 def start(self):
@@ -105,7 +105,7 b' class InputHookManager(object):'
105
105
106 def __init__(self):
106 def __init__(self):
107 if ctypes is None:
107 if ctypes is None:
108 warn("IPython GUI event loop requires ctypes, %gui will not be available\n")
108 warn("IPython GUI event loop requires ctypes, %gui will not be available")
109 return
109 return
110 self.PYFUNC = ctypes.PYFUNCTYPE(ctypes.c_int)
110 self.PYFUNC = ctypes.PYFUNCTYPE(ctypes.c_int)
111 self._apps = {}
111 self._apps = {}
@@ -319,7 +319,7 b' def make_exclude():'
319 continue
319 continue
320 fullpath = pjoin(parent, exclusion)
320 fullpath = pjoin(parent, exclusion)
321 if not os.path.exists(fullpath) and not glob.glob(fullpath + '.*'):
321 if not os.path.exists(fullpath) and not glob.glob(fullpath + '.*'):
322 warn("Excluding nonexistent file: %r\n" % exclusion)
322 warn("Excluding nonexistent file: %r" % exclusion)
323
323
324 return exclusions
324 return exclusions
325
325
@@ -130,9 +130,9 b' def import_fail_info(mod_name,fns=None):'
130 """Inform load failure for a module."""
130 """Inform load failure for a module."""
131
131
132 if fns == None:
132 if fns == None:
133 warn("Loading of %s failed.\n" % (mod_name,))
133 warn("Loading of %s failed." % (mod_name,))
134 else:
134 else:
135 warn("Loading of %s from %s failed.\n" % (fns,mod_name))
135 warn("Loading of %s from %s failed." % (fns,mod_name))
136
136
137
137
138 class NotGiven: pass
138 class NotGiven: pass
@@ -42,7 +42,7 b' def warn(msg,level=2,exit_val=1):'
42
42
43 if level>0:
43 if level>0:
44 header = ['','','WARNING: ','ERROR: ','FATAL ERROR: ']
44 header = ['','','WARNING: ','ERROR: ','FATAL ERROR: ']
45 io.stderr.write('%s%s' % (header[level],msg))
45 print(header[level], msg, sep='', file=io.stderr)
46 if level == 4:
46 if level == 4:
47 print('Exiting.\n', file=io.stderr)
47 print('Exiting.\n', file=io.stderr)
48 sys.exit(exit_val)
48 sys.exit(exit_val)
General Comments 0
You need to be logged in to leave comments. Login now