##// END OF EJS Templates
ipy_autoreload: fix %aimporting submodules
Pauli Virtanen -
Show More
@@ -1,340 +1,343 b''
1 """
1 """
2 IPython extension: autoreload modules before executing the next line
2 IPython extension: autoreload modules before executing the next line
3
3
4 Try::
4 Try::
5
5
6 %autoreload?
6 %autoreload?
7
7
8 for documentation.
8 for documentation.
9 """
9 """
10
10
11 # Pauli Virtanen <pav@iki.fi>, 2008.
11 # Pauli Virtanen <pav@iki.fi>, 2008.
12 # Thomas Heller, 2000.
12 # Thomas Heller, 2000.
13 #
13 #
14 # This IPython module is written by Pauli Virtanen, based on the autoreload
14 # This IPython module is written by Pauli Virtanen, based on the autoreload
15 # code by Thomas Heller.
15 # code by Thomas Heller.
16
16
17 #------------------------------------------------------------------------------
17 #------------------------------------------------------------------------------
18 # Autoreload functionality
18 # Autoreload functionality
19 #------------------------------------------------------------------------------
19 #------------------------------------------------------------------------------
20
20
21 import time, os, threading, sys, types, imp, inspect, traceback, atexit
21 import time, os, threading, sys, types, imp, inspect, traceback, atexit
22 import weakref
22 import weakref
23
23
24 def _get_compiled_ext():
24 def _get_compiled_ext():
25 """Official way to get the extension of compiled files (.pyc or .pyo)"""
25 """Official way to get the extension of compiled files (.pyc or .pyo)"""
26 for ext, mode, typ in imp.get_suffixes():
26 for ext, mode, typ in imp.get_suffixes():
27 if typ == imp.PY_COMPILED:
27 if typ == imp.PY_COMPILED:
28 return ext
28 return ext
29
29
30 PY_COMPILED_EXT = _get_compiled_ext()
30 PY_COMPILED_EXT = _get_compiled_ext()
31
31
32 class ModuleReloader(object):
32 class ModuleReloader(object):
33 failed = {}
33 failed = {}
34 """Modules that failed to reload: {module: mtime-on-failed-reload, ...}"""
34 """Modules that failed to reload: {module: mtime-on-failed-reload, ...}"""
35
35
36 modules = {}
36 modules = {}
37 """Modules specially marked as autoreloadable."""
37 """Modules specially marked as autoreloadable."""
38
38
39 skip_modules = {}
39 skip_modules = {}
40 """Modules specially marked as not autoreloadable."""
40 """Modules specially marked as not autoreloadable."""
41
41
42 check_all = True
42 check_all = True
43 """Autoreload all modules, not just those listed in 'modules'"""
43 """Autoreload all modules, not just those listed in 'modules'"""
44
44
45 old_objects = {}
45 old_objects = {}
46 """(module-name, name) -> weakref, for replacing old code objects"""
46 """(module-name, name) -> weakref, for replacing old code objects"""
47
47
48 def check(self, check_all=False):
48 def check(self, check_all=False):
49 """Check whether some modules need to be reloaded."""
49 """Check whether some modules need to be reloaded."""
50
50
51 if check_all or self.check_all:
51 if check_all or self.check_all:
52 modules = sys.modules.keys()
52 modules = sys.modules.keys()
53 else:
53 else:
54 modules = self.modules.keys()
54 modules = self.modules.keys()
55
55
56 for modname in modules:
56 for modname in modules:
57 m = sys.modules.get(modname, None)
57 m = sys.modules.get(modname, None)
58
58
59 if modname in self.skip_modules:
59 if modname in self.skip_modules:
60 continue
60 continue
61
61
62 if not hasattr(m, '__file__'):
62 if not hasattr(m, '__file__'):
63 continue
63 continue
64
64
65 if m.__name__ == '__main__':
65 if m.__name__ == '__main__':
66 # we cannot reload(__main__)
66 # we cannot reload(__main__)
67 continue
67 continue
68
68
69 filename = m.__file__
69 filename = m.__file__
70 dirname = os.path.dirname(filename)
70 dirname = os.path.dirname(filename)
71 path, ext = os.path.splitext(filename)
71 path, ext = os.path.splitext(filename)
72
72
73 if ext.lower() == '.py':
73 if ext.lower() == '.py':
74 ext = PY_COMPILED_EXT
74 ext = PY_COMPILED_EXT
75 filename = os.path.join(dirname, path + PY_COMPILED_EXT)
75 filename = os.path.join(dirname, path + PY_COMPILED_EXT)
76
76
77 if ext != PY_COMPILED_EXT:
77 if ext != PY_COMPILED_EXT:
78 continue
78 continue
79
79
80 try:
80 try:
81 pymtime = os.stat(filename[:-1]).st_mtime
81 pymtime = os.stat(filename[:-1]).st_mtime
82 if pymtime <= os.stat(filename).st_mtime:
82 if pymtime <= os.stat(filename).st_mtime:
83 continue
83 continue
84 if self.failed.get(filename[:-1], None) == pymtime:
84 if self.failed.get(filename[:-1], None) == pymtime:
85 continue
85 continue
86 except OSError:
86 except OSError:
87 continue
87 continue
88
88
89 try:
89 try:
90 superreload(m, reload, self.old_objects)
90 superreload(m, reload, self.old_objects)
91 if filename[:-1] in self.failed:
91 if filename[:-1] in self.failed:
92 del self.failed[filename[:-1]]
92 del self.failed[filename[:-1]]
93 except:
93 except:
94 print >> sys.stderr, "[autoreload of %s failed: %s]" % (
94 print >> sys.stderr, "[autoreload of %s failed: %s]" % (
95 modname, traceback.format_exc(1))
95 modname, traceback.format_exc(1))
96 self.failed[filename[:-1]] = pymtime
96 self.failed[filename[:-1]] = pymtime
97
97
98 #------------------------------------------------------------------------------
98 #------------------------------------------------------------------------------
99 # superreload
99 # superreload
100 #------------------------------------------------------------------------------
100 #------------------------------------------------------------------------------
101
101
102 def update_function(old, new):
102 def update_function(old, new):
103 """Upgrade the code object of a function"""
103 """Upgrade the code object of a function"""
104 for name in ['func_code', 'func_defaults', 'func_doc',
104 for name in ['func_code', 'func_defaults', 'func_doc',
105 'func_closure', 'func_globals', 'func_dict']:
105 'func_closure', 'func_globals', 'func_dict']:
106 try:
106 try:
107 setattr(old, name, getattr(new, name))
107 setattr(old, name, getattr(new, name))
108 except (AttributeError, TypeError):
108 except (AttributeError, TypeError):
109 pass
109 pass
110
110
111 def update_class(old, new):
111 def update_class(old, new):
112 """Replace stuff in the __dict__ of a class, and upgrade
112 """Replace stuff in the __dict__ of a class, and upgrade
113 method code objects"""
113 method code objects"""
114 for key in old.__dict__.keys():
114 for key in old.__dict__.keys():
115 old_obj = getattr(old, key)
115 old_obj = getattr(old, key)
116
116
117 try:
117 try:
118 new_obj = getattr(new, key)
118 new_obj = getattr(new, key)
119 except AttributeError:
119 except AttributeError:
120 # obsolete attribute: remove it
120 # obsolete attribute: remove it
121 try:
121 try:
122 delattr(old, key)
122 delattr(old, key)
123 except (AttributeError, TypeError):
123 except (AttributeError, TypeError):
124 pass
124 pass
125 continue
125 continue
126
126
127 if update_generic(old_obj, new_obj): continue
127 if update_generic(old_obj, new_obj): continue
128
128
129 try:
129 try:
130 setattr(old, key, getattr(new, key))
130 setattr(old, key, getattr(new, key))
131 except (AttributeError, TypeError):
131 except (AttributeError, TypeError):
132 pass # skip non-writable attributes
132 pass # skip non-writable attributes
133
133
134 def update_property(old, new):
134 def update_property(old, new):
135 """Replace get/set/del functions of a property"""
135 """Replace get/set/del functions of a property"""
136 update_generic(old.fdel, new.fdel)
136 update_generic(old.fdel, new.fdel)
137 update_generic(old.fget, new.fget)
137 update_generic(old.fget, new.fget)
138 update_generic(old.fset, new.fset)
138 update_generic(old.fset, new.fset)
139
139
140 def isinstance2(a, b, typ):
140 def isinstance2(a, b, typ):
141 return isinstance(a, typ) and isinstance(b, typ)
141 return isinstance(a, typ) and isinstance(b, typ)
142
142
143 UPDATE_RULES = [
143 UPDATE_RULES = [
144 (lambda a, b: isinstance2(a, b, types.ClassType),
144 (lambda a, b: isinstance2(a, b, types.ClassType),
145 update_class),
145 update_class),
146 (lambda a, b: isinstance2(a, b, types.TypeType),
146 (lambda a, b: isinstance2(a, b, types.TypeType),
147 update_class),
147 update_class),
148 (lambda a, b: isinstance2(a, b, types.FunctionType),
148 (lambda a, b: isinstance2(a, b, types.FunctionType),
149 update_function),
149 update_function),
150 (lambda a, b: isinstance2(a, b, property),
150 (lambda a, b: isinstance2(a, b, property),
151 update_property),
151 update_property),
152 (lambda a, b: isinstance2(a, b, types.MethodType),
152 (lambda a, b: isinstance2(a, b, types.MethodType),
153 lambda a, b: update_function(a.im_func, b.im_func)),
153 lambda a, b: update_function(a.im_func, b.im_func)),
154 ]
154 ]
155
155
156 def update_generic(a, b):
156 def update_generic(a, b):
157 for type_check, update in UPDATE_RULES:
157 for type_check, update in UPDATE_RULES:
158 if type_check(a, b):
158 if type_check(a, b):
159 update(a, b)
159 update(a, b)
160 return True
160 return True
161 return False
161 return False
162
162
163 class StrongRef(object):
163 class StrongRef(object):
164 def __init__(self, obj):
164 def __init__(self, obj):
165 self.obj = obj
165 self.obj = obj
166 def __call__(self):
166 def __call__(self):
167 return self.obj
167 return self.obj
168
168
169 def superreload(module, reload=reload, old_objects={}):
169 def superreload(module, reload=reload, old_objects={}):
170 """Enhanced version of the builtin reload function.
170 """Enhanced version of the builtin reload function.
171
171
172 superreload remembers objects previously in the module, and
172 superreload remembers objects previously in the module, and
173
173
174 - upgrades the class dictionary of every old class in the module
174 - upgrades the class dictionary of every old class in the module
175 - upgrades the code object of every old function and method
175 - upgrades the code object of every old function and method
176 - clears the module's namespace before reloading
176 - clears the module's namespace before reloading
177
177
178 """
178 """
179
179
180 # collect old objects in the module
180 # collect old objects in the module
181 for name, obj in module.__dict__.items():
181 for name, obj in module.__dict__.items():
182 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
182 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
183 continue
183 continue
184 key = (module.__name__, name)
184 key = (module.__name__, name)
185 try:
185 try:
186 old_objects.setdefault(key, []).append(weakref.ref(obj))
186 old_objects.setdefault(key, []).append(weakref.ref(obj))
187 except TypeError:
187 except TypeError:
188 # weakref doesn't work for all types;
188 # weakref doesn't work for all types;
189 # create strong references for 'important' cases
189 # create strong references for 'important' cases
190 if isinstance(obj, types.ClassType):
190 if isinstance(obj, types.ClassType):
191 old_objects.setdefault(key, []).append(StrongRef(obj))
191 old_objects.setdefault(key, []).append(StrongRef(obj))
192
192
193 # reload module
193 # reload module
194 try:
194 try:
195 # clear namespace first from old cruft
195 # clear namespace first from old cruft
196 old_name = module.__name__
196 old_name = module.__name__
197 module.__dict__.clear()
197 module.__dict__.clear()
198 module.__dict__['__name__'] = old_name
198 module.__dict__['__name__'] = old_name
199 except (TypeError, AttributeError, KeyError):
199 except (TypeError, AttributeError, KeyError):
200 pass
200 pass
201 module = reload(module)
201 module = reload(module)
202
202
203 # iterate over all objects and update functions & classes
203 # iterate over all objects and update functions & classes
204 for name, new_obj in module.__dict__.items():
204 for name, new_obj in module.__dict__.items():
205 key = (module.__name__, name)
205 key = (module.__name__, name)
206 if key not in old_objects: continue
206 if key not in old_objects: continue
207
207
208 new_refs = []
208 new_refs = []
209 for old_ref in old_objects[key]:
209 for old_ref in old_objects[key]:
210 old_obj = old_ref()
210 old_obj = old_ref()
211 if old_obj is None: continue
211 if old_obj is None: continue
212 new_refs.append(old_ref)
212 new_refs.append(old_ref)
213 update_generic(old_obj, new_obj)
213 update_generic(old_obj, new_obj)
214
214
215 if new_refs:
215 if new_refs:
216 old_objects[key] = new_refs
216 old_objects[key] = new_refs
217 else:
217 else:
218 del old_objects[key]
218 del old_objects[key]
219
219
220 return module
220 return module
221
221
222 reloader = ModuleReloader()
222 reloader = ModuleReloader()
223
223
224 #------------------------------------------------------------------------------
224 #------------------------------------------------------------------------------
225 # IPython connectivity
225 # IPython connectivity
226 #------------------------------------------------------------------------------
226 #------------------------------------------------------------------------------
227 import IPython.ipapi
227 import IPython.ipapi
228
228
229 ip = IPython.ipapi.get()
229 ip = IPython.ipapi.get()
230
230
231 autoreload_enabled = False
231 autoreload_enabled = False
232
232
233 def runcode_hook(self):
233 def runcode_hook(self):
234 if not autoreload_enabled:
234 if not autoreload_enabled:
235 raise IPython.ipapi.TryNext
235 raise IPython.ipapi.TryNext
236 try:
236 try:
237 reloader.check()
237 reloader.check()
238 except:
238 except:
239 pass
239 pass
240
240
241 def enable_autoreload():
241 def enable_autoreload():
242 global autoreload_enabled
242 global autoreload_enabled
243 autoreload_enabled = True
243 autoreload_enabled = True
244
244
245 def disable_autoreload():
245 def disable_autoreload():
246 global autoreload_enabled
246 global autoreload_enabled
247 autoreload_enabled = False
247 autoreload_enabled = False
248
248
249 def autoreload_f(self, parameter_s=''):
249 def autoreload_f(self, parameter_s=''):
250 r""" %autoreload => Reload modules automatically
250 r""" %autoreload => Reload modules automatically
251
251
252 %autoreload
252 %autoreload
253 Reload all modules (except those excluded by %aimport) automatically now.
253 Reload all modules (except those excluded by %aimport) automatically now.
254
254
255 %autoreload 1
255 %autoreload 1
256 Reload all modules imported with %aimport every time before executing
256 Reload all modules imported with %aimport every time before executing
257 the Python code typed.
257 the Python code typed.
258
258
259 %autoreload 2
259 %autoreload 2
260 Reload all modules (except those excluded by %aimport) every time
260 Reload all modules (except those excluded by %aimport) every time
261 before executing the Python code typed.
261 before executing the Python code typed.
262
262
263 Reloading Python modules in a reliable way is in general difficult,
263 Reloading Python modules in a reliable way is in general difficult,
264 and unexpected things may occur. %autoreload tries to work
264 and unexpected things may occur. %autoreload tries to work
265 around common pitfalls by replacing code objects of functions
265 around common pitfalls by replacing code objects of functions
266 previously in the module with new versions. This makes the following
266 previously in the module with new versions. This makes the following
267 things to work:
267 things to work:
268
268
269 - Functions and classes imported via 'from xxx import foo' are upgraded
269 - Functions and classes imported via 'from xxx import foo' are upgraded
270 to new versions when 'xxx' is reloaded.
270 to new versions when 'xxx' is reloaded.
271 - Methods and properties of classes are upgraded on reload, so that
271 - Methods and properties of classes are upgraded on reload, so that
272 calling 'c.foo()' on an object 'c' created before the reload causes
272 calling 'c.foo()' on an object 'c' created before the reload causes
273 the new code for 'foo' to be executed.
273 the new code for 'foo' to be executed.
274
274
275 Some of the known remaining caveats are:
275 Some of the known remaining caveats are:
276
276
277 - Replacing code objects does not always succeed: changing a @property
277 - Replacing code objects does not always succeed: changing a @property
278 in a class to an ordinary method or a method to a member variable
278 in a class to an ordinary method or a method to a member variable
279 can cause problems (but in old objects only).
279 can cause problems (but in old objects only).
280 - Functions that are removed (eg. via monkey-patching) from a module
280 - Functions that are removed (eg. via monkey-patching) from a module
281 before it is reloaded are not upgraded.
281 before it is reloaded are not upgraded.
282 - C extension modules cannot be reloaded, and so cannot be
282 - C extension modules cannot be reloaded, and so cannot be
283 autoreloaded.
283 autoreloaded.
284
284
285 """
285 """
286 if parameter_s == '':
286 if parameter_s == '':
287 reloader.check(True)
287 reloader.check(True)
288 elif parameter_s == '0':
288 elif parameter_s == '0':
289 disable_autoreload()
289 disable_autoreload()
290 elif parameter_s == '1':
290 elif parameter_s == '1':
291 reloader.check_all = False
291 reloader.check_all = False
292 enable_autoreload()
292 enable_autoreload()
293 elif parameter_s == '2':
293 elif parameter_s == '2':
294 reloader.check_all = True
294 reloader.check_all = True
295 enable_autoreload()
295 enable_autoreload()
296
296
297 def aimport_f(self, parameter_s=''):
297 def aimport_f(self, parameter_s=''):
298 """%aimport => Import modules for automatic reloading.
298 """%aimport => Import modules for automatic reloading.
299
299
300 %aimport
300 %aimport
301 List modules to automatically import and not to import.
301 List modules to automatically import and not to import.
302
302
303 %aimport foo
303 %aimport foo
304 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
304 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
305
305
306 %aimport -foo
306 %aimport -foo
307 Mark module 'foo' to not be autoreloaded for %autoreload 1
307 Mark module 'foo' to not be autoreloaded for %autoreload 1
308
308
309 """
309 """
310
310
311 modname = parameter_s
311 modname = parameter_s
312 if not modname:
312 if not modname:
313 to_reload = reloader.modules.keys()
313 to_reload = reloader.modules.keys()
314 to_reload.sort()
314 to_reload.sort()
315 to_skip = reloader.skip_modules.keys()
315 to_skip = reloader.skip_modules.keys()
316 to_skip.sort()
316 to_skip.sort()
317 if reloader.check_all:
317 if reloader.check_all:
318 print "Modules to reload:\nall-expect-skipped"
318 print "Modules to reload:\nall-expect-skipped"
319 else:
319 else:
320 print "Modules to reload:\n%s" % ' '.join(to_reload)
320 print "Modules to reload:\n%s" % ' '.join(to_reload)
321 print "\nModules to skip:\n%s" % ' '.join(to_skip)
321 print "\nModules to skip:\n%s" % ' '.join(to_skip)
322 elif modname.startswith('-'):
322 elif modname.startswith('-'):
323 modname = modname[1:]
323 modname = modname[1:]
324 try: del reloader.modules[modname]
324 try: del reloader.modules[modname]
325 except KeyError: pass
325 except KeyError: pass
326 reloader.skip_modules[modname] = True
326 reloader.skip_modules[modname] = True
327 else:
327 else:
328 try: del reloader.skip_modules[modname]
328 try: del reloader.skip_modules[modname]
329 except KeyError: pass
329 except KeyError: pass
330 reloader.modules[modname] = True
330 reloader.modules[modname] = True
331
331
332 mod = __import__(modname)
332 # Inject module to user namespace; handle also submodules properly
333 ip.to_user_ns({modname: mod})
333 __import__(modname)
334 basename = modname.split('.')[0]
335 mod = sys.modules[basename]
336 ip.to_user_ns({basename: mod})
334
337
335 def init():
338 def init():
336 ip.expose_magic('autoreload', autoreload_f)
339 ip.expose_magic('autoreload', autoreload_f)
337 ip.expose_magic('aimport', aimport_f)
340 ip.expose_magic('aimport', aimport_f)
338 ip.set_hook('pre_runcode_hook', runcode_hook)
341 ip.set_hook('pre_runcode_hook', runcode_hook)
339
342
340 init()
343 init()
General Comments 0
You need to be logged in to leave comments. Login now