From a7cc5832e2bb3b997ba5d096630493c014b0ce9a 2012-11-30 21:51:43 From: Thomas Kluyver Date: 2012-11-30 21:51:43 Subject: [PATCH] Merge pull request #2301 from takluyver/ast-transfomers Ast transfomers --- diff --git a/IPython/core/history.py b/IPython/core/history.py index 0dcf823..558e517 100644 --- a/IPython/core/history.py +++ b/IPython/core/history.py @@ -171,7 +171,7 @@ class HistoryAccessor(Configurable): self.hist_file = self._get_hist_file_name(profile) if sqlite3 is None and self.enabled: - warn("IPython History requires SQLite, your history will not be saved\n") + warn("IPython History requires SQLite, your history will not be saved") self.enabled = False self.init_db() diff --git a/IPython/core/interactiveshell.py b/IPython/core/interactiveshell.py index bbebb1f..133b352 100644 --- a/IPython/core/interactiveshell.py +++ b/IPython/core/interactiveshell.py @@ -196,6 +196,13 @@ class InteractiveShell(SingletonConfigurable): """An enhanced, interactive shell for Python.""" _instance = None + + ast_transformers = List([], config=True, help= + """ + A list of ast.NodeTransformer subclass instances, which will be applied + to user input before code is run. + """ + ) autocall = Enum((0,1,2), default_value=0, config=True, help= """ @@ -326,7 +333,7 @@ class InteractiveShell(SingletonConfigurable): 'prompt_out' : 'out_template', 'prompts_pad_left' : 'justify', } - warn("InteractiveShell.{name} is deprecated, use PromptManager.{newname}\n".format( + warn("InteractiveShell.{name} is deprecated, use PromptManager.{newname}".format( name=name, newname=table[name]) ) # protect against weird cases where self.config may not exist: @@ -709,7 +716,7 @@ class InteractiveShell(SingletonConfigurable): return warn("Attempting to work in a virtualenv. If you encounter problems, please " - "install IPython inside the virtualenv.\n") + "install IPython inside the virtualenv.") if sys.platform == "win32": virtual_env = os.path.join(os.environ['VIRTUAL_ENV'], 'Lib', 'site-packages') else: @@ -2611,6 +2618,8 @@ class InteractiveShell(SingletonConfigurable): self.execution_count += 1 return None + code_ast = self.transform_ast(code_ast) + interactivity = "none" if silent else self.ast_node_interactivity self.run_ast_nodes(code_ast.body, cell_name, interactivity=interactivity) @@ -2643,6 +2652,31 @@ class InteractiveShell(SingletonConfigurable): self.history_manager.store_output(self.execution_count) # Each cell is a *single* input, regardless of how many lines it has self.execution_count += 1 + + def transform_ast(self, node): + """Apply the AST transformations from self.ast_transformers + + Parameters + ---------- + node : ast.Node + The root node to be transformed. Typically called with the ast.Module + produced by parsing user input. + + Returns + ------- + An ast.Node corresponding to the node it was called with. Note that it + may also modify the passed object, so don't rely on references to the + original AST. + """ + for transformer in self.ast_transformers: + try: + node = transformer.visit(node) + except Exception: + warn("AST transformer %r threw an error. It will be unregistered." % transformer) + self.ast_transformers.remove(transformer) + + return ast.fix_missing_locations(node) + def run_ast_nodes(self, nodelist, cell_name, interactivity='last_expr'): """Run a sequence of AST nodes. The execution mode depends on the diff --git a/IPython/core/magics/execution.py b/IPython/core/magics/execution.py index 1b916b7..86afccd 100644 --- a/IPython/core/magics/execution.py +++ b/IPython/core/magics/execution.py @@ -14,6 +14,7 @@ # Stdlib import __builtin__ as builtin_mod +import ast import bdb import os import sys @@ -781,26 +782,54 @@ python-profiler package from non-free.""") # but is there a better way to achieve that the code stmt has access # to the shell namespace? transform = self.shell.input_splitter.transform_cell + if cell is None: # called as line magic - setup = 'pass' - stmt = timeit.reindent(transform(stmt), 8) - else: - setup = timeit.reindent(transform(stmt), 4) - stmt = timeit.reindent(transform(cell), 8) - - # From Python 3.3, this template uses new-style string formatting. - if sys.version_info >= (3, 3): - src = timeit.template.format(stmt=stmt, setup=setup) + ast_setup = ast.parse("pass") + ast_stmt = ast.parse(transform(stmt)) else: - src = timeit.template % dict(stmt=stmt, setup=setup) + ast_setup = ast.parse(transform(stmt)) + ast_stmt = ast.parse(transform(cell)) + + ast_setup = self.shell.transform_ast(ast_setup) + ast_stmt = self.shell.transform_ast(ast_stmt) + + # This codestring is taken from timeit.template - we fill it in as an + # AST, so that we can apply our AST transformations to the user code + # without affecting the timing code. + timeit_ast_template = ast.parse('def inner(_it, _timer):\n' + ' setup\n' + ' _t0 = _timer()\n' + ' for _i in _it:\n' + ' stmt\n' + ' _t1 = _timer()\n' + ' return _t1 - _t0\n') + + class TimeitTemplateFiller(ast.NodeTransformer): + "This is quite tightly tied to the template definition above." + def visit_FunctionDef(self, node): + "Fill in the setup statement" + self.generic_visit(node) + if node.name == "inner": + node.body[:1] = ast_setup.body + + return node + + def visit_For(self, node): + "Fill in the statement to be timed" + if getattr(getattr(node.body[0], 'value', None), 'id', None) == 'stmt': + node.body = ast_stmt.body + return node + + timeit_ast = TimeitTemplateFiller().visit(timeit_ast_template) + timeit_ast = ast.fix_missing_locations(timeit_ast) # Track compilation time so it can be reported if too long # Minimum time above which compilation time will be reported tc_min = 0.1 t0 = clock() - code = compile(src, "", "exec") + code = compile(timeit_ast, "", "exec") tc = clock()-t0 ns = {} @@ -884,20 +913,31 @@ python-profiler package from non-free.""") # fail immediately if the given expression can't be compiled expr = self.shell.prefilter(parameter_s,False) + + # Minimum time above which parse time will be reported + tp_min = 0.1 + + t0 = clock() + expr_ast = ast.parse(expr) + tp = clock()-t0 + + # Apply AST transformations + expr_ast = self.shell.transform_ast(expr_ast) # Minimum time above which compilation time will be reported tc_min = 0.1 - try: + if len(expr_ast.body)==1 and isinstance(expr_ast.body[0], ast.Expr): mode = 'eval' - t0 = clock() - code = compile(expr,'',mode) - tc = clock()-t0 - except SyntaxError: + source = '' + expr_ast = ast.Expression(expr_ast.body[0].value) + else: mode = 'exec' - t0 = clock() - code = compile(expr,'',mode) - tc = clock()-t0 + source = '' + t0 = clock() + code = compile(expr_ast, source, mode) + tc = clock()-t0 + # skew measurement as little as possible glob = self.shell.user_ns wtime = time.time @@ -923,6 +963,8 @@ python-profiler package from non-free.""") print "Wall time: %.2f s" % wall_time if tc > tc_min: print "Compiler : %.2f s" % tc + if tp > tp_min: + print "Parser : %.2f s" % tp return out @skip_doctest diff --git a/IPython/core/tests/test_interactiveshell.py b/IPython/core/tests/test_interactiveshell.py index 634d503..cf7f763 100644 --- a/IPython/core/tests/test_interactiveshell.py +++ b/IPython/core/tests/test_interactiveshell.py @@ -20,6 +20,7 @@ Authors # Imports #----------------------------------------------------------------------------- # stdlib +import ast import os import shutil import sys @@ -426,6 +427,132 @@ class TestModules(unittest.TestCase, tt.TempFileMixin): out = "False\nFalse\nFalse\n" tt.ipexec_validate(self.fname, out) +class Negator(ast.NodeTransformer): + """Negates all number literals in an AST.""" + def visit_Num(self, node): + node.n = -node.n + return node + +class TestAstTransform(unittest.TestCase): + def setUp(self): + self.negator = Negator() + ip.ast_transformers.append(self.negator) + + def tearDown(self): + ip.ast_transformers.remove(self.negator) + + def test_run_cell(self): + with tt.AssertPrints('-34'): + ip.run_cell('print (12 + 22)') + + # A named reference to a number shouldn't be transformed. + ip.user_ns['n'] = 55 + with tt.AssertNotPrints('-55'): + ip.run_cell('print (n)') + + def test_timeit(self): + called = set() + def f(x): + called.add(x) + ip.push({'f':f}) + + with tt.AssertPrints("best of "): + ip.run_line_magic("timeit", "-n1 f(1)") + self.assertEqual(called, set([-1])) + called.clear() + + with tt.AssertPrints("best of "): + ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)") + self.assertEqual(called, set([-2, -3])) + + def test_time(self): + called = [] + def f(x): + called.append(x) + ip.push({'f':f}) + + # Test with an expression + with tt.AssertPrints("CPU times"): + ip.run_line_magic("time", "f(5+9)") + self.assertEqual(called, [-14]) + called[:] = [] + + # Test with a statement (different code path) + with tt.AssertPrints("CPU times"): + ip.run_line_magic("time", "a = f(-3 + -2)") + self.assertEqual(called, [5]) + + def test_macro(self): + ip.push({'a':10}) + # The AST transformation makes this do a+=-1 + ip.define_macro("amacro", "a+=1\nprint(a)") + + with tt.AssertPrints("9"): + ip.run_cell("amacro") + with tt.AssertPrints("8"): + ip.run_cell("amacro") + +class IntegerWrapper(ast.NodeTransformer): + """Wraps all integers in a call to Integer()""" + def visit_Num(self, node): + if isinstance(node.n, int): + return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()), + args=[node], keywords=[]) + return node + +class TestAstTransform2(unittest.TestCase): + def setUp(self): + self.intwrapper = IntegerWrapper() + ip.ast_transformers.append(self.intwrapper) + + self.calls = [] + def Integer(*args): + self.calls.append(args) + return args + ip.push({"Integer": Integer}) + + def tearDown(self): + ip.ast_transformers.remove(self.intwrapper) + del ip.user_ns['Integer'] + + def test_run_cell(self): + ip.run_cell("n = 2") + self.assertEqual(self.calls, [(2,)]) + + # This shouldn't throw an error + ip.run_cell("o = 2.0") + self.assertEqual(ip.user_ns['o'], 2.0) + + def test_timeit(self): + called = set() + def f(x): + called.add(x) + ip.push({'f':f}) + + with tt.AssertPrints("best of "): + ip.run_line_magic("timeit", "-n1 f(1)") + self.assertEqual(called, set([(1,)])) + called.clear() + + with tt.AssertPrints("best of "): + ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)") + self.assertEqual(called, set([(2,), (3,)])) + +class ErrorTransformer(ast.NodeTransformer): + """Throws an error when it sees a number.""" + def visit_Num(self): + raise ValueError("test") + +class TestAstTransformError(unittest.TestCase): + def test_unregistering(self): + err_transformer = ErrorTransformer() + ip.ast_transformers.append(err_transformer) + + with tt.AssertPrints("unregister", channel='stderr'): + ip.run_cell("1 + 2") + + # This should have been removed. + nt.assert_not_in(err_transformer, ip.ast_transformers) def test__IPYTHON__(): # This shouldn't raise a NameError, that's all diff --git a/IPython/frontend/terminal/ipapp.py b/IPython/frontend/terminal/ipapp.py index f6c3f7e..2d2ef09 100755 --- a/IPython/frontend/terminal/ipapp.py +++ b/IPython/frontend/terminal/ipapp.py @@ -351,7 +351,7 @@ class TerminalIPythonApp(BaseIPythonApplication, InteractiveShellApp): """Replace --pylab='inline' with --pylab='auto'""" if new == 'inline': warn.warn("'inline' not available as pylab backend, " - "using 'auto' instead.\n") + "using 'auto' instead.") self.pylab = 'auto' def start(self): diff --git a/IPython/lib/inputhook.py b/IPython/lib/inputhook.py index b0e9a67..c2dd403 100644 --- a/IPython/lib/inputhook.py +++ b/IPython/lib/inputhook.py @@ -105,7 +105,7 @@ class InputHookManager(object): def __init__(self): if ctypes is None: - warn("IPython GUI event loop requires ctypes, %gui will not be available\n") + warn("IPython GUI event loop requires ctypes, %gui will not be available") return self.PYFUNC = ctypes.PYFUNCTYPE(ctypes.c_int) self._apps = {} diff --git a/IPython/testing/iptest.py b/IPython/testing/iptest.py index bad03d1..65b9e8b 100644 --- a/IPython/testing/iptest.py +++ b/IPython/testing/iptest.py @@ -319,7 +319,7 @@ def make_exclude(): continue fullpath = pjoin(parent, exclusion) if not os.path.exists(fullpath) and not glob.glob(fullpath + '.*'): - warn("Excluding nonexistent file: %r\n" % exclusion) + warn("Excluding nonexistent file: %r" % exclusion) return exclusions diff --git a/IPython/utils/attic.py b/IPython/utils/attic.py index 18d53b0..52fd799 100644 --- a/IPython/utils/attic.py +++ b/IPython/utils/attic.py @@ -130,9 +130,9 @@ def import_fail_info(mod_name,fns=None): """Inform load failure for a module.""" if fns == None: - warn("Loading of %s failed.\n" % (mod_name,)) + warn("Loading of %s failed." % (mod_name,)) else: - warn("Loading of %s from %s failed.\n" % (fns,mod_name)) + warn("Loading of %s from %s failed." % (fns,mod_name)) class NotGiven: pass diff --git a/IPython/utils/warn.py b/IPython/utils/warn.py index 530b70b..693eeb3 100644 --- a/IPython/utils/warn.py +++ b/IPython/utils/warn.py @@ -42,7 +42,7 @@ def warn(msg,level=2,exit_val=1): if level>0: header = ['','','WARNING: ','ERROR: ','FATAL ERROR: '] - io.stderr.write('%s%s' % (header[level],msg)) + print(header[level], msg, sep='', file=io.stderr) if level == 4: print('Exiting.\n', file=io.stderr) sys.exit(exit_val)