##// END OF EJS Templates
Combined recursive approach with check for already visited objects to avoid infinite recursion....
Niclas -
Show More
@@ -1,642 +1,591 b''
1 """IPython extension to reload modules before executing user code.
1 """IPython extension to reload modules before executing user code.
2
2
3 ``autoreload`` reloads modules automatically before entering the execution of
3 ``autoreload`` reloads modules automatically before entering the execution of
4 code typed at the IPython prompt.
4 code typed at the IPython prompt.
5
5
6 This makes for example the following workflow possible:
6 This makes for example the following workflow possible:
7
7
8 .. sourcecode:: ipython
8 .. sourcecode:: ipython
9
9
10 In [1]: %load_ext autoreload
10 In [1]: %load_ext autoreload
11
11
12 In [2]: %autoreload 2
12 In [2]: %autoreload 2
13
13
14 In [3]: from foo import some_function
14 In [3]: from foo import some_function
15
15
16 In [4]: some_function()
16 In [4]: some_function()
17 Out[4]: 42
17 Out[4]: 42
18
18
19 In [5]: # open foo.py in an editor and change some_function to return 43
19 In [5]: # open foo.py in an editor and change some_function to return 43
20
20
21 In [6]: some_function()
21 In [6]: some_function()
22 Out[6]: 43
22 Out[6]: 43
23
23
24 The module was reloaded without reloading it explicitly, and the object
24 The module was reloaded without reloading it explicitly, and the object
25 imported with ``from foo import ...`` was also updated.
25 imported with ``from foo import ...`` was also updated.
26
26
27 Usage
27 Usage
28 =====
28 =====
29
29
30 The following magic commands are provided:
30 The following magic commands are provided:
31
31
32 ``%autoreload``
32 ``%autoreload``
33
33
34 Reload all modules (except those excluded by ``%aimport``)
34 Reload all modules (except those excluded by ``%aimport``)
35 automatically now.
35 automatically now.
36
36
37 ``%autoreload 0``
37 ``%autoreload 0``
38
38
39 Disable automatic reloading.
39 Disable automatic reloading.
40
40
41 ``%autoreload 1``
41 ``%autoreload 1``
42
42
43 Reload all modules imported with ``%aimport`` every time before
43 Reload all modules imported with ``%aimport`` every time before
44 executing the Python code typed.
44 executing the Python code typed.
45
45
46 ``%autoreload 2``
46 ``%autoreload 2``
47
47
48 Reload all modules (except those excluded by ``%aimport``) every
48 Reload all modules (except those excluded by ``%aimport``) every
49 time before executing the Python code typed.
49 time before executing the Python code typed.
50
50
51 ``%aimport``
51 ``%aimport``
52
52
53 List modules which are to be automatically imported or not to be imported.
53 List modules which are to be automatically imported or not to be imported.
54
54
55 ``%aimport foo``
55 ``%aimport foo``
56
56
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
58
58
59 ``%aimport foo, bar``
59 ``%aimport foo, bar``
60
60
61 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
61 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
62
62
63 ``%aimport -foo``
63 ``%aimport -foo``
64
64
65 Mark module 'foo' to not be autoreloaded.
65 Mark module 'foo' to not be autoreloaded.
66
66
67 Caveats
67 Caveats
68 =======
68 =======
69
69
70 Reloading Python modules in a reliable way is in general difficult,
70 Reloading Python modules in a reliable way is in general difficult,
71 and unexpected things may occur. ``%autoreload`` tries to work around
71 and unexpected things may occur. ``%autoreload`` tries to work around
72 common pitfalls by replacing function code objects and parts of
72 common pitfalls by replacing function code objects and parts of
73 classes previously in the module with new versions. This makes the
73 classes previously in the module with new versions. This makes the
74 following things to work:
74 following things to work:
75
75
76 - Functions and classes imported via 'from xxx import foo' are upgraded
76 - Functions and classes imported via 'from xxx import foo' are upgraded
77 to new versions when 'xxx' is reloaded.
77 to new versions when 'xxx' is reloaded.
78
78
79 - Methods and properties of classes are upgraded on reload, so that
79 - Methods and properties of classes are upgraded on reload, so that
80 calling 'c.foo()' on an object 'c' created before the reload causes
80 calling 'c.foo()' on an object 'c' created before the reload causes
81 the new code for 'foo' to be executed.
81 the new code for 'foo' to be executed.
82
82
83 Some of the known remaining caveats are:
83 Some of the known remaining caveats are:
84
84
85 - Replacing code objects does not always succeed: changing a @property
85 - Replacing code objects does not always succeed: changing a @property
86 in a class to an ordinary method or a method to a member variable
86 in a class to an ordinary method or a method to a member variable
87 can cause problems (but in old objects only).
87 can cause problems (but in old objects only).
88
88
89 - Functions that are removed (eg. via monkey-patching) from a module
89 - Functions that are removed (eg. via monkey-patching) from a module
90 before it is reloaded are not upgraded.
90 before it is reloaded are not upgraded.
91
91
92 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
92 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
93 """
93 """
94
94
95 skip_doctest = True
95 skip_doctest = True
96
96
97 #-----------------------------------------------------------------------------
97 #-----------------------------------------------------------------------------
98 # Copyright (C) 2000 Thomas Heller
98 # Copyright (C) 2000 Thomas Heller
99 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
99 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
100 # Copyright (C) 2012 The IPython Development Team
100 # Copyright (C) 2012 The IPython Development Team
101 #
101 #
102 # Distributed under the terms of the BSD License. The full license is in
102 # Distributed under the terms of the BSD License. The full license is in
103 # the file COPYING, distributed as part of this software.
103 # the file COPYING, distributed as part of this software.
104 #-----------------------------------------------------------------------------
104 #-----------------------------------------------------------------------------
105 #
105 #
106 # This IPython module is written by Pauli Virtanen, based on the autoreload
106 # This IPython module is written by Pauli Virtanen, based on the autoreload
107 # code by Thomas Heller.
107 # code by Thomas Heller.
108
108
109 #-----------------------------------------------------------------------------
109 #-----------------------------------------------------------------------------
110 # Imports
110 # Imports
111 #-----------------------------------------------------------------------------
111 #-----------------------------------------------------------------------------
112
112
113 import os
113 import os
114 import sys
114 import sys
115 import traceback
115 import traceback
116 import types
116 import types
117 import weakref
117 import weakref
118 import inspect
118 import inspect
119 from importlib import import_module
119 from importlib import import_module
120 from importlib.util import source_from_cache
120 from importlib.util import source_from_cache
121 from imp import reload
121 from imp import reload
122
122
123 #------------------------------------------------------------------------------
123 #------------------------------------------------------------------------------
124 # Autoreload functionality
124 # Autoreload functionality
125 #------------------------------------------------------------------------------
125 #------------------------------------------------------------------------------
126
126
127 class ModuleReloader(object):
127 class ModuleReloader(object):
128 enabled = False
128 enabled = False
129 """Whether this reloader is enabled"""
129 """Whether this reloader is enabled"""
130
130
131 check_all = True
131 check_all = True
132 """Autoreload all modules, not just those listed in 'modules'"""
132 """Autoreload all modules, not just those listed in 'modules'"""
133
133
134 def __init__(self):
134 def __init__(self):
135 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
135 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
136 self.failed = {}
136 self.failed = {}
137 # Modules specially marked as autoreloadable.
137 # Modules specially marked as autoreloadable.
138 self.modules = {}
138 self.modules = {}
139 # Modules specially marked as not autoreloadable.
139 # Modules specially marked as not autoreloadable.
140 self.skip_modules = {}
140 self.skip_modules = {}
141 # (module-name, name) -> weakref, for replacing old code objects
141 # (module-name, name) -> weakref, for replacing old code objects
142 self.old_objects = {}
142 self.old_objects = {}
143 # Module modification timestamps
143 # Module modification timestamps
144 self.modules_mtimes = {}
144 self.modules_mtimes = {}
145
145
146 # Cache module modification times
146 # Cache module modification times
147 self.check(check_all=True, do_reload=False)
147 self.check(check_all=True, do_reload=False)
148
148
149 def mark_module_skipped(self, module_name):
149 def mark_module_skipped(self, module_name):
150 """Skip reloading the named module in the future"""
150 """Skip reloading the named module in the future"""
151 try:
151 try:
152 del self.modules[module_name]
152 del self.modules[module_name]
153 except KeyError:
153 except KeyError:
154 pass
154 pass
155 self.skip_modules[module_name] = True
155 self.skip_modules[module_name] = True
156
156
157 def mark_module_reloadable(self, module_name):
157 def mark_module_reloadable(self, module_name):
158 """Reload the named module in the future (if it is imported)"""
158 """Reload the named module in the future (if it is imported)"""
159 try:
159 try:
160 del self.skip_modules[module_name]
160 del self.skip_modules[module_name]
161 except KeyError:
161 except KeyError:
162 pass
162 pass
163 self.modules[module_name] = True
163 self.modules[module_name] = True
164
164
165 def aimport_module(self, module_name):
165 def aimport_module(self, module_name):
166 """Import a module, and mark it reloadable
166 """Import a module, and mark it reloadable
167
167
168 Returns
168 Returns
169 -------
169 -------
170 top_module : module
170 top_module : module
171 The imported module if it is top-level, or the top-level
171 The imported module if it is top-level, or the top-level
172 top_name : module
172 top_name : module
173 Name of top_module
173 Name of top_module
174
174
175 """
175 """
176 self.mark_module_reloadable(module_name)
176 self.mark_module_reloadable(module_name)
177
177
178 import_module(module_name)
178 import_module(module_name)
179 top_name = module_name.split('.')[0]
179 top_name = module_name.split('.')[0]
180 top_module = sys.modules[top_name]
180 top_module = sys.modules[top_name]
181 return top_module, top_name
181 return top_module, top_name
182
182
183 def filename_and_mtime(self, module):
183 def filename_and_mtime(self, module):
184 if not hasattr(module, '__file__') or module.__file__ is None:
184 if not hasattr(module, '__file__') or module.__file__ is None:
185 return None, None
185 return None, None
186
186
187 if getattr(module, '__name__', None) in [None, '__mp_main__', '__main__']:
187 if getattr(module, '__name__', None) in [None, '__mp_main__', '__main__']:
188 # we cannot reload(__main__) or reload(__mp_main__)
188 # we cannot reload(__main__) or reload(__mp_main__)
189 return None, None
189 return None, None
190
190
191 filename = module.__file__
191 filename = module.__file__
192 path, ext = os.path.splitext(filename)
192 path, ext = os.path.splitext(filename)
193
193
194 if ext.lower() == '.py':
194 if ext.lower() == '.py':
195 py_filename = filename
195 py_filename = filename
196 else:
196 else:
197 try:
197 try:
198 py_filename = source_from_cache(filename)
198 py_filename = source_from_cache(filename)
199 except ValueError:
199 except ValueError:
200 return None, None
200 return None, None
201
201
202 try:
202 try:
203 pymtime = os.stat(py_filename).st_mtime
203 pymtime = os.stat(py_filename).st_mtime
204 except OSError:
204 except OSError:
205 return None, None
205 return None, None
206
206
207 return py_filename, pymtime
207 return py_filename, pymtime
208
208
209 def check(self, check_all=False, do_reload=True):
209 def check(self, check_all=False, do_reload=True):
210 """Check whether some modules need to be reloaded."""
210 """Check whether some modules need to be reloaded."""
211
211
212 if not self.enabled and not check_all:
212 if not self.enabled and not check_all:
213 return
213 return
214
214
215 if check_all or self.check_all:
215 if check_all or self.check_all:
216 modules = list(sys.modules.keys())
216 modules = list(sys.modules.keys())
217 else:
217 else:
218 modules = list(self.modules.keys())
218 modules = list(self.modules.keys())
219
219
220 for modname in modules:
220 for modname in modules:
221 m = sys.modules.get(modname, None)
221 m = sys.modules.get(modname, None)
222
222
223 if modname in self.skip_modules:
223 if modname in self.skip_modules:
224 continue
224 continue
225
225
226 py_filename, pymtime = self.filename_and_mtime(m)
226 py_filename, pymtime = self.filename_and_mtime(m)
227 if py_filename is None:
227 if py_filename is None:
228 continue
228 continue
229
229
230 try:
230 try:
231 if pymtime <= self.modules_mtimes[modname]:
231 if pymtime <= self.modules_mtimes[modname]:
232 continue
232 continue
233 except KeyError:
233 except KeyError:
234 self.modules_mtimes[modname] = pymtime
234 self.modules_mtimes[modname] = pymtime
235 continue
235 continue
236 else:
236 else:
237 if self.failed.get(py_filename, None) == pymtime:
237 if self.failed.get(py_filename, None) == pymtime:
238 continue
238 continue
239
239
240 self.modules_mtimes[modname] = pymtime
240 self.modules_mtimes[modname] = pymtime
241
241
242 # If we've reached this point, we should try to reload the module
242 # If we've reached this point, we should try to reload the module
243 if do_reload:
243 if do_reload:
244 try:
244 try:
245 superreload(m, reload, self.old_objects)
245 superreload(m, reload, self.old_objects)
246 if py_filename in self.failed:
246 if py_filename in self.failed:
247 del self.failed[py_filename]
247 del self.failed[py_filename]
248 except:
248 except:
249 print("[autoreload of %s failed: %s]" % (
249 print("[autoreload of %s failed: %s]" % (
250 modname, traceback.format_exc(10)), file=sys.stderr)
250 modname, traceback.format_exc(10)), file=sys.stderr)
251 self.failed[py_filename] = pymtime
251 self.failed[py_filename] = pymtime
252
252
253 #------------------------------------------------------------------------------
253 #------------------------------------------------------------------------------
254 # superreload
254 # superreload
255 #------------------------------------------------------------------------------
255 #------------------------------------------------------------------------------
256
256
257
257
258 func_attrs = ['__code__', '__defaults__', '__doc__',
258 func_attrs = ['__code__', '__defaults__', '__doc__',
259 '__closure__', '__globals__', '__dict__']
259 '__closure__', '__globals__', '__dict__']
260
260
261
261
262 def update_function(old, new):
262 def update_function(old, new):
263 """Upgrade the code object of a function"""
263 """Upgrade the code object of a function"""
264 for name in func_attrs:
264 for name in func_attrs:
265 try:
265 try:
266 setattr(old, name, getattr(new, name))
266 setattr(old, name, getattr(new, name))
267 except (AttributeError, TypeError):
267 except (AttributeError, TypeError):
268 pass
268 pass
269
269
270
270
271 def _find_instances(old_type):
271 def update_instances(old, new, objects=None, visited={}):
272 """Try to find all instances of a class that need updating.
273
274 Classic graph exploration, we want to avoid re-visiting object multiple times.
275 """
276 # find ipython workspace stack frame, this is just to bootstrap where we
277 # find the object that need updating.
278 frame = next(frame_nfo.frame for frame_nfo in inspect.stack()
279 if 'trigger' in frame_nfo.function)
280 # build generator for non-private variable values from workspace
281 shell = frame.f_locals['self'].shell
282 user_ns = shell.user_ns
283 user_ns_hidden = shell.user_ns_hidden
284 nonmatching = object()
285 objects = ( value for key, value in user_ns.items()
286 if not key.startswith('_')
287 and (value is not user_ns_hidden.get(key, nonmatching))
288 and not inspect.ismodule(value))
289
290 # note: in the following we do use dict as object might not be hashable.
291 # list of objects we found that will need an update.
292 to_update = {}
293
294 # list of object we have not recursed into yet
295 open_set = {}
296
297 # list of object we have visited already
298 closed_set = {}
299
300 open_set.update({id(o):o for o in objects})
301
302 it = 0
303 while len(open_set) > 0:
304 it += 1
305 if it > 100_000:
306 raise ValueError('infinite')
307 (current_id,current) = next(iter(open_set.items()))
308 if type(current) is old_type:
309 to_update[current_id] = current
310 if hasattr(current, '__dict__') and not (inspect.isfunction(current)
311 or inspect.ismethod(current)):
312 potential_new = {id(o):o for o in current.__dict__.values() if id(o) not in closed_set.keys()}
313 open_set.update(potential_new)
314 # if object is a container, search it
315 if hasattr(current, 'items') or (hasattr(current, '__contains__')
316 and not isinstance(current, str)):
317 potential_new = (value for key, value in current.items()
318 if not str(key).startswith('_')
319 and not inspect.ismodule(value) and not id(value) in closed_set.keys())
320 open_set.update(potential_new)
321 del open_set[id(current)]
322 closed_set[id(current)] = current
323 return to_update.values()
324
325 def update_instances(old, new, objects=None):
326 """Iterate through objects recursively, searching for instances of old and
272 """Iterate through objects recursively, searching for instances of old and
327 replace their __class__ reference with new. If no objects are given, start
273 replace their __class__ reference with new. If no objects are given, start
328 with the current ipython workspace.
274 with the current ipython workspace.
329 """
275 """
330 if not objects:
276 if objects is None:
277 # make sure visited is cleaned when not called recursively
278 visited = {}
331 # find ipython workspace stack frame
279 # find ipython workspace stack frame
332 frame = next(frame_nfo.frame for frame_nfo in inspect.stack()
280 frame = next(frame_nfo.frame for frame_nfo in inspect.stack()
333 if 'trigger' in frame_nfo.function)
281 if 'trigger' in frame_nfo.function)
334 # build generator for non-private variable values from workspace
282 # build generator for non-private variable values from workspace
335 shell = frame.f_locals['self'].shell
283 shell = frame.f_locals['self'].shell
336 user_ns = shell.user_ns
284 user_ns = shell.user_ns
337 user_ns_hidden = shell.user_ns_hidden
285 user_ns_hidden = shell.user_ns_hidden
338 nonmatching = object()
286 nonmatching = object()
339 objects = ( value for key, value in user_ns.items()
287 objects = ( value for key, value in user_ns.items()
340 if not key.startswith('_')
288 if not key.startswith('_')
341 and (value is not user_ns_hidden.get(key, nonmatching))
289 and (value is not user_ns_hidden.get(key, nonmatching))
342 and not inspect.ismodule(value))
290 and not inspect.ismodule(value))
343
291
344 # use dict values if objects is a dict but don't touch private variables
292 # use dict values if objects is a dict but don't touch private variables
345 if hasattr(objects, 'items'):
293 if hasattr(objects, 'items'):
346 objects = (value for key, value in objects.items()
294 objects = (value for key, value in objects.items()
347 if not str(key).startswith('_')
295 if not str(key).startswith('_')
348 and not inspect.ismodule(value) )
296 and not inspect.ismodule(value) )
349
297
350 # try if objects is iterable
298 # try if objects is iterable
351 try:
299 try:
352 for obj in objects:
300 for obj in (obj for obj in objects if id(obj) not in visited):
301 # add current object to visited to avoid revisiting
302 visited.update({id(obj):obj})
353
303
354 # update, if object is instance of old_class (but no subclasses)
304 # update, if object is instance of old_class (but no subclasses)
355 if type(obj) is old:
305 if type(obj) is old:
356 obj.__class__ = new
306 obj.__class__ = new
357
307
358
308
359 # if object is instance of other class, look for nested instances
309 # if object is instance of other class, look for nested instances
360 if hasattr(obj, '__dict__') and not (inspect.isfunction(obj)
310 if hasattr(obj, '__dict__') and not (inspect.isfunction(obj)
361 or inspect.ismethod(obj)):
311 or inspect.ismethod(obj)):
362 update_instances(old, new, obj.__dict__)
312 update_instances(old, new, obj.__dict__, visited)
363
313
364 # if object is a container, search it
314 # if object is a container, search it
365 if hasattr(obj, 'items') or (hasattr(obj, '__contains__')
315 if hasattr(obj, 'items') or (hasattr(obj, '__contains__')
366 and not isinstance(obj, str)):
316 and not isinstance(obj, str)):
367 update_instances(old, new, obj)
317 update_instances(old, new, obj, visited)
368
318
369 except TypeError:
319 except TypeError:
370 pass
320 pass
371
321
372
322
373 def update_class(old, new):
323 def update_class(old, new):
374 """Replace stuff in the __dict__ of a class, and upgrade
324 """Replace stuff in the __dict__ of a class, and upgrade
375 method code objects, and add new methods, if any"""
325 method code objects, and add new methods, if any"""
376 print('old is', old)
326 print('old is', id(old))
377 for key in list(old.__dict__.keys()):
327 for key in list(old.__dict__.keys()):
378 old_obj = getattr(old, key)
328 old_obj = getattr(old, key)
379 try:
329 try:
380 new_obj = getattr(new, key)
330 new_obj = getattr(new, key)
381 # explicitly checking that comparison returns True to handle
331 # explicitly checking that comparison returns True to handle
382 # cases where `==` doesn't return a boolean.
332 # cases where `==` doesn't return a boolean.
383 if (old_obj == new_obj) is True:
333 if (old_obj == new_obj) is True:
384 continue
334 continue
385 except AttributeError:
335 except AttributeError:
386 # obsolete attribute: remove it
336 # obsolete attribute: remove it
387 try:
337 try:
388 delattr(old, key)
338 delattr(old, key)
389 except (AttributeError, TypeError):
339 except (AttributeError, TypeError):
390 pass
340 pass
391 continue
341 continue
392
342
393 if update_generic(old_obj, new_obj): continue
343 if update_generic(old_obj, new_obj): continue
394
344
395 try:
345 try:
396 setattr(old, key, getattr(new, key))
346 setattr(old, key, getattr(new, key))
397 except (AttributeError, TypeError):
347 except (AttributeError, TypeError):
398 pass # skip non-writable attributes
348 pass # skip non-writable attributes
399
349
400 for key in list(new.__dict__.keys()):
350 for key in list(new.__dict__.keys()):
401 if key not in list(old.__dict__.keys()):
351 if key not in list(old.__dict__.keys()):
402 try:
352 try:
403 setattr(old, key, getattr(new, key))
353 setattr(old, key, getattr(new, key))
404 except (AttributeError, TypeError):
354 except (AttributeError, TypeError):
405 pass # skip non-writable attributes
355 pass # skip non-writable attributes
406
356
407 # update all instances of class
357 # update all instances of class
408 for instance in _find_instances(old):
358 update_instances(old, new)
409 instance.__class__ = new
410
359
411
360
412 def update_property(old, new):
361 def update_property(old, new):
413 """Replace get/set/del functions of a property"""
362 """Replace get/set/del functions of a property"""
414 update_generic(old.fdel, new.fdel)
363 update_generic(old.fdel, new.fdel)
415 update_generic(old.fget, new.fget)
364 update_generic(old.fget, new.fget)
416 update_generic(old.fset, new.fset)
365 update_generic(old.fset, new.fset)
417
366
418
367
419 def isinstance2(a, b, typ):
368 def isinstance2(a, b, typ):
420 return isinstance(a, typ) and isinstance(b, typ)
369 return isinstance(a, typ) and isinstance(b, typ)
421
370
422
371
423 UPDATE_RULES = [
372 UPDATE_RULES = [
424 (lambda a, b: isinstance2(a, b, type),
373 (lambda a, b: isinstance2(a, b, type),
425 update_class),
374 update_class),
426 (lambda a, b: isinstance2(a, b, types.FunctionType),
375 (lambda a, b: isinstance2(a, b, types.FunctionType),
427 update_function),
376 update_function),
428 (lambda a, b: isinstance2(a, b, property),
377 (lambda a, b: isinstance2(a, b, property),
429 update_property),
378 update_property),
430 ]
379 ]
431 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
380 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
432 lambda a, b: update_function(a.__func__, b.__func__)),
381 lambda a, b: update_function(a.__func__, b.__func__)),
433 ])
382 ])
434
383
435
384
436 def update_generic(a, b):
385 def update_generic(a, b):
437 for type_check, update in UPDATE_RULES:
386 for type_check, update in UPDATE_RULES:
438 if type_check(a, b):
387 if type_check(a, b):
439 update(a, b)
388 update(a, b)
440 return True
389 return True
441 return False
390 return False
442
391
443
392
444 class StrongRef(object):
393 class StrongRef(object):
445 def __init__(self, obj):
394 def __init__(self, obj):
446 self.obj = obj
395 self.obj = obj
447 def __call__(self):
396 def __call__(self):
448 return self.obj
397 return self.obj
449
398
450
399
451 def superreload(module, reload=reload, old_objects=None):
400 def superreload(module, reload=reload, old_objects=None):
452 """Enhanced version of the builtin reload function.
401 """Enhanced version of the builtin reload function.
453
402
454 superreload remembers objects previously in the module, and
403 superreload remembers objects previously in the module, and
455
404
456 - upgrades the class dictionary of every old class in the module
405 - upgrades the class dictionary of every old class in the module
457 - upgrades the code object of every old function and method
406 - upgrades the code object of every old function and method
458 - clears the module's namespace before reloading
407 - clears the module's namespace before reloading
459
408
460 """
409 """
461 if old_objects is None:
410 if old_objects is None:
462 old_objects = {}
411 old_objects = {}
463
412
464 # collect old objects in the module
413 # collect old objects in the module
465 for name, obj in list(module.__dict__.items()):
414 for name, obj in list(module.__dict__.items()):
466 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
415 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
467 continue
416 continue
468 key = (module.__name__, name)
417 key = (module.__name__, name)
469 try:
418 try:
470 old_objects.setdefault(key, []).append(weakref.ref(obj))
419 old_objects.setdefault(key, []).append(weakref.ref(obj))
471 except TypeError:
420 except TypeError:
472 pass
421 pass
473
422
474 # reload module
423 # reload module
475 try:
424 try:
476 # clear namespace first from old cruft
425 # clear namespace first from old cruft
477 old_dict = module.__dict__.copy()
426 old_dict = module.__dict__.copy()
478 old_name = module.__name__
427 old_name = module.__name__
479 module.__dict__.clear()
428 module.__dict__.clear()
480 module.__dict__['__name__'] = old_name
429 module.__dict__['__name__'] = old_name
481 module.__dict__['__loader__'] = old_dict['__loader__']
430 module.__dict__['__loader__'] = old_dict['__loader__']
482 except (TypeError, AttributeError, KeyError):
431 except (TypeError, AttributeError, KeyError):
483 pass
432 pass
484
433
485 try:
434 try:
486 module = reload(module)
435 module = reload(module)
487 except:
436 except:
488 # restore module dictionary on failed reload
437 # restore module dictionary on failed reload
489 module.__dict__.update(old_dict)
438 module.__dict__.update(old_dict)
490 raise
439 raise
491
440
492 # iterate over all objects and update functions & classes
441 # iterate over all objects and update functions & classes
493 for name, new_obj in list(module.__dict__.items()):
442 for name, new_obj in list(module.__dict__.items()):
494 key = (module.__name__, name)
443 key = (module.__name__, name)
495 if key not in old_objects: continue
444 if key not in old_objects: continue
496
445
497 new_refs = []
446 new_refs = []
498 for old_ref in old_objects[key]:
447 for old_ref in old_objects[key]:
499 old_obj = old_ref()
448 old_obj = old_ref()
500 if old_obj is None: continue
449 if old_obj is None: continue
501 new_refs.append(old_ref)
450 new_refs.append(old_ref)
502 update_generic(old_obj, new_obj)
451 update_generic(old_obj, new_obj)
503
452
504 if new_refs:
453 if new_refs:
505 old_objects[key] = new_refs
454 old_objects[key] = new_refs
506 else:
455 else:
507 del old_objects[key]
456 del old_objects[key]
508
457
509 return module
458 return module
510
459
511 #------------------------------------------------------------------------------
460 #------------------------------------------------------------------------------
512 # IPython connectivity
461 # IPython connectivity
513 #------------------------------------------------------------------------------
462 #------------------------------------------------------------------------------
514
463
515 from IPython.core.magic import Magics, magics_class, line_magic
464 from IPython.core.magic import Magics, magics_class, line_magic
516
465
517 @magics_class
466 @magics_class
518 class AutoreloadMagics(Magics):
467 class AutoreloadMagics(Magics):
519 def __init__(self, *a, **kw):
468 def __init__(self, *a, **kw):
520 super(AutoreloadMagics, self).__init__(*a, **kw)
469 super(AutoreloadMagics, self).__init__(*a, **kw)
521 self._reloader = ModuleReloader()
470 self._reloader = ModuleReloader()
522 self._reloader.check_all = False
471 self._reloader.check_all = False
523 self.loaded_modules = set(sys.modules)
472 self.loaded_modules = set(sys.modules)
524
473
525 @line_magic
474 @line_magic
526 def autoreload(self, parameter_s=''):
475 def autoreload(self, parameter_s=''):
527 r"""%autoreload => Reload modules automatically
476 r"""%autoreload => Reload modules automatically
528
477
529 %autoreload
478 %autoreload
530 Reload all modules (except those excluded by %aimport) automatically
479 Reload all modules (except those excluded by %aimport) automatically
531 now.
480 now.
532
481
533 %autoreload 0
482 %autoreload 0
534 Disable automatic reloading.
483 Disable automatic reloading.
535
484
536 %autoreload 1
485 %autoreload 1
537 Reload all modules imported with %aimport every time before executing
486 Reload all modules imported with %aimport every time before executing
538 the Python code typed.
487 the Python code typed.
539
488
540 %autoreload 2
489 %autoreload 2
541 Reload all modules (except those excluded by %aimport) every time
490 Reload all modules (except those excluded by %aimport) every time
542 before executing the Python code typed.
491 before executing the Python code typed.
543
492
544 Reloading Python modules in a reliable way is in general
493 Reloading Python modules in a reliable way is in general
545 difficult, and unexpected things may occur. %autoreload tries to
494 difficult, and unexpected things may occur. %autoreload tries to
546 work around common pitfalls by replacing function code objects and
495 work around common pitfalls by replacing function code objects and
547 parts of classes previously in the module with new versions. This
496 parts of classes previously in the module with new versions. This
548 makes the following things to work:
497 makes the following things to work:
549
498
550 - Functions and classes imported via 'from xxx import foo' are upgraded
499 - Functions and classes imported via 'from xxx import foo' are upgraded
551 to new versions when 'xxx' is reloaded.
500 to new versions when 'xxx' is reloaded.
552
501
553 - Methods and properties of classes are upgraded on reload, so that
502 - Methods and properties of classes are upgraded on reload, so that
554 calling 'c.foo()' on an object 'c' created before the reload causes
503 calling 'c.foo()' on an object 'c' created before the reload causes
555 the new code for 'foo' to be executed.
504 the new code for 'foo' to be executed.
556
505
557 Some of the known remaining caveats are:
506 Some of the known remaining caveats are:
558
507
559 - Replacing code objects does not always succeed: changing a @property
508 - Replacing code objects does not always succeed: changing a @property
560 in a class to an ordinary method or a method to a member variable
509 in a class to an ordinary method or a method to a member variable
561 can cause problems (but in old objects only).
510 can cause problems (but in old objects only).
562
511
563 - Functions that are removed (eg. via monkey-patching) from a module
512 - Functions that are removed (eg. via monkey-patching) from a module
564 before it is reloaded are not upgraded.
513 before it is reloaded are not upgraded.
565
514
566 - C extension modules cannot be reloaded, and so cannot be
515 - C extension modules cannot be reloaded, and so cannot be
567 autoreloaded.
516 autoreloaded.
568
517
569 """
518 """
570 if parameter_s == '':
519 if parameter_s == '':
571 self._reloader.check(True)
520 self._reloader.check(True)
572 elif parameter_s == '0':
521 elif parameter_s == '0':
573 self._reloader.enabled = False
522 self._reloader.enabled = False
574 elif parameter_s == '1':
523 elif parameter_s == '1':
575 self._reloader.check_all = False
524 self._reloader.check_all = False
576 self._reloader.enabled = True
525 self._reloader.enabled = True
577 elif parameter_s == '2':
526 elif parameter_s == '2':
578 self._reloader.check_all = True
527 self._reloader.check_all = True
579 self._reloader.enabled = True
528 self._reloader.enabled = True
580
529
581 @line_magic
530 @line_magic
582 def aimport(self, parameter_s='', stream=None):
531 def aimport(self, parameter_s='', stream=None):
583 """%aimport => Import modules for automatic reloading.
532 """%aimport => Import modules for automatic reloading.
584
533
585 %aimport
534 %aimport
586 List modules to automatically import and not to import.
535 List modules to automatically import and not to import.
587
536
588 %aimport foo
537 %aimport foo
589 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
538 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
590
539
591 %aimport foo, bar
540 %aimport foo, bar
592 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
541 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
593
542
594 %aimport -foo
543 %aimport -foo
595 Mark module 'foo' to not be autoreloaded for %autoreload 1
544 Mark module 'foo' to not be autoreloaded for %autoreload 1
596 """
545 """
597 modname = parameter_s
546 modname = parameter_s
598 if not modname:
547 if not modname:
599 to_reload = sorted(self._reloader.modules.keys())
548 to_reload = sorted(self._reloader.modules.keys())
600 to_skip = sorted(self._reloader.skip_modules.keys())
549 to_skip = sorted(self._reloader.skip_modules.keys())
601 if stream is None:
550 if stream is None:
602 stream = sys.stdout
551 stream = sys.stdout
603 if self._reloader.check_all:
552 if self._reloader.check_all:
604 stream.write("Modules to reload:\nall-except-skipped\n")
553 stream.write("Modules to reload:\nall-except-skipped\n")
605 else:
554 else:
606 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
555 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
607 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
556 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
608 elif modname.startswith('-'):
557 elif modname.startswith('-'):
609 modname = modname[1:]
558 modname = modname[1:]
610 self._reloader.mark_module_skipped(modname)
559 self._reloader.mark_module_skipped(modname)
611 else:
560 else:
612 for _module in ([_.strip() for _ in modname.split(',')]):
561 for _module in ([_.strip() for _ in modname.split(',')]):
613 top_module, top_name = self._reloader.aimport_module(_module)
562 top_module, top_name = self._reloader.aimport_module(_module)
614
563
615 # Inject module to user namespace
564 # Inject module to user namespace
616 self.shell.push({top_name: top_module})
565 self.shell.push({top_name: top_module})
617
566
618 def pre_run_cell(self):
567 def pre_run_cell(self):
619 if self._reloader.enabled:
568 if self._reloader.enabled:
620 try:
569 try:
621 self._reloader.check()
570 self._reloader.check()
622 except:
571 except:
623 pass
572 pass
624
573
625 def post_execute_hook(self):
574 def post_execute_hook(self):
626 """Cache the modification times of any modules imported in this execution
575 """Cache the modification times of any modules imported in this execution
627 """
576 """
628 newly_loaded_modules = set(sys.modules) - self.loaded_modules
577 newly_loaded_modules = set(sys.modules) - self.loaded_modules
629 for modname in newly_loaded_modules:
578 for modname in newly_loaded_modules:
630 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
579 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
631 if pymtime is not None:
580 if pymtime is not None:
632 self._reloader.modules_mtimes[modname] = pymtime
581 self._reloader.modules_mtimes[modname] = pymtime
633
582
634 self.loaded_modules.update(newly_loaded_modules)
583 self.loaded_modules.update(newly_loaded_modules)
635
584
636
585
637 def load_ipython_extension(ip):
586 def load_ipython_extension(ip):
638 """Load the extension in IPython."""
587 """Load the extension in IPython."""
639 auto_reload = AutoreloadMagics(ip)
588 auto_reload = AutoreloadMagics(ip)
640 ip.register_magics(auto_reload)
589 ip.register_magics(auto_reload)
641 ip.events.register('pre_run_cell', auto_reload.pre_run_cell)
590 ip.events.register('pre_run_cell', auto_reload.pre_run_cell)
642 ip.events.register('post_execute', auto_reload.post_execute_hook)
591 ip.events.register('post_execute', auto_reload.post_execute_hook)
General Comments 0
You need to be logged in to leave comments. Login now