##// END OF EJS Templates
Provide and easier way to generate magics and pre-post hooks...
Matthias Bussonnier -
Show More
@@ -0,0 +1,320 b''
1 """
2 This module contains utility function and classes to inject simple ast
3 transformations based on code strings into IPython. While it is already possible
4 with ast-transformers it is not easy to directly manipulate ast.
5
6
7 IPython has pre-code and post-code hooks, but are ran from within the IPython
8 machinery so may be inappropriate, for example for performance mesurement.
9
10 This module give you tools to simplify this, and expose 2 classes:
11
12 - `ReplaceCodeTransformer` which is a simple ast transformer based on code
13 template,
14
15 and for advance case:
16
17 - `Mangler` which is a simple ast transformer that mangle names in the ast.
18
19
20 Example, let's try to make a simple version of the ``timeit`` magic, that run a
21 code snippet 10 times and print the average time taken.
22
23 Basically we want to run :
24
25 .. code-block:: python
26
27 from time import perf_counter
28 now = perf_counter()
29 for i in range(10):
30 __code__ # our code
31 print(f"Time taken: {(perf_counter() - now)/10}")
32 __ret__ # the result of the last statement
33
34 Where ``__code__`` is the code snippet we want to run, and ``__ret__`` is the
35 result, so that if we for example run `dataframe.head()` IPython still display
36 the head of dataframe instead of nothing.
37
38 Here is a complete example of a file `timit2.py` that define such a magic:
39
40 .. code-block:: python
41
42 from IPython.core.magic import (
43 Magics,
44 magics_class,
45 line_cell_magic,
46 )
47 from IPython.core.magics.ast_mod import ReplaceCodeTransformer
48 from textwrap import dedent
49 import ast
50
51 template = template = dedent('''
52 from time import perf_counter
53 now = perf_counter()
54 for i in range(10):
55 __code__
56 print(f"Time taken: {(perf_counter() - now)/10}")
57 __ret__
58 '''
59 )
60
61
62 @magics_class
63 class AstM(Magics):
64 @line_cell_magic
65 def t2(self, line, cell):
66 transformer = ReplaceCodeTransformer.from_string(template)
67 transformer.debug = True
68 transformer.mangler.debug = True
69 new_code = transformer.visit(ast.parse(cell))
70 return exec(compile(new_code, "<ast>", "exec"))
71
72
73 def load_ipython_extension(ip):
74 ip.register_magics(AstM)
75
76
77
78 .. code-block:: python
79
80 In [1]: %load_ext timit2
81
82 In [2]: %%t2
83 ...: import time
84 ...: time.sleep(0.05)
85 ...:
86 ...:
87 Time taken: 0.05435649999999441
88
89
90 If you wish to ran all the code enter in IPython in an ast transformer, you can
91 do so as well:
92
93 .. code-block:: python
94
95 In [1]: from IPython.core.magics.ast_mod import ReplaceCodeTransformer
96 ...:
97 ...: template = '''
98 ...: from time import perf_counter
99 ...: now = perf_counter()
100 ...: __code__
101 ...: print(f"Code ran in {perf_counter()-now}")
102 ...: __ret__'''
103 ...:
104 ...: get_ipython().ast_transformers.append(ReplaceCodeTransformer.from_string(template))
105
106 In [2]: 1+1
107 Code ran in 3.40410006174352e-05
108 Out[2]: 2
109
110
111
112 Hygiene and Mangling
113 --------------------
114
115 The ast transformer above is not hygienic, it may not work if the user code use
116 the same variable names as the ones used in the template. For example.
117
118 To help with this by default the `ReplaceCodeTransformer` will mangle all names
119 staring with 3 underscores. This is a simple heuristic that should work in most
120 case, but can be cumbersome in some case. We provide a `Mangler` class that can
121 be overridden to change the mangling heuristic, or simply use the `mangle_all`
122 utility function. It will _try_ to mangle all names (except `__ret__` and
123 `__code__`), but this include builtins (``print``, ``range``, ``type``) and
124 replace those by invalid identifiers py prepending ``mangle-``:
125 ``mangle-print``, ``mangle-range``, ``mangle-type`` etc. This is not a problem
126 as currently Python AST support invalid identifiers, but it may not be the case
127 in the future.
128
129 You can set `ReplaceCodeTransformer.debug=True` and
130 `ReplaceCodeTransformer.mangler.debug=True` to see the code after mangling and
131 transforming:
132
133 .. code-block:: python
134
135
136 In [1]: from IPython.core.magics.ast_mod import ReplaceCodeTransformer, mangle_all
137 ...:
138 ...: template = '''
139 ...: from builtins import type, print
140 ...: from time import perf_counter
141 ...: now = perf_counter()
142 ...: __code__
143 ...: print(f"Code ran in {perf_counter()-now}")
144 ...: __ret__'''
145 ...:
146 ...: transformer = ReplaceCodeTransformer.from_string(template, mangling_predicate=mangle_all)
147
148
149 In [2]: transformer.debug = True
150 ...: transformer.mangler.debug = True
151 ...: get_ipython().ast_transformers.append(transformer)
152
153 In [3]: 1+1
154 Mangling Alias mangle-type
155 Mangling Alias mangle-print
156 Mangling Alias mangle-perf_counter
157 Mangling now
158 Mangling perf_counter
159 Not mangling __code__
160 Mangling print
161 Mangling perf_counter
162 Mangling now
163 Not mangling __ret__
164 ---- Transformed code ----
165 from builtins import type as mangle-type, print as mangle-print
166 from time import perf_counter as mangle-perf_counter
167 mangle-now = mangle-perf_counter()
168 ret-tmp = 1 + 1
169 mangle-print(f'Code ran in {mangle-perf_counter() - mangle-now}')
170 ret-tmp
171 ---- ---------------- ----
172 Code ran in 0.00013654199938173406
173 Out[3]: 2
174
175
176 """
177
178 __skip_doctest__ = True
179
180
181 from ast import NodeTransformer, Store, Load, Name, Expr, Assign, Module
182 import ast
183 import copy
184
185 from typing import Dict, Optional
186
187
188 mangle_all = lambda name: False if name in ("__ret__", "__code__") else True
189
190
191 class Mangler(NodeTransformer):
192 """
193 Mangle given names in and ast tree to make sure they do not conflict with
194 user code.
195 """
196
197 enabled: bool = True
198 debug: bool = False
199
200 def log(self, *args, **kwargs):
201 if self.debug:
202 print(*args, **kwargs)
203
204 def __init__(self, predicate=None):
205 if predicate is None:
206 predicate = lambda name: name.startswith("___")
207 self.predicate = predicate
208
209 def visit_Name(self, node):
210 if self.predicate(node.id):
211 self.log("Mangling", node.id)
212 # Once in the ast we do not need
213 # names to be valid identifiers.
214 node.id = "mangle-" + node.id
215 else:
216 self.log("Not mangling", node.id)
217 return node
218
219 def visit_FunctionDef(self, node):
220 if self.predicate(node.name):
221 self.log("Mangling", node.name)
222 node.name = "mangle-" + node.name
223 else:
224 self.log("Not mangling", node.name)
225
226 for arg in node.args.args:
227 if self.predicate(arg.arg):
228 self.log("Mangling function arg", arg.arg)
229 arg.arg = "mangle-" + arg.arg
230 else:
231 self.log("Not mangling function arg", arg.arg)
232 return self.generic_visit(node)
233
234 def visit_ImportFrom(self, node):
235 return self._visit_Import_and_ImportFrom(node)
236
237 def visit_Import(self, node):
238 return self._visit_Import_and_ImportFrom(node)
239
240 def _visit_Import_and_ImportFrom(self, node):
241 for alias in node.names:
242 asname = alias.name if alias.asname is None else alias.asname
243 if self.predicate(asname):
244 new_name: str = "mangle-" + asname
245 self.log("Mangling Alias", new_name)
246 alias.asname = new_name
247 else:
248 self.log("Not mangling Alias", alias.asname)
249 return node
250
251
252 class ReplaceCodeTransformer(NodeTransformer):
253 enabled: bool = True
254 debug: bool = False
255 mangler: Mangler
256
257 def __init__(
258 self, template: Module, mapping: Optional[Dict] = None, mangling_predicate=None
259 ):
260 assert isinstance(mapping, (dict, type(None)))
261 assert isinstance(mangling_predicate, (type(None), type(lambda: None)))
262 assert isinstance(template, ast.Module)
263 self.template = template
264 self.mangler = Mangler(predicate=mangling_predicate)
265 if mapping is None:
266 mapping = {}
267 self.mapping = mapping
268
269 @classmethod
270 def from_string(
271 cls, template: str, mapping: Optional[Dict] = None, mangling_predicate=None
272 ):
273 return cls(
274 ast.parse(template), mapping=mapping, mangling_predicate=mangling_predicate
275 )
276
277 def visit_Module(self, code):
278 if not self.enabled:
279 return code
280 # if not isinstance(code, ast.Module):
281 # recursively called...
282 # return generic_visit(self, code)
283 last = code.body[-1]
284 if isinstance(last, Expr):
285 code.body.pop()
286 code.body.append(Assign([Name("ret-tmp", ctx=Store())], value=last.value))
287 ast.fix_missing_locations(code)
288 ret = Expr(value=Name("ret-tmp", ctx=Load()))
289 ret = ast.fix_missing_locations(ret)
290 self.mapping["__ret__"] = ret
291 else:
292 self.mapping["__ret__"] = ast.parse("None").body[0]
293 self.mapping["__code__"] = code.body
294 tpl = ast.fix_missing_locations(self.template)
295
296 tx = copy.deepcopy(tpl)
297 tx = self.mangler.visit(tx)
298 node = self.generic_visit(tx)
299 node_2 = ast.fix_missing_locations(node)
300 if self.debug:
301 print("---- Transformed code ----")
302 print(ast.unparse(node_2))
303 print("---- ---------------- ----")
304 return node_2
305
306 # this does not work as the name might be in a list and one might want to extend the list.
307 # def visit_Name(self, name):
308 # if name.id in self.mapping and name.id == "__ret__":
309 # print(name, "in mapping")
310 # if isinstance(name.ctx, ast.Store):
311 # return Name("tmp", ctx=Store())
312 # else:
313 # return copy.deepcopy(self.mapping[name.id])
314 # return name
315
316 def visit_Expr(self, expr):
317 if isinstance(expr.value, Name) and expr.value.id in self.mapping:
318 if self.mapping[expr.value.id] is not None:
319 return copy.deepcopy(self.mapping[expr.value.id])
320 return self.generic_visit(expr)
@@ -3338,8 +3338,11 b' class InteractiveShell(SingletonConfigurable):'
3338 # an InputRejected. Short-circuit in this case so that we
3338 # an InputRejected. Short-circuit in this case so that we
3339 # don't unregister the transform.
3339 # don't unregister the transform.
3340 raise
3340 raise
3341 except Exception:
3341 except Exception as e:
3342 warn("AST transformer %r threw an error. It will be unregistered." % transformer)
3342 warn(
3343 "AST transformer %r threw an error. It will be unregistered. %s"
3344 % (transformer, e)
3345 )
3343 self.ast_transformers.remove(transformer)
3346 self.ast_transformers.remove(transformer)
3344
3347
3345 if self.ast_transformers:
3348 if self.ast_transformers:
@@ -8,6 +8,7 b''
8 import ast
8 import ast
9 import bdb
9 import bdb
10 import builtins as builtin_mod
10 import builtins as builtin_mod
11 import copy
11 import cProfile as profile
12 import cProfile as profile
12 import gc
13 import gc
13 import itertools
14 import itertools
@@ -19,14 +20,28 b' import shlex'
19 import sys
20 import sys
20 import time
21 import time
21 import timeit
22 import timeit
22 from ast import Module
23 from typing import Dict, Any
24 from ast import (
25 Assign,
26 Call,
27 Expr,
28 Load,
29 Module,
30 Name,
31 NodeTransformer,
32 Store,
33 parse,
34 unparse,
35 )
23 from io import StringIO
36 from io import StringIO
24 from logging import error
37 from logging import error
25 from pathlib import Path
38 from pathlib import Path
26 from pdb import Restart
39 from pdb import Restart
40 from textwrap import dedent, indent
27 from warnings import warn
41 from warnings import warn
28
42
29 from IPython.core import magic_arguments, oinspect, page
43 from IPython.core import magic_arguments, oinspect, page
44 from IPython.core.displayhook import DisplayHook
30 from IPython.core.error import UsageError
45 from IPython.core.error import UsageError
31 from IPython.core.macro import Macro
46 from IPython.core.macro import Macro
32 from IPython.core.magic import (
47 from IPython.core.magic import (
@@ -37,8 +52,8 b' from IPython.core.magic import ('
37 magics_class,
52 magics_class,
38 needs_local_scope,
53 needs_local_scope,
39 no_var_expand,
54 no_var_expand,
40 output_can_be_silenced,
41 on_off,
55 on_off,
56 output_can_be_silenced,
42 )
57 )
43 from IPython.testing.skipdoctest import skip_doctest
58 from IPython.testing.skipdoctest import skip_doctest
44 from IPython.utils.capture import capture_output
59 from IPython.utils.capture import capture_output
@@ -47,7 +62,7 b' from IPython.utils.ipstruct import Struct'
47 from IPython.utils.module_paths import find_mod
62 from IPython.utils.module_paths import find_mod
48 from IPython.utils.path import get_py_filename, shellglob
63 from IPython.utils.path import get_py_filename, shellglob
49 from IPython.utils.timing import clock, clock2
64 from IPython.utils.timing import clock, clock2
50 from IPython.core.displayhook import DisplayHook
65 from IPython.core.magics.ast_mod import ReplaceCodeTransformer
51
66
52 #-----------------------------------------------------------------------------
67 #-----------------------------------------------------------------------------
53 # Magic implementation classes
68 # Magic implementation classes
@@ -164,9 +179,9 b' class Timer(timeit.Timer):'
164
179
165 @magics_class
180 @magics_class
166 class ExecutionMagics(Magics):
181 class ExecutionMagics(Magics):
167 """Magics related to code execution, debugging, profiling, etc.
182 """Magics related to code execution, debugging, profiling, etc."""
168
183
169 """
184 _transformers: Dict[str, Any] = {}
170
185
171 def __init__(self, shell):
186 def __init__(self, shell):
172 super(ExecutionMagics, self).__init__(shell)
187 super(ExecutionMagics, self).__init__(shell)
@@ -1474,6 +1489,83 b' class ExecutionMagics(Magics):'
1474 elif args.output:
1489 elif args.output:
1475 self.shell.user_ns[args.output] = io
1490 self.shell.user_ns[args.output] = io
1476
1491
1492 @skip_doctest
1493 @magic_arguments.magic_arguments()
1494 @magic_arguments.argument("name", type=str, default="default", nargs="?")
1495 @magic_arguments.argument(
1496 "--remove", action="store_true", help="remove the current transformer"
1497 )
1498 @magic_arguments.argument(
1499 "--list", action="store_true", help="list existing transformers name"
1500 )
1501 @magic_arguments.argument(
1502 "--list-all",
1503 action="store_true",
1504 help="list existing transformers name and code template",
1505 )
1506 @line_cell_magic
1507 def code_wrap(self, line, cell=None):
1508 """
1509 Simple magic to quickly define a code transformer for all IPython's future imput.
1510
1511 ``__code__`` and ``__ret__`` are special variable that represent the code to run
1512 and the value of the last expression of ``__code__`` respectively.
1513
1514 Examples
1515 --------
1516
1517 .. ipython::
1518
1519 In [1]: %%code_wrap before_after
1520 ...: print('before')
1521 ...: __code__
1522 ...: print('after')
1523 ...: __ret__
1524
1525
1526 In [2]: 1
1527 before
1528 after
1529 Out[2]: 1
1530
1531 In [3]: %code_wrap --list
1532 before_after
1533
1534 In [4]: %code_wrap --list-all
1535 before_after :
1536 print('before')
1537 __code__
1538 print('after')
1539 __ret__
1540
1541 In [5]: %code_wrap --remove before_after
1542
1543 """
1544 args = magic_arguments.parse_argstring(self.code_wrap, line)
1545
1546 if args.list:
1547 for name in self._transformers.keys():
1548 print(name)
1549 return
1550 if args.list_all:
1551 for name, _t in self._transformers.items():
1552 print(name, ":")
1553 print(indent(ast.unparse(_t.template), " "))
1554 print()
1555 return
1556
1557 to_remove = self._transformers.pop(args.name, None)
1558 if to_remove in self.shell.ast_transformers:
1559 self.shell.ast_transformers.remove(to_remove)
1560 if cell is None or args.remove:
1561 return
1562
1563 _trs = ReplaceCodeTransformer(ast.parse(cell))
1564
1565 self._transformers[args.name] = _trs
1566 self.shell.ast_transformers.append(_trs)
1567
1568
1477 def parse_breakpoint(text, current_file):
1569 def parse_breakpoint(text, current_file):
1478 '''Returns (file, line) for file:line and (current_file, line) for line'''
1570 '''Returns (file, line) for file:line and (current_file, line) for line'''
1479 colon = text.find(':')
1571 colon = text.find(':')
@@ -1519,4 +1611,4 b' def _format_time(timespan, precision=3):'
1519 order = min(-int(math.floor(math.log10(timespan)) // 3), 3)
1611 order = min(-int(math.floor(math.log10(timespan)) // 3), 3)
1520 else:
1612 else:
1521 order = 3
1613 order = 3
1522 return u"%.*g %s" % (precision, timespan * scaling[order], units[order])
1614 return "%.*g %s" % (precision, timespan * scaling[order], units[order])
@@ -24,8 +24,6 b" if __name__ == '__main__':"
24 docwriter.package_skip_patterns += [r'\.external$',
24 docwriter.package_skip_patterns += [r'\.external$',
25 # Extensions are documented elsewhere.
25 # Extensions are documented elsewhere.
26 r'\.extensions',
26 r'\.extensions',
27 # Magics are documented separately
28 r'\.core\.magics',
29 # This isn't API
27 # This isn't API
30 r'\.sphinxext',
28 r'\.sphinxext',
31 # Shims
29 # Shims
General Comments 0
You need to be logged in to leave comments. Login now