##// END OF EJS Templates
BUG: extensions/autoreload: don't clobber module dictionary if reload fails
Pauli Virtanen -
Show More
@@ -1,393 +1,400 b''
1 """
1 """
2 IPython extension: autoreload modules before executing the next line
2 IPython extension: autoreload modules before executing the next line
3
3
4 Try::
4 Try::
5
5
6 %autoreload?
6 %autoreload?
7
7
8 for documentation.
8 for documentation.
9 """
9 """
10
10
11 # Pauli Virtanen <pav@iki.fi>, 2008.
11 # Pauli Virtanen <pav@iki.fi>, 2008.
12 # Thomas Heller, 2000.
12 # Thomas Heller, 2000.
13 #
13 #
14 # This IPython module is written by Pauli Virtanen, based on the autoreload
14 # This IPython module is written by Pauli Virtanen, based on the autoreload
15 # code by Thomas Heller.
15 # code by Thomas Heller.
16
16
17 #------------------------------------------------------------------------------
17 #------------------------------------------------------------------------------
18 # Autoreload functionality
18 # Autoreload functionality
19 #------------------------------------------------------------------------------
19 #------------------------------------------------------------------------------
20
20
21 import time, os, threading, sys, types, imp, inspect, traceback, atexit
21 import time, os, threading, sys, types, imp, inspect, traceback, atexit
22 import weakref
22 import weakref
23
23
24 def _get_compiled_ext():
24 def _get_compiled_ext():
25 """Official way to get the extension of compiled files (.pyc or .pyo)"""
25 """Official way to get the extension of compiled files (.pyc or .pyo)"""
26 for ext, mode, typ in imp.get_suffixes():
26 for ext, mode, typ in imp.get_suffixes():
27 if typ == imp.PY_COMPILED:
27 if typ == imp.PY_COMPILED:
28 return ext
28 return ext
29
29
30 PY_COMPILED_EXT = _get_compiled_ext()
30 PY_COMPILED_EXT = _get_compiled_ext()
31
31
32 class ModuleReloader(object):
32 class ModuleReloader(object):
33 enabled = False
33 enabled = False
34 """Whether this reloader is enabled"""
34 """Whether this reloader is enabled"""
35
35
36 failed = {}
36 failed = {}
37 """Modules that failed to reload: {module: mtime-on-failed-reload, ...}"""
37 """Modules that failed to reload: {module: mtime-on-failed-reload, ...}"""
38
38
39 modules = {}
39 modules = {}
40 """Modules specially marked as autoreloadable."""
40 """Modules specially marked as autoreloadable."""
41
41
42 skip_modules = {}
42 skip_modules = {}
43 """Modules specially marked as not autoreloadable."""
43 """Modules specially marked as not autoreloadable."""
44
44
45 check_all = True
45 check_all = True
46 """Autoreload all modules, not just those listed in 'modules'"""
46 """Autoreload all modules, not just those listed in 'modules'"""
47
47
48 old_objects = {}
48 old_objects = {}
49 """(module-name, name) -> weakref, for replacing old code objects"""
49 """(module-name, name) -> weakref, for replacing old code objects"""
50
50
51 def mark_module_skipped(self, module_name):
51 def mark_module_skipped(self, module_name):
52 """Skip reloading the named module in the future"""
52 """Skip reloading the named module in the future"""
53 try:
53 try:
54 del self.modules[module_name]
54 del self.modules[module_name]
55 except KeyError:
55 except KeyError:
56 pass
56 pass
57 self.skip_modules[module_name] = True
57 self.skip_modules[module_name] = True
58
58
59 def mark_module_reloadable(self, module_name):
59 def mark_module_reloadable(self, module_name):
60 """Reload the named module in the future (if it is imported)"""
60 """Reload the named module in the future (if it is imported)"""
61 try:
61 try:
62 del self.skip_modules[module_name]
62 del self.skip_modules[module_name]
63 except KeyError:
63 except KeyError:
64 pass
64 pass
65 self.modules[module_name] = True
65 self.modules[module_name] = True
66
66
67 def aimport_module(self, module_name):
67 def aimport_module(self, module_name):
68 """Import a module, and mark it reloadable
68 """Import a module, and mark it reloadable
69
69
70 Returns
70 Returns
71 -------
71 -------
72 top_module : module
72 top_module : module
73 The imported module if it is top-level, or the top-level
73 The imported module if it is top-level, or the top-level
74 top_name : module
74 top_name : module
75 Name of top_module
75 Name of top_module
76
76
77 """
77 """
78 self.mark_module_reloadable(module_name)
78 self.mark_module_reloadable(module_name)
79
79
80 __import__(module_name)
80 __import__(module_name)
81 top_name = module_name.split('.')[0]
81 top_name = module_name.split('.')[0]
82 top_module = sys.modules[top_name]
82 top_module = sys.modules[top_name]
83 return top_module, top_name
83 return top_module, top_name
84
84
85 def check(self, check_all=False):
85 def check(self, check_all=False):
86 """Check whether some modules need to be reloaded."""
86 """Check whether some modules need to be reloaded."""
87
87
88 if not self.enabled and not check_all:
88 if not self.enabled and not check_all:
89 return
89 return
90
90
91 if check_all or self.check_all:
91 if check_all or self.check_all:
92 modules = sys.modules.keys()
92 modules = sys.modules.keys()
93 else:
93 else:
94 modules = self.modules.keys()
94 modules = self.modules.keys()
95
95
96 for modname in modules:
96 for modname in modules:
97 m = sys.modules.get(modname, None)
97 m = sys.modules.get(modname, None)
98
98
99 if modname in self.skip_modules:
99 if modname in self.skip_modules:
100 continue
100 continue
101
101
102 if not hasattr(m, '__file__'):
102 if not hasattr(m, '__file__'):
103 continue
103 continue
104
104
105 if m.__name__ == '__main__':
105 if m.__name__ == '__main__':
106 # we cannot reload(__main__)
106 # we cannot reload(__main__)
107 continue
107 continue
108
108
109 filename = m.__file__
109 filename = m.__file__
110 path, ext = os.path.splitext(filename)
110 path, ext = os.path.splitext(filename)
111
111
112 if ext.lower() == '.py':
112 if ext.lower() == '.py':
113 ext = PY_COMPILED_EXT
113 ext = PY_COMPILED_EXT
114 pyc_filename = path + PY_COMPILED_EXT
114 pyc_filename = path + PY_COMPILED_EXT
115 py_filename = filename
115 py_filename = filename
116 else:
116 else:
117 pyc_filename = filename
117 pyc_filename = filename
118 py_filename = filename[:-1]
118 py_filename = filename[:-1]
119
119
120 if ext != PY_COMPILED_EXT:
120 if ext != PY_COMPILED_EXT:
121 continue
121 continue
122
122
123 try:
123 try:
124 pymtime = os.stat(py_filename).st_mtime
124 pymtime = os.stat(py_filename).st_mtime
125 if pymtime <= os.stat(pyc_filename).st_mtime:
125 if pymtime <= os.stat(pyc_filename).st_mtime:
126 continue
126 continue
127 if self.failed.get(py_filename, None) == pymtime:
127 if self.failed.get(py_filename, None) == pymtime:
128 continue
128 continue
129 except OSError:
129 except OSError:
130 continue
130 continue
131
131
132 try:
132 try:
133 superreload(m, reload, self.old_objects)
133 superreload(m, reload, self.old_objects)
134 if py_filename in self.failed:
134 if py_filename in self.failed:
135 del self.failed[py_filename]
135 del self.failed[py_filename]
136 except:
136 except:
137 print >> sys.stderr, "[autoreload of %s failed: %s]" % (
137 print >> sys.stderr, "[autoreload of %s failed: %s]" % (
138 modname, traceback.format_exc(1))
138 modname, traceback.format_exc(1))
139 self.failed[py_filename] = pymtime
139 self.failed[py_filename] = pymtime
140
140
141 #------------------------------------------------------------------------------
141 #------------------------------------------------------------------------------
142 # superreload
142 # superreload
143 #------------------------------------------------------------------------------
143 #------------------------------------------------------------------------------
144
144
145 def update_function(old, new):
145 def update_function(old, new):
146 """Upgrade the code object of a function"""
146 """Upgrade the code object of a function"""
147 for name in ['func_code', 'func_defaults', 'func_doc',
147 for name in ['func_code', 'func_defaults', 'func_doc',
148 'func_closure', 'func_globals', 'func_dict']:
148 'func_closure', 'func_globals', 'func_dict']:
149 try:
149 try:
150 setattr(old, name, getattr(new, name))
150 setattr(old, name, getattr(new, name))
151 except (AttributeError, TypeError):
151 except (AttributeError, TypeError):
152 pass
152 pass
153
153
154 def update_class(old, new):
154 def update_class(old, new):
155 """Replace stuff in the __dict__ of a class, and upgrade
155 """Replace stuff in the __dict__ of a class, and upgrade
156 method code objects"""
156 method code objects"""
157 for key in old.__dict__.keys():
157 for key in old.__dict__.keys():
158 old_obj = getattr(old, key)
158 old_obj = getattr(old, key)
159
159
160 try:
160 try:
161 new_obj = getattr(new, key)
161 new_obj = getattr(new, key)
162 except AttributeError:
162 except AttributeError:
163 # obsolete attribute: remove it
163 # obsolete attribute: remove it
164 try:
164 try:
165 delattr(old, key)
165 delattr(old, key)
166 except (AttributeError, TypeError):
166 except (AttributeError, TypeError):
167 pass
167 pass
168 continue
168 continue
169
169
170 if update_generic(old_obj, new_obj): continue
170 if update_generic(old_obj, new_obj): continue
171
171
172 try:
172 try:
173 setattr(old, key, getattr(new, key))
173 setattr(old, key, getattr(new, key))
174 except (AttributeError, TypeError):
174 except (AttributeError, TypeError):
175 pass # skip non-writable attributes
175 pass # skip non-writable attributes
176
176
177 def update_property(old, new):
177 def update_property(old, new):
178 """Replace get/set/del functions of a property"""
178 """Replace get/set/del functions of a property"""
179 update_generic(old.fdel, new.fdel)
179 update_generic(old.fdel, new.fdel)
180 update_generic(old.fget, new.fget)
180 update_generic(old.fget, new.fget)
181 update_generic(old.fset, new.fset)
181 update_generic(old.fset, new.fset)
182
182
183 def isinstance2(a, b, typ):
183 def isinstance2(a, b, typ):
184 return isinstance(a, typ) and isinstance(b, typ)
184 return isinstance(a, typ) and isinstance(b, typ)
185
185
186 UPDATE_RULES = [
186 UPDATE_RULES = [
187 (lambda a, b: isinstance2(a, b, types.ClassType),
187 (lambda a, b: isinstance2(a, b, types.ClassType),
188 update_class),
188 update_class),
189 (lambda a, b: isinstance2(a, b, types.TypeType),
189 (lambda a, b: isinstance2(a, b, types.TypeType),
190 update_class),
190 update_class),
191 (lambda a, b: isinstance2(a, b, types.FunctionType),
191 (lambda a, b: isinstance2(a, b, types.FunctionType),
192 update_function),
192 update_function),
193 (lambda a, b: isinstance2(a, b, property),
193 (lambda a, b: isinstance2(a, b, property),
194 update_property),
194 update_property),
195 (lambda a, b: isinstance2(a, b, types.MethodType),
195 (lambda a, b: isinstance2(a, b, types.MethodType),
196 lambda a, b: update_function(a.im_func, b.im_func)),
196 lambda a, b: update_function(a.im_func, b.im_func)),
197 ]
197 ]
198
198
199 def update_generic(a, b):
199 def update_generic(a, b):
200 for type_check, update in UPDATE_RULES:
200 for type_check, update in UPDATE_RULES:
201 if type_check(a, b):
201 if type_check(a, b):
202 update(a, b)
202 update(a, b)
203 return True
203 return True
204 return False
204 return False
205
205
206 class StrongRef(object):
206 class StrongRef(object):
207 def __init__(self, obj):
207 def __init__(self, obj):
208 self.obj = obj
208 self.obj = obj
209 def __call__(self):
209 def __call__(self):
210 return self.obj
210 return self.obj
211
211
212 def superreload(module, reload=reload, old_objects={}):
212 def superreload(module, reload=reload, old_objects={}):
213 """Enhanced version of the builtin reload function.
213 """Enhanced version of the builtin reload function.
214
214
215 superreload remembers objects previously in the module, and
215 superreload remembers objects previously in the module, and
216
216
217 - upgrades the class dictionary of every old class in the module
217 - upgrades the class dictionary of every old class in the module
218 - upgrades the code object of every old function and method
218 - upgrades the code object of every old function and method
219 - clears the module's namespace before reloading
219 - clears the module's namespace before reloading
220
220
221 """
221 """
222
222
223 # collect old objects in the module
223 # collect old objects in the module
224 for name, obj in module.__dict__.items():
224 for name, obj in module.__dict__.items():
225 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
225 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
226 continue
226 continue
227 key = (module.__name__, name)
227 key = (module.__name__, name)
228 try:
228 try:
229 old_objects.setdefault(key, []).append(weakref.ref(obj))
229 old_objects.setdefault(key, []).append(weakref.ref(obj))
230 except TypeError:
230 except TypeError:
231 # weakref doesn't work for all types;
231 # weakref doesn't work for all types;
232 # create strong references for 'important' cases
232 # create strong references for 'important' cases
233 if isinstance(obj, types.ClassType):
233 if isinstance(obj, types.ClassType):
234 old_objects.setdefault(key, []).append(StrongRef(obj))
234 old_objects.setdefault(key, []).append(StrongRef(obj))
235
235
236 # reload module
236 # reload module
237 try:
237 try:
238 # clear namespace first from old cruft
238 # clear namespace first from old cruft
239 old_dict = module.__dict__.copy()
239 old_name = module.__name__
240 old_name = module.__name__
240 module.__dict__.clear()
241 module.__dict__.clear()
241 module.__dict__['__name__'] = old_name
242 module.__dict__['__name__'] = old_name
242 except (TypeError, AttributeError, KeyError):
243 except (TypeError, AttributeError, KeyError):
243 pass
244 pass
244 module = reload(module)
245
246 try:
247 module = reload(module)
248 except:
249 # restore module dictionary on failed reload
250 module.__dict__.update(old_dict)
251 raise
245
252
246 # iterate over all objects and update functions & classes
253 # iterate over all objects and update functions & classes
247 for name, new_obj in module.__dict__.items():
254 for name, new_obj in module.__dict__.items():
248 key = (module.__name__, name)
255 key = (module.__name__, name)
249 if key not in old_objects: continue
256 if key not in old_objects: continue
250
257
251 new_refs = []
258 new_refs = []
252 for old_ref in old_objects[key]:
259 for old_ref in old_objects[key]:
253 old_obj = old_ref()
260 old_obj = old_ref()
254 if old_obj is None: continue
261 if old_obj is None: continue
255 new_refs.append(old_ref)
262 new_refs.append(old_ref)
256 update_generic(old_obj, new_obj)
263 update_generic(old_obj, new_obj)
257
264
258 if new_refs:
265 if new_refs:
259 old_objects[key] = new_refs
266 old_objects[key] = new_refs
260 else:
267 else:
261 del old_objects[key]
268 del old_objects[key]
262
269
263 return module
270 return module
264
271
265 #------------------------------------------------------------------------------
272 #------------------------------------------------------------------------------
266 # IPython connectivity
273 # IPython connectivity
267 #------------------------------------------------------------------------------
274 #------------------------------------------------------------------------------
268
275
269 from IPython.core.plugin import Plugin
276 from IPython.core.plugin import Plugin
270 from IPython.core.hooks import TryNext
277 from IPython.core.hooks import TryNext
271
278
272 class AutoreloadInterface(object):
279 class AutoreloadInterface(object):
273 def __init__(self, *a, **kw):
280 def __init__(self, *a, **kw):
274 super(AutoreloadInterface, self).__init__(*a, **kw)
281 super(AutoreloadInterface, self).__init__(*a, **kw)
275 self._reloader = ModuleReloader()
282 self._reloader = ModuleReloader()
276 self._reloader.check_all = False
283 self._reloader.check_all = False
277
284
278 def magic_autoreload(self, ipself, parameter_s=''):
285 def magic_autoreload(self, ipself, parameter_s=''):
279 r"""%autoreload => Reload modules automatically
286 r"""%autoreload => Reload modules automatically
280
287
281 %autoreload
288 %autoreload
282 Reload all modules (except those excluded by %aimport) automatically
289 Reload all modules (except those excluded by %aimport) automatically
283 now.
290 now.
284
291
285 %autoreload 0
292 %autoreload 0
286 Disable automatic reloading.
293 Disable automatic reloading.
287
294
288 %autoreload 1
295 %autoreload 1
289 Reload all modules imported with %aimport every time before executing
296 Reload all modules imported with %aimport every time before executing
290 the Python code typed.
297 the Python code typed.
291
298
292 %autoreload 2
299 %autoreload 2
293 Reload all modules (except those excluded by %aimport) every time
300 Reload all modules (except those excluded by %aimport) every time
294 before executing the Python code typed.
301 before executing the Python code typed.
295
302
296 Reloading Python modules in a reliable way is in general
303 Reloading Python modules in a reliable way is in general
297 difficult, and unexpected things may occur. %autoreload tries to
304 difficult, and unexpected things may occur. %autoreload tries to
298 work around common pitfalls by replacing function code objects and
305 work around common pitfalls by replacing function code objects and
299 parts of classes previously in the module with new versions. This
306 parts of classes previously in the module with new versions. This
300 makes the following things to work:
307 makes the following things to work:
301
308
302 - Functions and classes imported via 'from xxx import foo' are upgraded
309 - Functions and classes imported via 'from xxx import foo' are upgraded
303 to new versions when 'xxx' is reloaded.
310 to new versions when 'xxx' is reloaded.
304
311
305 - Methods and properties of classes are upgraded on reload, so that
312 - Methods and properties of classes are upgraded on reload, so that
306 calling 'c.foo()' on an object 'c' created before the reload causes
313 calling 'c.foo()' on an object 'c' created before the reload causes
307 the new code for 'foo' to be executed.
314 the new code for 'foo' to be executed.
308
315
309 Some of the known remaining caveats are:
316 Some of the known remaining caveats are:
310
317
311 - Replacing code objects does not always succeed: changing a @property
318 - Replacing code objects does not always succeed: changing a @property
312 in a class to an ordinary method or a method to a member variable
319 in a class to an ordinary method or a method to a member variable
313 can cause problems (but in old objects only).
320 can cause problems (but in old objects only).
314
321
315 - Functions that are removed (eg. via monkey-patching) from a module
322 - Functions that are removed (eg. via monkey-patching) from a module
316 before it is reloaded are not upgraded.
323 before it is reloaded are not upgraded.
317
324
318 - C extension modules cannot be reloaded, and so cannot be
325 - C extension modules cannot be reloaded, and so cannot be
319 autoreloaded.
326 autoreloaded.
320
327
321 """
328 """
322 if parameter_s == '':
329 if parameter_s == '':
323 self._reloader.check(True)
330 self._reloader.check(True)
324 elif parameter_s == '0':
331 elif parameter_s == '0':
325 self._reloader.enabled = False
332 self._reloader.enabled = False
326 elif parameter_s == '1':
333 elif parameter_s == '1':
327 self._reloader.check_all = False
334 self._reloader.check_all = False
328 self._reloader.enabled = True
335 self._reloader.enabled = True
329 elif parameter_s == '2':
336 elif parameter_s == '2':
330 self._reloader.check_all = True
337 self._reloader.check_all = True
331 self._reloader.enabled = True
338 self._reloader.enabled = True
332
339
333 def magic_aimport(self, ipself, parameter_s='', stream=None):
340 def magic_aimport(self, ipself, parameter_s='', stream=None):
334 """%aimport => Import modules for automatic reloading.
341 """%aimport => Import modules for automatic reloading.
335
342
336 %aimport
343 %aimport
337 List modules to automatically import and not to import.
344 List modules to automatically import and not to import.
338
345
339 %aimport foo
346 %aimport foo
340 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
347 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
341
348
342 %aimport -foo
349 %aimport -foo
343 Mark module 'foo' to not be autoreloaded for %autoreload 1
350 Mark module 'foo' to not be autoreloaded for %autoreload 1
344
351
345 """
352 """
346
353
347 modname = parameter_s
354 modname = parameter_s
348 if not modname:
355 if not modname:
349 to_reload = self._reloader.modules.keys()
356 to_reload = self._reloader.modules.keys()
350 to_reload.sort()
357 to_reload.sort()
351 to_skip = self._reloader.skip_modules.keys()
358 to_skip = self._reloader.skip_modules.keys()
352 to_skip.sort()
359 to_skip.sort()
353 if stream is None:
360 if stream is None:
354 stream = sys.stdout
361 stream = sys.stdout
355 if self._reloader.check_all:
362 if self._reloader.check_all:
356 stream.write("Modules to reload:\nall-except-skipped\n")
363 stream.write("Modules to reload:\nall-except-skipped\n")
357 else:
364 else:
358 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
365 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
359 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
366 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
360 elif modname.startswith('-'):
367 elif modname.startswith('-'):
361 modname = modname[1:]
368 modname = modname[1:]
362 self._reloader.mark_module_skipped(modname)
369 self._reloader.mark_module_skipped(modname)
363 else:
370 else:
364 top_module, top_name = self._reloader.aimport_module(modname)
371 top_module, top_name = self._reloader.aimport_module(modname)
365
372
366 # Inject module to user namespace
373 # Inject module to user namespace
367 ipself.push({top_name: top_module})
374 ipself.push({top_name: top_module})
368
375
369 def pre_run_code_hook(self, ipself):
376 def pre_run_code_hook(self, ipself):
370 if not self._reloader.enabled:
377 if not self._reloader.enabled:
371 raise TryNext
378 raise TryNext
372 try:
379 try:
373 self._reloader.check()
380 self._reloader.check()
374 except:
381 except:
375 pass
382 pass
376
383
377 class AutoreloadPlugin(AutoreloadInterface, Plugin):
384 class AutoreloadPlugin(AutoreloadInterface, Plugin):
378 def __init__(self, shell=None, config=None):
385 def __init__(self, shell=None, config=None):
379 super(AutoreloadPlugin, self).__init__(shell=shell, config=config)
386 super(AutoreloadPlugin, self).__init__(shell=shell, config=config)
380
387
381 self.shell.define_magic('autoreload', self.magic_autoreload)
388 self.shell.define_magic('autoreload', self.magic_autoreload)
382 self.shell.define_magic('aimport', self.magic_aimport)
389 self.shell.define_magic('aimport', self.magic_aimport)
383 self.shell.set_hook('pre_run_code_hook', self.pre_run_code_hook)
390 self.shell.set_hook('pre_run_code_hook', self.pre_run_code_hook)
384
391
385 _loaded = False
392 _loaded = False
386
393
387 def load_ipython_extension(ip):
394 def load_ipython_extension(ip):
388 """Load the extension in IPython."""
395 """Load the extension in IPython."""
389 global _loaded
396 global _loaded
390 if not _loaded:
397 if not _loaded:
391 plugin = AutoreloadPlugin(shell=ip, config=ip.config)
398 plugin = AutoreloadPlugin(shell=ip, config=ip.config)
392 ip.plugin_manager.register_plugin('autoreload', plugin)
399 ip.plugin_manager.register_plugin('autoreload', plugin)
393 _loaded = True
400 _loaded = True
General Comments 0
You need to be logged in to leave comments. Login now