From a7cc5832e2bb3b997ba5d096630493c014b0ce9a 2012-11-30 21:51:43
From: Thomas Kluyver <takowl@gmail.com>
Date: 2012-11-30 21:51:43
Subject: [PATCH] Merge pull request #2301 from takluyver/ast-transfomers

Ast transfomers
---

diff --git a/IPython/core/history.py b/IPython/core/history.py
index 0dcf823..558e517 100644
--- a/IPython/core/history.py
+++ b/IPython/core/history.py
@@ -171,7 +171,7 @@ class HistoryAccessor(Configurable):
             self.hist_file = self._get_hist_file_name(profile)
 
         if sqlite3 is None and self.enabled:
-            warn("IPython History requires SQLite, your history will not be saved\n")
+            warn("IPython History requires SQLite, your history will not be saved")
             self.enabled = False
         
         self.init_db()
diff --git a/IPython/core/interactiveshell.py b/IPython/core/interactiveshell.py
index bbebb1f..133b352 100644
--- a/IPython/core/interactiveshell.py
+++ b/IPython/core/interactiveshell.py
@@ -196,6 +196,13 @@ class InteractiveShell(SingletonConfigurable):
     """An enhanced, interactive shell for Python."""
 
     _instance = None
+    
+    ast_transformers = List([], config=True, help=
+        """
+        A list of ast.NodeTransformer subclass instances, which will be applied
+        to user input before code is run.
+        """
+    )
 
     autocall = Enum((0,1,2), default_value=0, config=True, help=
         """
@@ -326,7 +333,7 @@ class InteractiveShell(SingletonConfigurable):
             'prompt_out' : 'out_template',
             'prompts_pad_left' : 'justify',
         }
-        warn("InteractiveShell.{name} is deprecated, use PromptManager.{newname}\n".format(
+        warn("InteractiveShell.{name} is deprecated, use PromptManager.{newname}".format(
                 name=name, newname=table[name])
         )
         # protect against weird cases where self.config may not exist:
@@ -709,7 +716,7 @@ class InteractiveShell(SingletonConfigurable):
             return
         
         warn("Attempting to work in a virtualenv. If you encounter problems, please "
-             "install IPython inside the virtualenv.\n")
+             "install IPython inside the virtualenv.")
         if sys.platform == "win32":
             virtual_env = os.path.join(os.environ['VIRTUAL_ENV'], 'Lib', 'site-packages') 
         else:
@@ -2611,6 +2618,8 @@ class InteractiveShell(SingletonConfigurable):
                             self.execution_count += 1
                         return None
                     
+                    code_ast = self.transform_ast(code_ast)
+                    
                     interactivity = "none" if silent else self.ast_node_interactivity
                     self.run_ast_nodes(code_ast.body, cell_name,
                                        interactivity=interactivity)
@@ -2643,6 +2652,31 @@ class InteractiveShell(SingletonConfigurable):
             self.history_manager.store_output(self.execution_count)
             # Each cell is a *single* input, regardless of how many lines it has
             self.execution_count += 1
+    
+    def transform_ast(self, node):
+        """Apply the AST transformations from self.ast_transformers
+        
+        Parameters
+        ----------
+        node : ast.Node
+          The root node to be transformed. Typically called with the ast.Module
+          produced by parsing user input.
+        
+        Returns
+        -------
+        An ast.Node corresponding to the node it was called with. Note that it
+        may also modify the passed object, so don't rely on references to the
+        original AST.
+        """
+        for transformer in self.ast_transformers:
+            try:
+                node = transformer.visit(node)
+            except Exception:
+                warn("AST transformer %r threw an error. It will be unregistered." % transformer)
+                self.ast_transformers.remove(transformer)
+        
+        return ast.fix_missing_locations(node)
+                
 
     def run_ast_nodes(self, nodelist, cell_name, interactivity='last_expr'):
         """Run a sequence of AST nodes. The execution mode depends on the
diff --git a/IPython/core/magics/execution.py b/IPython/core/magics/execution.py
index 1b916b7..86afccd 100644
--- a/IPython/core/magics/execution.py
+++ b/IPython/core/magics/execution.py
@@ -14,6 +14,7 @@
 
 # Stdlib
 import __builtin__ as builtin_mod
+import ast
 import bdb
 import os
 import sys
@@ -781,26 +782,54 @@ python-profiler package from non-free.""")
         # but is there a better way to achieve that the code stmt has access
         # to the shell namespace?
         transform  = self.shell.input_splitter.transform_cell
+        
         if cell is None:
             # called as line magic
-            setup = 'pass'
-            stmt = timeit.reindent(transform(stmt), 8)
-        else:
-            setup = timeit.reindent(transform(stmt), 4)
-            stmt = timeit.reindent(transform(cell), 8)
-
-        # From Python 3.3, this template uses new-style string formatting.
-        if sys.version_info >= (3, 3):
-            src = timeit.template.format(stmt=stmt, setup=setup)
+            ast_setup = ast.parse("pass")
+            ast_stmt = ast.parse(transform(stmt))
         else:
-            src = timeit.template % dict(stmt=stmt, setup=setup)
+            ast_setup = ast.parse(transform(stmt))
+            ast_stmt = ast.parse(transform(cell))
+        
+        ast_setup = self.shell.transform_ast(ast_setup)
+        ast_stmt = self.shell.transform_ast(ast_stmt)
+        
+        # This codestring is taken from timeit.template - we fill it in as an
+        # AST, so that we can apply our AST transformations to the user code
+        # without affecting the timing code.
+        timeit_ast_template = ast.parse('def inner(_it, _timer):\n'
+                                        '    setup\n'
+                                        '    _t0 = _timer()\n'
+                                        '    for _i in _it:\n'
+                                        '        stmt\n'
+                                        '    _t1 = _timer()\n'
+                                        '    return _t1 - _t0\n')
+        
+        class TimeitTemplateFiller(ast.NodeTransformer):
+            "This is quite tightly tied to the template definition above."
+            def visit_FunctionDef(self, node):
+                "Fill in the setup statement"
+                self.generic_visit(node)
+                if node.name == "inner":
+                    node.body[:1] = ast_setup.body
+                
+                return node
+            
+            def visit_For(self, node):
+                "Fill in the statement to be timed"
+                if getattr(getattr(node.body[0], 'value', None), 'id', None) == 'stmt':
+                    node.body = ast_stmt.body
+                return node
+        
+        timeit_ast = TimeitTemplateFiller().visit(timeit_ast_template)
+        timeit_ast = ast.fix_missing_locations(timeit_ast)
 
         # Track compilation time so it can be reported if too long
         # Minimum time above which compilation time will be reported
         tc_min = 0.1
 
         t0 = clock()
-        code = compile(src, "<magic-timeit>", "exec")
+        code = compile(timeit_ast, "<magic-timeit>", "exec")
         tc = clock()-t0
 
         ns = {}
@@ -884,20 +913,31 @@ python-profiler package from non-free.""")
         # fail immediately if the given expression can't be compiled
 
         expr = self.shell.prefilter(parameter_s,False)
+        
+        # Minimum time above which parse time will be reported
+        tp_min = 0.1
+        
+        t0 = clock()
+        expr_ast = ast.parse(expr)
+        tp = clock()-t0
+        
+        # Apply AST transformations
+        expr_ast = self.shell.transform_ast(expr_ast)
 
         # Minimum time above which compilation time will be reported
         tc_min = 0.1
 
-        try:
+        if len(expr_ast.body)==1 and isinstance(expr_ast.body[0], ast.Expr):
             mode = 'eval'
-            t0 = clock()
-            code = compile(expr,'<timed eval>',mode)
-            tc = clock()-t0
-        except SyntaxError:
+            source = '<timed eval>'
+            expr_ast = ast.Expression(expr_ast.body[0].value)
+        else:
             mode = 'exec'
-            t0 = clock()
-            code = compile(expr,'<timed exec>',mode)
-            tc = clock()-t0
+            source = '<timed exec>'
+        t0 = clock()
+        code = compile(expr_ast, source, mode)
+        tc = clock()-t0
+        
         # skew measurement as little as possible
         glob = self.shell.user_ns
         wtime = time.time
@@ -923,6 +963,8 @@ python-profiler package from non-free.""")
         print "Wall time: %.2f s" % wall_time
         if tc > tc_min:
             print "Compiler : %.2f s" % tc
+        if tp > tp_min:
+            print "Parser   : %.2f s" % tp
         return out
 
     @skip_doctest
diff --git a/IPython/core/tests/test_interactiveshell.py b/IPython/core/tests/test_interactiveshell.py
index 634d503..cf7f763 100644
--- a/IPython/core/tests/test_interactiveshell.py
+++ b/IPython/core/tests/test_interactiveshell.py
@@ -20,6 +20,7 @@ Authors
 # Imports
 #-----------------------------------------------------------------------------
 # stdlib
+import ast
 import os
 import shutil
 import sys
@@ -426,6 +427,132 @@ class TestModules(unittest.TestCase, tt.TempFileMixin):
         out = "False\nFalse\nFalse\n"
         tt.ipexec_validate(self.fname, out)
 
+class Negator(ast.NodeTransformer):
+    """Negates all number literals in an AST."""
+    def visit_Num(self, node):
+        node.n = -node.n
+        return node
+
+class TestAstTransform(unittest.TestCase):
+    def setUp(self):
+        self.negator = Negator()
+        ip.ast_transformers.append(self.negator)
+    
+    def tearDown(self):
+        ip.ast_transformers.remove(self.negator)
+    
+    def test_run_cell(self):
+        with tt.AssertPrints('-34'):
+            ip.run_cell('print (12 + 22)')
+        
+        # A named reference to a number shouldn't be transformed.
+        ip.user_ns['n'] = 55
+        with tt.AssertNotPrints('-55'):
+            ip.run_cell('print (n)')
+    
+    def test_timeit(self):
+        called = set()
+        def f(x):
+            called.add(x)
+        ip.push({'f':f})
+        
+        with tt.AssertPrints("best of "):
+            ip.run_line_magic("timeit", "-n1 f(1)")
+        self.assertEqual(called, set([-1]))
+        called.clear()
+        
+        with tt.AssertPrints("best of "):
+            ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
+        self.assertEqual(called, set([-2, -3]))
+    
+    def test_time(self):
+        called = []
+        def f(x):
+            called.append(x)
+        ip.push({'f':f})
+        
+        # Test with an expression
+        with tt.AssertPrints("CPU times"):
+            ip.run_line_magic("time", "f(5+9)")
+        self.assertEqual(called, [-14])
+        called[:] = []
+        
+        # Test with a statement (different code path)
+        with tt.AssertPrints("CPU times"):
+            ip.run_line_magic("time", "a = f(-3 + -2)")
+        self.assertEqual(called, [5])
+    
+    def test_macro(self):
+        ip.push({'a':10})
+        # The AST transformation makes this do a+=-1
+        ip.define_macro("amacro", "a+=1\nprint(a)")
+        
+        with tt.AssertPrints("9"):
+            ip.run_cell("amacro")
+        with tt.AssertPrints("8"):
+            ip.run_cell("amacro")
+
+class IntegerWrapper(ast.NodeTransformer):
+    """Wraps all integers in a call to Integer()"""
+    def visit_Num(self, node):
+        if isinstance(node.n, int):
+            return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
+                            args=[node], keywords=[])
+        return node
+
+class TestAstTransform2(unittest.TestCase):
+    def setUp(self):
+        self.intwrapper = IntegerWrapper()
+        ip.ast_transformers.append(self.intwrapper)
+        
+        self.calls = []
+        def Integer(*args):
+            self.calls.append(args)
+            return args
+        ip.push({"Integer": Integer})
+    
+    def tearDown(self):
+        ip.ast_transformers.remove(self.intwrapper)
+        del ip.user_ns['Integer']
+    
+    def test_run_cell(self):
+        ip.run_cell("n = 2")
+        self.assertEqual(self.calls, [(2,)])
+        
+        # This shouldn't throw an error
+        ip.run_cell("o = 2.0")
+        self.assertEqual(ip.user_ns['o'], 2.0)
+    
+    def test_timeit(self):
+        called = set()
+        def f(x):
+            called.add(x)
+        ip.push({'f':f})
+        
+        with tt.AssertPrints("best of "):
+            ip.run_line_magic("timeit", "-n1 f(1)")
+        self.assertEqual(called, set([(1,)]))
+        called.clear()
+        
+        with tt.AssertPrints("best of "):
+            ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
+        self.assertEqual(called, set([(2,), (3,)]))
+
+class ErrorTransformer(ast.NodeTransformer):
+    """Throws an error when it sees a number."""
+    def visit_Num(self):
+        raise ValueError("test")
+
+class TestAstTransformError(unittest.TestCase):
+    def test_unregistering(self):
+        err_transformer = ErrorTransformer()
+        ip.ast_transformers.append(err_transformer)
+        
+        with tt.AssertPrints("unregister", channel='stderr'):
+            ip.run_cell("1 + 2")
+        
+        # This should have been removed.
+        nt.assert_not_in(err_transformer, ip.ast_transformers)
 
 def test__IPYTHON__():
     # This shouldn't raise a NameError, that's all
diff --git a/IPython/frontend/terminal/ipapp.py b/IPython/frontend/terminal/ipapp.py
index f6c3f7e..2d2ef09 100755
--- a/IPython/frontend/terminal/ipapp.py
+++ b/IPython/frontend/terminal/ipapp.py
@@ -351,7 +351,7 @@ class TerminalIPythonApp(BaseIPythonApplication, InteractiveShellApp):
         """Replace --pylab='inline' with --pylab='auto'"""
         if new == 'inline':
             warn.warn("'inline' not available as pylab backend, "
-                      "using 'auto' instead.\n")
+                      "using 'auto' instead.")
             self.pylab = 'auto'
 
     def start(self):
diff --git a/IPython/lib/inputhook.py b/IPython/lib/inputhook.py
index b0e9a67..c2dd403 100644
--- a/IPython/lib/inputhook.py
+++ b/IPython/lib/inputhook.py
@@ -105,7 +105,7 @@ class InputHookManager(object):
     
     def __init__(self):
         if ctypes is None:
-            warn("IPython GUI event loop requires ctypes, %gui will not be available\n")
+            warn("IPython GUI event loop requires ctypes, %gui will not be available")
             return
         self.PYFUNC = ctypes.PYFUNCTYPE(ctypes.c_int)
         self._apps = {}
diff --git a/IPython/testing/iptest.py b/IPython/testing/iptest.py
index bad03d1..65b9e8b 100644
--- a/IPython/testing/iptest.py
+++ b/IPython/testing/iptest.py
@@ -319,7 +319,7 @@ def make_exclude():
             continue
         fullpath = pjoin(parent, exclusion)
         if not os.path.exists(fullpath) and not glob.glob(fullpath + '.*'):
-            warn("Excluding nonexistent file: %r\n" % exclusion)
+            warn("Excluding nonexistent file: %r" % exclusion)
 
     return exclusions
 
diff --git a/IPython/utils/attic.py b/IPython/utils/attic.py
index 18d53b0..52fd799 100644
--- a/IPython/utils/attic.py
+++ b/IPython/utils/attic.py
@@ -130,9 +130,9 @@ def import_fail_info(mod_name,fns=None):
     """Inform load failure for a module."""
 
     if fns == None:
-        warn("Loading of %s failed.\n" % (mod_name,))
+        warn("Loading of %s failed." % (mod_name,))
     else:
-        warn("Loading of %s from %s failed.\n" % (fns,mod_name))
+        warn("Loading of %s from %s failed." % (fns,mod_name))
 
 
 class NotGiven: pass
diff --git a/IPython/utils/warn.py b/IPython/utils/warn.py
index 530b70b..693eeb3 100644
--- a/IPython/utils/warn.py
+++ b/IPython/utils/warn.py
@@ -42,7 +42,7 @@ def warn(msg,level=2,exit_val=1):
 
     if level>0:
         header = ['','','WARNING: ','ERROR: ','FATAL ERROR: ']
-        io.stderr.write('%s%s' % (header[level],msg))
+        print(header[level], msg, sep='', file=io.stderr)
         if level == 4:
             print('Exiting.\n', file=io.stderr)
             sys.exit(exit_val)