##// END OF EJS Templates
ipy_autoreload: fix %aimporting submodules
Pauli Virtanen -
Show More
@@ -1,340 +1,343 b''
1 1 """
2 2 IPython extension: autoreload modules before executing the next line
3 3
4 4 Try::
5 5
6 6 %autoreload?
7 7
8 8 for documentation.
9 9 """
10 10
11 11 # Pauli Virtanen <pav@iki.fi>, 2008.
12 12 # Thomas Heller, 2000.
13 13 #
14 14 # This IPython module is written by Pauli Virtanen, based on the autoreload
15 15 # code by Thomas Heller.
16 16
17 17 #------------------------------------------------------------------------------
18 18 # Autoreload functionality
19 19 #------------------------------------------------------------------------------
20 20
21 21 import time, os, threading, sys, types, imp, inspect, traceback, atexit
22 22 import weakref
23 23
24 24 def _get_compiled_ext():
25 25 """Official way to get the extension of compiled files (.pyc or .pyo)"""
26 26 for ext, mode, typ in imp.get_suffixes():
27 27 if typ == imp.PY_COMPILED:
28 28 return ext
29 29
30 30 PY_COMPILED_EXT = _get_compiled_ext()
31 31
32 32 class ModuleReloader(object):
33 33 failed = {}
34 34 """Modules that failed to reload: {module: mtime-on-failed-reload, ...}"""
35 35
36 36 modules = {}
37 37 """Modules specially marked as autoreloadable."""
38 38
39 39 skip_modules = {}
40 40 """Modules specially marked as not autoreloadable."""
41 41
42 42 check_all = True
43 43 """Autoreload all modules, not just those listed in 'modules'"""
44 44
45 45 old_objects = {}
46 46 """(module-name, name) -> weakref, for replacing old code objects"""
47 47
48 48 def check(self, check_all=False):
49 49 """Check whether some modules need to be reloaded."""
50 50
51 51 if check_all or self.check_all:
52 52 modules = sys.modules.keys()
53 53 else:
54 54 modules = self.modules.keys()
55 55
56 56 for modname in modules:
57 57 m = sys.modules.get(modname, None)
58 58
59 59 if modname in self.skip_modules:
60 60 continue
61 61
62 62 if not hasattr(m, '__file__'):
63 63 continue
64 64
65 65 if m.__name__ == '__main__':
66 66 # we cannot reload(__main__)
67 67 continue
68 68
69 69 filename = m.__file__
70 70 dirname = os.path.dirname(filename)
71 71 path, ext = os.path.splitext(filename)
72 72
73 73 if ext.lower() == '.py':
74 74 ext = PY_COMPILED_EXT
75 75 filename = os.path.join(dirname, path + PY_COMPILED_EXT)
76 76
77 77 if ext != PY_COMPILED_EXT:
78 78 continue
79 79
80 80 try:
81 81 pymtime = os.stat(filename[:-1]).st_mtime
82 82 if pymtime <= os.stat(filename).st_mtime:
83 83 continue
84 84 if self.failed.get(filename[:-1], None) == pymtime:
85 85 continue
86 86 except OSError:
87 87 continue
88 88
89 89 try:
90 90 superreload(m, reload, self.old_objects)
91 91 if filename[:-1] in self.failed:
92 92 del self.failed[filename[:-1]]
93 93 except:
94 94 print >> sys.stderr, "[autoreload of %s failed: %s]" % (
95 95 modname, traceback.format_exc(1))
96 96 self.failed[filename[:-1]] = pymtime
97 97
98 98 #------------------------------------------------------------------------------
99 99 # superreload
100 100 #------------------------------------------------------------------------------
101 101
102 102 def update_function(old, new):
103 103 """Upgrade the code object of a function"""
104 104 for name in ['func_code', 'func_defaults', 'func_doc',
105 105 'func_closure', 'func_globals', 'func_dict']:
106 106 try:
107 107 setattr(old, name, getattr(new, name))
108 108 except (AttributeError, TypeError):
109 109 pass
110 110
111 111 def update_class(old, new):
112 112 """Replace stuff in the __dict__ of a class, and upgrade
113 113 method code objects"""
114 114 for key in old.__dict__.keys():
115 115 old_obj = getattr(old, key)
116 116
117 117 try:
118 118 new_obj = getattr(new, key)
119 119 except AttributeError:
120 120 # obsolete attribute: remove it
121 121 try:
122 122 delattr(old, key)
123 123 except (AttributeError, TypeError):
124 124 pass
125 125 continue
126 126
127 127 if update_generic(old_obj, new_obj): continue
128 128
129 129 try:
130 130 setattr(old, key, getattr(new, key))
131 131 except (AttributeError, TypeError):
132 132 pass # skip non-writable attributes
133 133
134 134 def update_property(old, new):
135 135 """Replace get/set/del functions of a property"""
136 136 update_generic(old.fdel, new.fdel)
137 137 update_generic(old.fget, new.fget)
138 138 update_generic(old.fset, new.fset)
139 139
140 140 def isinstance2(a, b, typ):
141 141 return isinstance(a, typ) and isinstance(b, typ)
142 142
143 143 UPDATE_RULES = [
144 144 (lambda a, b: isinstance2(a, b, types.ClassType),
145 145 update_class),
146 146 (lambda a, b: isinstance2(a, b, types.TypeType),
147 147 update_class),
148 148 (lambda a, b: isinstance2(a, b, types.FunctionType),
149 149 update_function),
150 150 (lambda a, b: isinstance2(a, b, property),
151 151 update_property),
152 152 (lambda a, b: isinstance2(a, b, types.MethodType),
153 153 lambda a, b: update_function(a.im_func, b.im_func)),
154 154 ]
155 155
156 156 def update_generic(a, b):
157 157 for type_check, update in UPDATE_RULES:
158 158 if type_check(a, b):
159 159 update(a, b)
160 160 return True
161 161 return False
162 162
163 163 class StrongRef(object):
164 164 def __init__(self, obj):
165 165 self.obj = obj
166 166 def __call__(self):
167 167 return self.obj
168 168
169 169 def superreload(module, reload=reload, old_objects={}):
170 170 """Enhanced version of the builtin reload function.
171 171
172 172 superreload remembers objects previously in the module, and
173 173
174 174 - upgrades the class dictionary of every old class in the module
175 175 - upgrades the code object of every old function and method
176 176 - clears the module's namespace before reloading
177 177
178 178 """
179 179
180 180 # collect old objects in the module
181 181 for name, obj in module.__dict__.items():
182 182 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
183 183 continue
184 184 key = (module.__name__, name)
185 185 try:
186 186 old_objects.setdefault(key, []).append(weakref.ref(obj))
187 187 except TypeError:
188 188 # weakref doesn't work for all types;
189 189 # create strong references for 'important' cases
190 190 if isinstance(obj, types.ClassType):
191 191 old_objects.setdefault(key, []).append(StrongRef(obj))
192 192
193 193 # reload module
194 194 try:
195 195 # clear namespace first from old cruft
196 196 old_name = module.__name__
197 197 module.__dict__.clear()
198 198 module.__dict__['__name__'] = old_name
199 199 except (TypeError, AttributeError, KeyError):
200 200 pass
201 201 module = reload(module)
202 202
203 203 # iterate over all objects and update functions & classes
204 204 for name, new_obj in module.__dict__.items():
205 205 key = (module.__name__, name)
206 206 if key not in old_objects: continue
207 207
208 208 new_refs = []
209 209 for old_ref in old_objects[key]:
210 210 old_obj = old_ref()
211 211 if old_obj is None: continue
212 212 new_refs.append(old_ref)
213 213 update_generic(old_obj, new_obj)
214 214
215 215 if new_refs:
216 216 old_objects[key] = new_refs
217 217 else:
218 218 del old_objects[key]
219 219
220 220 return module
221 221
222 222 reloader = ModuleReloader()
223 223
224 224 #------------------------------------------------------------------------------
225 225 # IPython connectivity
226 226 #------------------------------------------------------------------------------
227 227 import IPython.ipapi
228 228
229 229 ip = IPython.ipapi.get()
230 230
231 231 autoreload_enabled = False
232 232
233 233 def runcode_hook(self):
234 234 if not autoreload_enabled:
235 235 raise IPython.ipapi.TryNext
236 236 try:
237 237 reloader.check()
238 238 except:
239 239 pass
240 240
241 241 def enable_autoreload():
242 242 global autoreload_enabled
243 243 autoreload_enabled = True
244 244
245 245 def disable_autoreload():
246 246 global autoreload_enabled
247 247 autoreload_enabled = False
248 248
249 249 def autoreload_f(self, parameter_s=''):
250 250 r""" %autoreload => Reload modules automatically
251 251
252 252 %autoreload
253 253 Reload all modules (except those excluded by %aimport) automatically now.
254 254
255 255 %autoreload 1
256 256 Reload all modules imported with %aimport every time before executing
257 257 the Python code typed.
258 258
259 259 %autoreload 2
260 260 Reload all modules (except those excluded by %aimport) every time
261 261 before executing the Python code typed.
262 262
263 263 Reloading Python modules in a reliable way is in general difficult,
264 264 and unexpected things may occur. %autoreload tries to work
265 265 around common pitfalls by replacing code objects of functions
266 266 previously in the module with new versions. This makes the following
267 267 things to work:
268 268
269 269 - Functions and classes imported via 'from xxx import foo' are upgraded
270 270 to new versions when 'xxx' is reloaded.
271 271 - Methods and properties of classes are upgraded on reload, so that
272 272 calling 'c.foo()' on an object 'c' created before the reload causes
273 273 the new code for 'foo' to be executed.
274 274
275 275 Some of the known remaining caveats are:
276 276
277 277 - Replacing code objects does not always succeed: changing a @property
278 278 in a class to an ordinary method or a method to a member variable
279 279 can cause problems (but in old objects only).
280 280 - Functions that are removed (eg. via monkey-patching) from a module
281 281 before it is reloaded are not upgraded.
282 282 - C extension modules cannot be reloaded, and so cannot be
283 283 autoreloaded.
284 284
285 285 """
286 286 if parameter_s == '':
287 287 reloader.check(True)
288 288 elif parameter_s == '0':
289 289 disable_autoreload()
290 290 elif parameter_s == '1':
291 291 reloader.check_all = False
292 292 enable_autoreload()
293 293 elif parameter_s == '2':
294 294 reloader.check_all = True
295 295 enable_autoreload()
296 296
297 297 def aimport_f(self, parameter_s=''):
298 298 """%aimport => Import modules for automatic reloading.
299 299
300 300 %aimport
301 301 List modules to automatically import and not to import.
302 302
303 303 %aimport foo
304 304 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
305 305
306 306 %aimport -foo
307 307 Mark module 'foo' to not be autoreloaded for %autoreload 1
308 308
309 309 """
310 310
311 311 modname = parameter_s
312 312 if not modname:
313 313 to_reload = reloader.modules.keys()
314 314 to_reload.sort()
315 315 to_skip = reloader.skip_modules.keys()
316 316 to_skip.sort()
317 317 if reloader.check_all:
318 318 print "Modules to reload:\nall-expect-skipped"
319 319 else:
320 320 print "Modules to reload:\n%s" % ' '.join(to_reload)
321 321 print "\nModules to skip:\n%s" % ' '.join(to_skip)
322 322 elif modname.startswith('-'):
323 323 modname = modname[1:]
324 324 try: del reloader.modules[modname]
325 325 except KeyError: pass
326 326 reloader.skip_modules[modname] = True
327 327 else:
328 328 try: del reloader.skip_modules[modname]
329 329 except KeyError: pass
330 330 reloader.modules[modname] = True
331
332 mod = __import__(modname)
333 ip.to_user_ns({modname: mod})
331
332 # Inject module to user namespace; handle also submodules properly
333 __import__(modname)
334 basename = modname.split('.')[0]
335 mod = sys.modules[basename]
336 ip.to_user_ns({basename: mod})
334 337
335 338 def init():
336 339 ip.expose_magic('autoreload', autoreload_f)
337 340 ip.expose_magic('aimport', aimport_f)
338 341 ip.set_hook('pre_runcode_hook', runcode_hook)
339 342
340 343 init()
General Comments 0
You need to be logged in to leave comments. Login now