##// END OF EJS Templates
Merge pull request #11644 from daharn/master...
Matthias Bussonnier -
r25092:bd5753b3 merge
parent child Browse files
Show More
@@ -1,534 +1,590 b''
1 """IPython extension to reload modules before executing user code.
1 """IPython extension to reload modules before executing user code.
2
2
3 ``autoreload`` reloads modules automatically before entering the execution of
3 ``autoreload`` reloads modules automatically before entering the execution of
4 code typed at the IPython prompt.
4 code typed at the IPython prompt.
5
5
6 This makes for example the following workflow possible:
6 This makes for example the following workflow possible:
7
7
8 .. sourcecode:: ipython
8 .. sourcecode:: ipython
9
9
10 In [1]: %load_ext autoreload
10 In [1]: %load_ext autoreload
11
11
12 In [2]: %autoreload 2
12 In [2]: %autoreload 2
13
13
14 In [3]: from foo import some_function
14 In [3]: from foo import some_function
15
15
16 In [4]: some_function()
16 In [4]: some_function()
17 Out[4]: 42
17 Out[4]: 42
18
18
19 In [5]: # open foo.py in an editor and change some_function to return 43
19 In [5]: # open foo.py in an editor and change some_function to return 43
20
20
21 In [6]: some_function()
21 In [6]: some_function()
22 Out[6]: 43
22 Out[6]: 43
23
23
24 The module was reloaded without reloading it explicitly, and the object
24 The module was reloaded without reloading it explicitly, and the object
25 imported with ``from foo import ...`` was also updated.
25 imported with ``from foo import ...`` was also updated.
26
26
27 Usage
27 Usage
28 =====
28 =====
29
29
30 The following magic commands are provided:
30 The following magic commands are provided:
31
31
32 ``%autoreload``
32 ``%autoreload``
33
33
34 Reload all modules (except those excluded by ``%aimport``)
34 Reload all modules (except those excluded by ``%aimport``)
35 automatically now.
35 automatically now.
36
36
37 ``%autoreload 0``
37 ``%autoreload 0``
38
38
39 Disable automatic reloading.
39 Disable automatic reloading.
40
40
41 ``%autoreload 1``
41 ``%autoreload 1``
42
42
43 Reload all modules imported with ``%aimport`` every time before
43 Reload all modules imported with ``%aimport`` every time before
44 executing the Python code typed.
44 executing the Python code typed.
45
45
46 ``%autoreload 2``
46 ``%autoreload 2``
47
47
48 Reload all modules (except those excluded by ``%aimport``) every
48 Reload all modules (except those excluded by ``%aimport``) every
49 time before executing the Python code typed.
49 time before executing the Python code typed.
50
50
51 ``%aimport``
51 ``%aimport``
52
52
53 List modules which are to be automatically imported or not to be imported.
53 List modules which are to be automatically imported or not to be imported.
54
54
55 ``%aimport foo``
55 ``%aimport foo``
56
56
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
58
58
59 ``%aimport foo, bar``
59 ``%aimport foo, bar``
60
60
61 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
61 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
62
62
63 ``%aimport -foo``
63 ``%aimport -foo``
64
64
65 Mark module 'foo' to not be autoreloaded.
65 Mark module 'foo' to not be autoreloaded.
66
66
67 Caveats
67 Caveats
68 =======
68 =======
69
69
70 Reloading Python modules in a reliable way is in general difficult,
70 Reloading Python modules in a reliable way is in general difficult,
71 and unexpected things may occur. ``%autoreload`` tries to work around
71 and unexpected things may occur. ``%autoreload`` tries to work around
72 common pitfalls by replacing function code objects and parts of
72 common pitfalls by replacing function code objects and parts of
73 classes previously in the module with new versions. This makes the
73 classes previously in the module with new versions. This makes the
74 following things to work:
74 following things to work:
75
75
76 - Functions and classes imported via 'from xxx import foo' are upgraded
76 - Functions and classes imported via 'from xxx import foo' are upgraded
77 to new versions when 'xxx' is reloaded.
77 to new versions when 'xxx' is reloaded.
78
78
79 - Methods and properties of classes are upgraded on reload, so that
79 - Methods and properties of classes are upgraded on reload, so that
80 calling 'c.foo()' on an object 'c' created before the reload causes
80 calling 'c.foo()' on an object 'c' created before the reload causes
81 the new code for 'foo' to be executed.
81 the new code for 'foo' to be executed.
82
82
83 Some of the known remaining caveats are:
83 Some of the known remaining caveats are:
84
84
85 - Replacing code objects does not always succeed: changing a @property
85 - Replacing code objects does not always succeed: changing a @property
86 in a class to an ordinary method or a method to a member variable
86 in a class to an ordinary method or a method to a member variable
87 can cause problems (but in old objects only).
87 can cause problems (but in old objects only).
88
88
89 - Functions that are removed (eg. via monkey-patching) from a module
89 - Functions that are removed (eg. via monkey-patching) from a module
90 before it is reloaded are not upgraded.
90 before it is reloaded are not upgraded.
91
91
92 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
92 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
93 """
93 """
94
94
95 skip_doctest = True
95 skip_doctest = True
96
96
97 #-----------------------------------------------------------------------------
97 #-----------------------------------------------------------------------------
98 # Copyright (C) 2000 Thomas Heller
98 # Copyright (C) 2000 Thomas Heller
99 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
99 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
100 # Copyright (C) 2012 The IPython Development Team
100 # Copyright (C) 2012 The IPython Development Team
101 #
101 #
102 # Distributed under the terms of the BSD License. The full license is in
102 # Distributed under the terms of the BSD License. The full license is in
103 # the file COPYING, distributed as part of this software.
103 # the file COPYING, distributed as part of this software.
104 #-----------------------------------------------------------------------------
104 #-----------------------------------------------------------------------------
105 #
105 #
106 # This IPython module is written by Pauli Virtanen, based on the autoreload
106 # This IPython module is written by Pauli Virtanen, based on the autoreload
107 # code by Thomas Heller.
107 # code by Thomas Heller.
108
108
109 #-----------------------------------------------------------------------------
109 #-----------------------------------------------------------------------------
110 # Imports
110 # Imports
111 #-----------------------------------------------------------------------------
111 #-----------------------------------------------------------------------------
112
112
113 import os
113 import os
114 import sys
114 import sys
115 import traceback
115 import traceback
116 import types
116 import types
117 import weakref
117 import weakref
118 import inspect
118 from importlib import import_module
119 from importlib import import_module
119 from importlib.util import source_from_cache
120 from importlib.util import source_from_cache
120 from imp import reload
121 from imp import reload
121
122
122 #------------------------------------------------------------------------------
123 #------------------------------------------------------------------------------
123 # Autoreload functionality
124 # Autoreload functionality
124 #------------------------------------------------------------------------------
125 #------------------------------------------------------------------------------
125
126
126 class ModuleReloader(object):
127 class ModuleReloader(object):
127 enabled = False
128 enabled = False
128 """Whether this reloader is enabled"""
129 """Whether this reloader is enabled"""
129
130
130 check_all = True
131 check_all = True
131 """Autoreload all modules, not just those listed in 'modules'"""
132 """Autoreload all modules, not just those listed in 'modules'"""
132
133
133 def __init__(self):
134 def __init__(self):
134 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
135 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
135 self.failed = {}
136 self.failed = {}
136 # Modules specially marked as autoreloadable.
137 # Modules specially marked as autoreloadable.
137 self.modules = {}
138 self.modules = {}
138 # Modules specially marked as not autoreloadable.
139 # Modules specially marked as not autoreloadable.
139 self.skip_modules = {}
140 self.skip_modules = {}
140 # (module-name, name) -> weakref, for replacing old code objects
141 # (module-name, name) -> weakref, for replacing old code objects
141 self.old_objects = {}
142 self.old_objects = {}
142 # Module modification timestamps
143 # Module modification timestamps
143 self.modules_mtimes = {}
144 self.modules_mtimes = {}
144
145
145 # Cache module modification times
146 # Cache module modification times
146 self.check(check_all=True, do_reload=False)
147 self.check(check_all=True, do_reload=False)
147
148
148 def mark_module_skipped(self, module_name):
149 def mark_module_skipped(self, module_name):
149 """Skip reloading the named module in the future"""
150 """Skip reloading the named module in the future"""
150 try:
151 try:
151 del self.modules[module_name]
152 del self.modules[module_name]
152 except KeyError:
153 except KeyError:
153 pass
154 pass
154 self.skip_modules[module_name] = True
155 self.skip_modules[module_name] = True
155
156
156 def mark_module_reloadable(self, module_name):
157 def mark_module_reloadable(self, module_name):
157 """Reload the named module in the future (if it is imported)"""
158 """Reload the named module in the future (if it is imported)"""
158 try:
159 try:
159 del self.skip_modules[module_name]
160 del self.skip_modules[module_name]
160 except KeyError:
161 except KeyError:
161 pass
162 pass
162 self.modules[module_name] = True
163 self.modules[module_name] = True
163
164
164 def aimport_module(self, module_name):
165 def aimport_module(self, module_name):
165 """Import a module, and mark it reloadable
166 """Import a module, and mark it reloadable
166
167
167 Returns
168 Returns
168 -------
169 -------
169 top_module : module
170 top_module : module
170 The imported module if it is top-level, or the top-level
171 The imported module if it is top-level, or the top-level
171 top_name : module
172 top_name : module
172 Name of top_module
173 Name of top_module
173
174
174 """
175 """
175 self.mark_module_reloadable(module_name)
176 self.mark_module_reloadable(module_name)
176
177
177 import_module(module_name)
178 import_module(module_name)
178 top_name = module_name.split('.')[0]
179 top_name = module_name.split('.')[0]
179 top_module = sys.modules[top_name]
180 top_module = sys.modules[top_name]
180 return top_module, top_name
181 return top_module, top_name
181
182
182 def filename_and_mtime(self, module):
183 def filename_and_mtime(self, module):
183 if not hasattr(module, '__file__') or module.__file__ is None:
184 if not hasattr(module, '__file__') or module.__file__ is None:
184 return None, None
185 return None, None
185
186
186 if getattr(module, '__name__', None) in [None, '__mp_main__', '__main__']:
187 if getattr(module, '__name__', None) in [None, '__mp_main__', '__main__']:
187 # we cannot reload(__main__) or reload(__mp_main__)
188 # we cannot reload(__main__) or reload(__mp_main__)
188 return None, None
189 return None, None
189
190
190 filename = module.__file__
191 filename = module.__file__
191 path, ext = os.path.splitext(filename)
192 path, ext = os.path.splitext(filename)
192
193
193 if ext.lower() == '.py':
194 if ext.lower() == '.py':
194 py_filename = filename
195 py_filename = filename
195 else:
196 else:
196 try:
197 try:
197 py_filename = source_from_cache(filename)
198 py_filename = source_from_cache(filename)
198 except ValueError:
199 except ValueError:
199 return None, None
200 return None, None
200
201
201 try:
202 try:
202 pymtime = os.stat(py_filename).st_mtime
203 pymtime = os.stat(py_filename).st_mtime
203 except OSError:
204 except OSError:
204 return None, None
205 return None, None
205
206
206 return py_filename, pymtime
207 return py_filename, pymtime
207
208
208 def check(self, check_all=False, do_reload=True):
209 def check(self, check_all=False, do_reload=True):
209 """Check whether some modules need to be reloaded."""
210 """Check whether some modules need to be reloaded."""
210
211
211 if not self.enabled and not check_all:
212 if not self.enabled and not check_all:
212 return
213 return
213
214
214 if check_all or self.check_all:
215 if check_all or self.check_all:
215 modules = list(sys.modules.keys())
216 modules = list(sys.modules.keys())
216 else:
217 else:
217 modules = list(self.modules.keys())
218 modules = list(self.modules.keys())
218
219
219 for modname in modules:
220 for modname in modules:
220 m = sys.modules.get(modname, None)
221 m = sys.modules.get(modname, None)
221
222
222 if modname in self.skip_modules:
223 if modname in self.skip_modules:
223 continue
224 continue
224
225
225 py_filename, pymtime = self.filename_and_mtime(m)
226 py_filename, pymtime = self.filename_and_mtime(m)
226 if py_filename is None:
227 if py_filename is None:
227 continue
228 continue
228
229
229 try:
230 try:
230 if pymtime <= self.modules_mtimes[modname]:
231 if pymtime <= self.modules_mtimes[modname]:
231 continue
232 continue
232 except KeyError:
233 except KeyError:
233 self.modules_mtimes[modname] = pymtime
234 self.modules_mtimes[modname] = pymtime
234 continue
235 continue
235 else:
236 else:
236 if self.failed.get(py_filename, None) == pymtime:
237 if self.failed.get(py_filename, None) == pymtime:
237 continue
238 continue
238
239
239 self.modules_mtimes[modname] = pymtime
240 self.modules_mtimes[modname] = pymtime
240
241
241 # If we've reached this point, we should try to reload the module
242 # If we've reached this point, we should try to reload the module
242 if do_reload:
243 if do_reload:
243 try:
244 try:
244 superreload(m, reload, self.old_objects)
245 superreload(m, reload, self.old_objects)
245 if py_filename in self.failed:
246 if py_filename in self.failed:
246 del self.failed[py_filename]
247 del self.failed[py_filename]
247 except:
248 except:
248 print("[autoreload of %s failed: %s]" % (
249 print("[autoreload of %s failed: %s]" % (
249 modname, traceback.format_exc(10)), file=sys.stderr)
250 modname, traceback.format_exc(10)), file=sys.stderr)
250 self.failed[py_filename] = pymtime
251 self.failed[py_filename] = pymtime
251
252
252 #------------------------------------------------------------------------------
253 #------------------------------------------------------------------------------
253 # superreload
254 # superreload
254 #------------------------------------------------------------------------------
255 #------------------------------------------------------------------------------
255
256
256
257
257 func_attrs = ['__code__', '__defaults__', '__doc__',
258 func_attrs = ['__code__', '__defaults__', '__doc__',
258 '__closure__', '__globals__', '__dict__']
259 '__closure__', '__globals__', '__dict__']
259
260
260
261
261 def update_function(old, new):
262 def update_function(old, new):
262 """Upgrade the code object of a function"""
263 """Upgrade the code object of a function"""
263 for name in func_attrs:
264 for name in func_attrs:
264 try:
265 try:
265 setattr(old, name, getattr(new, name))
266 setattr(old, name, getattr(new, name))
266 except (AttributeError, TypeError):
267 except (AttributeError, TypeError):
267 pass
268 pass
268
269
269
270
271 def update_instances(old, new, objects=None, visited={}):
272 """Iterate through objects recursively, searching for instances of old and
273 replace their __class__ reference with new. If no objects are given, start
274 with the current ipython workspace.
275 """
276 if objects is None:
277 # make sure visited is cleaned when not called recursively
278 visited = {}
279 # find ipython workspace stack frame
280 frame = next(frame_nfo.frame for frame_nfo in inspect.stack()
281 if 'trigger' in frame_nfo.function)
282 # build generator for non-private variable values from workspace
283 shell = frame.f_locals['self'].shell
284 user_ns = shell.user_ns
285 user_ns_hidden = shell.user_ns_hidden
286 nonmatching = object()
287 objects = ( value for key, value in user_ns.items()
288 if not key.startswith('_')
289 and (value is not user_ns_hidden.get(key, nonmatching))
290 and not inspect.ismodule(value))
291
292 # use dict values if objects is a dict but don't touch private variables
293 if hasattr(objects, 'items'):
294 objects = (value for key, value in objects.items()
295 if not str(key).startswith('_')
296 and not inspect.ismodule(value) )
297
298 # try if objects is iterable
299 try:
300 for obj in (obj for obj in objects if id(obj) not in visited):
301 # add current object to visited to avoid revisiting
302 visited.update({id(obj):obj})
303
304 # update, if object is instance of old_class (but no subclasses)
305 if type(obj) is old:
306 obj.__class__ = new
307
308
309 # if object is instance of other class, look for nested instances
310 if hasattr(obj, '__dict__') and not (inspect.isfunction(obj)
311 or inspect.ismethod(obj)):
312 update_instances(old, new, obj.__dict__, visited)
313
314 # if object is a container, search it
315 if hasattr(obj, 'items') or (hasattr(obj, '__contains__')
316 and not isinstance(obj, str)):
317 update_instances(old, new, obj, visited)
318
319 except TypeError:
320 pass
321
322
270 def update_class(old, new):
323 def update_class(old, new):
271 """Replace stuff in the __dict__ of a class, and upgrade
324 """Replace stuff in the __dict__ of a class, and upgrade
272 method code objects, and add new methods, if any"""
325 method code objects, and add new methods, if any"""
273 for key in list(old.__dict__.keys()):
326 for key in list(old.__dict__.keys()):
274 old_obj = getattr(old, key)
327 old_obj = getattr(old, key)
275 try:
328 try:
276 new_obj = getattr(new, key)
329 new_obj = getattr(new, key)
277 # explicitly checking that comparison returns True to handle
330 # explicitly checking that comparison returns True to handle
278 # cases where `==` doesn't return a boolean.
331 # cases where `==` doesn't return a boolean.
279 if (old_obj == new_obj) is True:
332 if (old_obj == new_obj) is True:
280 continue
333 continue
281 except AttributeError:
334 except AttributeError:
282 # obsolete attribute: remove it
335 # obsolete attribute: remove it
283 try:
336 try:
284 delattr(old, key)
337 delattr(old, key)
285 except (AttributeError, TypeError):
338 except (AttributeError, TypeError):
286 pass
339 pass
287 continue
340 continue
288
341
289 if update_generic(old_obj, new_obj): continue
342 if update_generic(old_obj, new_obj): continue
290
343
291 try:
344 try:
292 setattr(old, key, getattr(new, key))
345 setattr(old, key, getattr(new, key))
293 except (AttributeError, TypeError):
346 except (AttributeError, TypeError):
294 pass # skip non-writable attributes
347 pass # skip non-writable attributes
295
348
296 for key in list(new.__dict__.keys()):
349 for key in list(new.__dict__.keys()):
297 if key not in list(old.__dict__.keys()):
350 if key not in list(old.__dict__.keys()):
298 try:
351 try:
299 setattr(old, key, getattr(new, key))
352 setattr(old, key, getattr(new, key))
300 except (AttributeError, TypeError):
353 except (AttributeError, TypeError):
301 pass # skip non-writable attributes
354 pass # skip non-writable attributes
302
355
356 # update all instances of class
357 update_instances(old, new)
358
303
359
304 def update_property(old, new):
360 def update_property(old, new):
305 """Replace get/set/del functions of a property"""
361 """Replace get/set/del functions of a property"""
306 update_generic(old.fdel, new.fdel)
362 update_generic(old.fdel, new.fdel)
307 update_generic(old.fget, new.fget)
363 update_generic(old.fget, new.fget)
308 update_generic(old.fset, new.fset)
364 update_generic(old.fset, new.fset)
309
365
310
366
311 def isinstance2(a, b, typ):
367 def isinstance2(a, b, typ):
312 return isinstance(a, typ) and isinstance(b, typ)
368 return isinstance(a, typ) and isinstance(b, typ)
313
369
314
370
315 UPDATE_RULES = [
371 UPDATE_RULES = [
316 (lambda a, b: isinstance2(a, b, type),
372 (lambda a, b: isinstance2(a, b, type),
317 update_class),
373 update_class),
318 (lambda a, b: isinstance2(a, b, types.FunctionType),
374 (lambda a, b: isinstance2(a, b, types.FunctionType),
319 update_function),
375 update_function),
320 (lambda a, b: isinstance2(a, b, property),
376 (lambda a, b: isinstance2(a, b, property),
321 update_property),
377 update_property),
322 ]
378 ]
323 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
379 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
324 lambda a, b: update_function(a.__func__, b.__func__)),
380 lambda a, b: update_function(a.__func__, b.__func__)),
325 ])
381 ])
326
382
327
383
328 def update_generic(a, b):
384 def update_generic(a, b):
329 for type_check, update in UPDATE_RULES:
385 for type_check, update in UPDATE_RULES:
330 if type_check(a, b):
386 if type_check(a, b):
331 update(a, b)
387 update(a, b)
332 return True
388 return True
333 return False
389 return False
334
390
335
391
336 class StrongRef(object):
392 class StrongRef(object):
337 def __init__(self, obj):
393 def __init__(self, obj):
338 self.obj = obj
394 self.obj = obj
339 def __call__(self):
395 def __call__(self):
340 return self.obj
396 return self.obj
341
397
342
398
343 def superreload(module, reload=reload, old_objects=None):
399 def superreload(module, reload=reload, old_objects=None):
344 """Enhanced version of the builtin reload function.
400 """Enhanced version of the builtin reload function.
345
401
346 superreload remembers objects previously in the module, and
402 superreload remembers objects previously in the module, and
347
403
348 - upgrades the class dictionary of every old class in the module
404 - upgrades the class dictionary of every old class in the module
349 - upgrades the code object of every old function and method
405 - upgrades the code object of every old function and method
350 - clears the module's namespace before reloading
406 - clears the module's namespace before reloading
351
407
352 """
408 """
353 if old_objects is None:
409 if old_objects is None:
354 old_objects = {}
410 old_objects = {}
355
411
356 # collect old objects in the module
412 # collect old objects in the module
357 for name, obj in list(module.__dict__.items()):
413 for name, obj in list(module.__dict__.items()):
358 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
414 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
359 continue
415 continue
360 key = (module.__name__, name)
416 key = (module.__name__, name)
361 try:
417 try:
362 old_objects.setdefault(key, []).append(weakref.ref(obj))
418 old_objects.setdefault(key, []).append(weakref.ref(obj))
363 except TypeError:
419 except TypeError:
364 pass
420 pass
365
421
366 # reload module
422 # reload module
367 try:
423 try:
368 # clear namespace first from old cruft
424 # clear namespace first from old cruft
369 old_dict = module.__dict__.copy()
425 old_dict = module.__dict__.copy()
370 old_name = module.__name__
426 old_name = module.__name__
371 module.__dict__.clear()
427 module.__dict__.clear()
372 module.__dict__['__name__'] = old_name
428 module.__dict__['__name__'] = old_name
373 module.__dict__['__loader__'] = old_dict['__loader__']
429 module.__dict__['__loader__'] = old_dict['__loader__']
374 except (TypeError, AttributeError, KeyError):
430 except (TypeError, AttributeError, KeyError):
375 pass
431 pass
376
432
377 try:
433 try:
378 module = reload(module)
434 module = reload(module)
379 except:
435 except:
380 # restore module dictionary on failed reload
436 # restore module dictionary on failed reload
381 module.__dict__.update(old_dict)
437 module.__dict__.update(old_dict)
382 raise
438 raise
383
439
384 # iterate over all objects and update functions & classes
440 # iterate over all objects and update functions & classes
385 for name, new_obj in list(module.__dict__.items()):
441 for name, new_obj in list(module.__dict__.items()):
386 key = (module.__name__, name)
442 key = (module.__name__, name)
387 if key not in old_objects: continue
443 if key not in old_objects: continue
388
444
389 new_refs = []
445 new_refs = []
390 for old_ref in old_objects[key]:
446 for old_ref in old_objects[key]:
391 old_obj = old_ref()
447 old_obj = old_ref()
392 if old_obj is None: continue
448 if old_obj is None: continue
393 new_refs.append(old_ref)
449 new_refs.append(old_ref)
394 update_generic(old_obj, new_obj)
450 update_generic(old_obj, new_obj)
395
451
396 if new_refs:
452 if new_refs:
397 old_objects[key] = new_refs
453 old_objects[key] = new_refs
398 else:
454 else:
399 del old_objects[key]
455 del old_objects[key]
400
456
401 return module
457 return module
402
458
403 #------------------------------------------------------------------------------
459 #------------------------------------------------------------------------------
404 # IPython connectivity
460 # IPython connectivity
405 #------------------------------------------------------------------------------
461 #------------------------------------------------------------------------------
406
462
407 from IPython.core.magic import Magics, magics_class, line_magic
463 from IPython.core.magic import Magics, magics_class, line_magic
408
464
409 @magics_class
465 @magics_class
410 class AutoreloadMagics(Magics):
466 class AutoreloadMagics(Magics):
411 def __init__(self, *a, **kw):
467 def __init__(self, *a, **kw):
412 super(AutoreloadMagics, self).__init__(*a, **kw)
468 super(AutoreloadMagics, self).__init__(*a, **kw)
413 self._reloader = ModuleReloader()
469 self._reloader = ModuleReloader()
414 self._reloader.check_all = False
470 self._reloader.check_all = False
415 self.loaded_modules = set(sys.modules)
471 self.loaded_modules = set(sys.modules)
416
472
417 @line_magic
473 @line_magic
418 def autoreload(self, parameter_s=''):
474 def autoreload(self, parameter_s=''):
419 r"""%autoreload => Reload modules automatically
475 r"""%autoreload => Reload modules automatically
420
476
421 %autoreload
477 %autoreload
422 Reload all modules (except those excluded by %aimport) automatically
478 Reload all modules (except those excluded by %aimport) automatically
423 now.
479 now.
424
480
425 %autoreload 0
481 %autoreload 0
426 Disable automatic reloading.
482 Disable automatic reloading.
427
483
428 %autoreload 1
484 %autoreload 1
429 Reload all modules imported with %aimport every time before executing
485 Reload all modules imported with %aimport every time before executing
430 the Python code typed.
486 the Python code typed.
431
487
432 %autoreload 2
488 %autoreload 2
433 Reload all modules (except those excluded by %aimport) every time
489 Reload all modules (except those excluded by %aimport) every time
434 before executing the Python code typed.
490 before executing the Python code typed.
435
491
436 Reloading Python modules in a reliable way is in general
492 Reloading Python modules in a reliable way is in general
437 difficult, and unexpected things may occur. %autoreload tries to
493 difficult, and unexpected things may occur. %autoreload tries to
438 work around common pitfalls by replacing function code objects and
494 work around common pitfalls by replacing function code objects and
439 parts of classes previously in the module with new versions. This
495 parts of classes previously in the module with new versions. This
440 makes the following things to work:
496 makes the following things to work:
441
497
442 - Functions and classes imported via 'from xxx import foo' are upgraded
498 - Functions and classes imported via 'from xxx import foo' are upgraded
443 to new versions when 'xxx' is reloaded.
499 to new versions when 'xxx' is reloaded.
444
500
445 - Methods and properties of classes are upgraded on reload, so that
501 - Methods and properties of classes are upgraded on reload, so that
446 calling 'c.foo()' on an object 'c' created before the reload causes
502 calling 'c.foo()' on an object 'c' created before the reload causes
447 the new code for 'foo' to be executed.
503 the new code for 'foo' to be executed.
448
504
449 Some of the known remaining caveats are:
505 Some of the known remaining caveats are:
450
506
451 - Replacing code objects does not always succeed: changing a @property
507 - Replacing code objects does not always succeed: changing a @property
452 in a class to an ordinary method or a method to a member variable
508 in a class to an ordinary method or a method to a member variable
453 can cause problems (but in old objects only).
509 can cause problems (but in old objects only).
454
510
455 - Functions that are removed (eg. via monkey-patching) from a module
511 - Functions that are removed (eg. via monkey-patching) from a module
456 before it is reloaded are not upgraded.
512 before it is reloaded are not upgraded.
457
513
458 - C extension modules cannot be reloaded, and so cannot be
514 - C extension modules cannot be reloaded, and so cannot be
459 autoreloaded.
515 autoreloaded.
460
516
461 """
517 """
462 if parameter_s == '':
518 if parameter_s == '':
463 self._reloader.check(True)
519 self._reloader.check(True)
464 elif parameter_s == '0':
520 elif parameter_s == '0':
465 self._reloader.enabled = False
521 self._reloader.enabled = False
466 elif parameter_s == '1':
522 elif parameter_s == '1':
467 self._reloader.check_all = False
523 self._reloader.check_all = False
468 self._reloader.enabled = True
524 self._reloader.enabled = True
469 elif parameter_s == '2':
525 elif parameter_s == '2':
470 self._reloader.check_all = True
526 self._reloader.check_all = True
471 self._reloader.enabled = True
527 self._reloader.enabled = True
472
528
473 @line_magic
529 @line_magic
474 def aimport(self, parameter_s='', stream=None):
530 def aimport(self, parameter_s='', stream=None):
475 """%aimport => Import modules for automatic reloading.
531 """%aimport => Import modules for automatic reloading.
476
532
477 %aimport
533 %aimport
478 List modules to automatically import and not to import.
534 List modules to automatically import and not to import.
479
535
480 %aimport foo
536 %aimport foo
481 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
537 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
482
538
483 %aimport foo, bar
539 %aimport foo, bar
484 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
540 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
485
541
486 %aimport -foo
542 %aimport -foo
487 Mark module 'foo' to not be autoreloaded for %autoreload 1
543 Mark module 'foo' to not be autoreloaded for %autoreload 1
488 """
544 """
489 modname = parameter_s
545 modname = parameter_s
490 if not modname:
546 if not modname:
491 to_reload = sorted(self._reloader.modules.keys())
547 to_reload = sorted(self._reloader.modules.keys())
492 to_skip = sorted(self._reloader.skip_modules.keys())
548 to_skip = sorted(self._reloader.skip_modules.keys())
493 if stream is None:
549 if stream is None:
494 stream = sys.stdout
550 stream = sys.stdout
495 if self._reloader.check_all:
551 if self._reloader.check_all:
496 stream.write("Modules to reload:\nall-except-skipped\n")
552 stream.write("Modules to reload:\nall-except-skipped\n")
497 else:
553 else:
498 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
554 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
499 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
555 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
500 elif modname.startswith('-'):
556 elif modname.startswith('-'):
501 modname = modname[1:]
557 modname = modname[1:]
502 self._reloader.mark_module_skipped(modname)
558 self._reloader.mark_module_skipped(modname)
503 else:
559 else:
504 for _module in ([_.strip() for _ in modname.split(',')]):
560 for _module in ([_.strip() for _ in modname.split(',')]):
505 top_module, top_name = self._reloader.aimport_module(_module)
561 top_module, top_name = self._reloader.aimport_module(_module)
506
562
507 # Inject module to user namespace
563 # Inject module to user namespace
508 self.shell.push({top_name: top_module})
564 self.shell.push({top_name: top_module})
509
565
510 def pre_run_cell(self):
566 def pre_run_cell(self):
511 if self._reloader.enabled:
567 if self._reloader.enabled:
512 try:
568 try:
513 self._reloader.check()
569 self._reloader.check()
514 except:
570 except:
515 pass
571 pass
516
572
517 def post_execute_hook(self):
573 def post_execute_hook(self):
518 """Cache the modification times of any modules imported in this execution
574 """Cache the modification times of any modules imported in this execution
519 """
575 """
520 newly_loaded_modules = set(sys.modules) - self.loaded_modules
576 newly_loaded_modules = set(sys.modules) - self.loaded_modules
521 for modname in newly_loaded_modules:
577 for modname in newly_loaded_modules:
522 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
578 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
523 if pymtime is not None:
579 if pymtime is not None:
524 self._reloader.modules_mtimes[modname] = pymtime
580 self._reloader.modules_mtimes[modname] = pymtime
525
581
526 self.loaded_modules.update(newly_loaded_modules)
582 self.loaded_modules.update(newly_loaded_modules)
527
583
528
584
529 def load_ipython_extension(ip):
585 def load_ipython_extension(ip):
530 """Load the extension in IPython."""
586 """Load the extension in IPython."""
531 auto_reload = AutoreloadMagics(ip)
587 auto_reload = AutoreloadMagics(ip)
532 ip.register_magics(auto_reload)
588 ip.register_magics(auto_reload)
533 ip.events.register('pre_run_cell', auto_reload.pre_run_cell)
589 ip.events.register('pre_run_cell', auto_reload.pre_run_cell)
534 ip.events.register('post_execute', auto_reload.post_execute_hook)
590 ip.events.register('post_execute', auto_reload.post_execute_hook)
@@ -1,398 +1,449 b''
1 """Tests for autoreload extension.
1 """Tests for autoreload extension.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (c) 2012 IPython Development Team.
4 # Copyright (c) 2012 IPython Development Team.
5 #
5 #
6 # Distributed under the terms of the Modified BSD License.
6 # Distributed under the terms of the Modified BSD License.
7 #
7 #
8 # The full license is in the file COPYING.txt, distributed with this software.
8 # The full license is in the file COPYING.txt, distributed with this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 import os
15 import os
16 import sys
16 import sys
17 import tempfile
17 import tempfile
18 import textwrap
18 import textwrap
19 import shutil
19 import shutil
20 import random
20 import random
21 import time
21 import time
22 from io import StringIO
22 from io import StringIO
23
23
24 import nose.tools as nt
24 import nose.tools as nt
25 import IPython.testing.tools as tt
25 import IPython.testing.tools as tt
26
26
27 from IPython.testing.decorators import skipif
27 from IPython.testing.decorators import skipif
28
28
29 from IPython.extensions.autoreload import AutoreloadMagics
29 from IPython.extensions.autoreload import AutoreloadMagics
30 from IPython.core.events import EventManager, pre_run_cell
30 from IPython.core.events import EventManager, pre_run_cell
31
31
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33 # Test fixture
33 # Test fixture
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35
35
36 noop = lambda *a, **kw: None
36 noop = lambda *a, **kw: None
37
37
38 class FakeShell(object):
38 class FakeShell:
39
39
40 def __init__(self):
40 def __init__(self):
41 self.ns = {}
41 self.ns = {}
42 self.user_ns = self.ns
43 self.user_ns_hidden = {}
42 self.events = EventManager(self, {'pre_run_cell', pre_run_cell})
44 self.events = EventManager(self, {'pre_run_cell', pre_run_cell})
43 self.auto_magics = AutoreloadMagics(shell=self)
45 self.auto_magics = AutoreloadMagics(shell=self)
44 self.events.register('pre_run_cell', self.auto_magics.pre_run_cell)
46 self.events.register('pre_run_cell', self.auto_magics.pre_run_cell)
45
47
46 register_magics = set_hook = noop
48 register_magics = set_hook = noop
47
49
48 def run_code(self, code):
50 def run_code(self, code):
49 self.events.trigger('pre_run_cell')
51 self.events.trigger('pre_run_cell')
50 exec(code, self.ns)
52 exec(code, self.user_ns)
51 self.auto_magics.post_execute_hook()
53 self.auto_magics.post_execute_hook()
52
54
53 def push(self, items):
55 def push(self, items):
54 self.ns.update(items)
56 self.ns.update(items)
55
57
56 def magic_autoreload(self, parameter):
58 def magic_autoreload(self, parameter):
57 self.auto_magics.autoreload(parameter)
59 self.auto_magics.autoreload(parameter)
58
60
59 def magic_aimport(self, parameter, stream=None):
61 def magic_aimport(self, parameter, stream=None):
60 self.auto_magics.aimport(parameter, stream=stream)
62 self.auto_magics.aimport(parameter, stream=stream)
61 self.auto_magics.post_execute_hook()
63 self.auto_magics.post_execute_hook()
62
64
63
65
64 class Fixture(object):
66 class Fixture(object):
65 """Fixture for creating test module files"""
67 """Fixture for creating test module files"""
66
68
67 test_dir = None
69 test_dir = None
68 old_sys_path = None
70 old_sys_path = None
69 filename_chars = "abcdefghijklmopqrstuvwxyz0123456789"
71 filename_chars = "abcdefghijklmopqrstuvwxyz0123456789"
70
72
71 def setUp(self):
73 def setUp(self):
72 self.test_dir = tempfile.mkdtemp()
74 self.test_dir = tempfile.mkdtemp()
73 self.old_sys_path = list(sys.path)
75 self.old_sys_path = list(sys.path)
74 sys.path.insert(0, self.test_dir)
76 sys.path.insert(0, self.test_dir)
75 self.shell = FakeShell()
77 self.shell = FakeShell()
76
78
77 def tearDown(self):
79 def tearDown(self):
78 shutil.rmtree(self.test_dir)
80 shutil.rmtree(self.test_dir)
79 sys.path = self.old_sys_path
81 sys.path = self.old_sys_path
80
82
81 self.test_dir = None
83 self.test_dir = None
82 self.old_sys_path = None
84 self.old_sys_path = None
83 self.shell = None
85 self.shell = None
84
86
85 def get_module(self):
87 def get_module(self):
86 module_name = "tmpmod_" + "".join(random.sample(self.filename_chars,20))
88 module_name = "tmpmod_" + "".join(random.sample(self.filename_chars,20))
87 if module_name in sys.modules:
89 if module_name in sys.modules:
88 del sys.modules[module_name]
90 del sys.modules[module_name]
89 file_name = os.path.join(self.test_dir, module_name + ".py")
91 file_name = os.path.join(self.test_dir, module_name + ".py")
90 return module_name, file_name
92 return module_name, file_name
91
93
92 def write_file(self, filename, content):
94 def write_file(self, filename, content):
93 """
95 """
94 Write a file, and force a timestamp difference of at least one second
96 Write a file, and force a timestamp difference of at least one second
95
97
96 Notes
98 Notes
97 -----
99 -----
98 Python's .pyc files record the timestamp of their compilation
100 Python's .pyc files record the timestamp of their compilation
99 with a time resolution of one second.
101 with a time resolution of one second.
100
102
101 Therefore, we need to force a timestamp difference between .py
103 Therefore, we need to force a timestamp difference between .py
102 and .pyc, without having the .py file be timestamped in the
104 and .pyc, without having the .py file be timestamped in the
103 future, and without changing the timestamp of the .pyc file
105 future, and without changing the timestamp of the .pyc file
104 (because that is stored in the file). The only reliable way
106 (because that is stored in the file). The only reliable way
105 to achieve this seems to be to sleep.
107 to achieve this seems to be to sleep.
106 """
108 """
107
109 content = textwrap.dedent(content)
108 # Sleep one second + eps
110 # Sleep one second + eps
109 time.sleep(1.05)
111 time.sleep(1.05)
110
112
111 # Write
113 # Write
112 with open(filename, 'w') as f:
114 with open(filename, 'w') as f:
113 f.write(content)
115 f.write(content)
114
116
115 def new_module(self, code):
117 def new_module(self, code):
118 code = textwrap.dedent(code)
116 mod_name, mod_fn = self.get_module()
119 mod_name, mod_fn = self.get_module()
117 with open(mod_fn, 'w') as f:
120 with open(mod_fn, 'w') as f:
118 f.write(code)
121 f.write(code)
119 return mod_name, mod_fn
122 return mod_name, mod_fn
120
123
121 #-----------------------------------------------------------------------------
124 #-----------------------------------------------------------------------------
122 # Test automatic reloading
125 # Test automatic reloading
123 #-----------------------------------------------------------------------------
126 #-----------------------------------------------------------------------------
124
127
128 def pickle_get_current_class(obj):
129 """
130 Original issue comes from pickle; hence the name.
131 """
132 name = obj.__class__.__name__
133 module_name = getattr(obj, "__module__", None)
134 obj2 = sys.modules[module_name]
135 for subpath in name.split("."):
136 obj2 = getattr(obj2, subpath)
137 return obj2
138
125 class TestAutoreload(Fixture):
139 class TestAutoreload(Fixture):
126
140
127 @skipif(sys.version_info < (3, 6))
141 @skipif(sys.version_info < (3, 6))
128 def test_reload_enums(self):
142 def test_reload_enums(self):
129 import enum
143 import enum
130 mod_name, mod_fn = self.new_module(textwrap.dedent("""
144 mod_name, mod_fn = self.new_module(textwrap.dedent("""
131 from enum import Enum
145 from enum import Enum
132 class MyEnum(Enum):
146 class MyEnum(Enum):
133 A = 'A'
147 A = 'A'
134 B = 'B'
148 B = 'B'
135 """))
149 """))
136 self.shell.magic_autoreload("2")
150 self.shell.magic_autoreload("2")
137 self.shell.magic_aimport(mod_name)
151 self.shell.magic_aimport(mod_name)
138 self.write_file(mod_fn, textwrap.dedent("""
152 self.write_file(mod_fn, textwrap.dedent("""
139 from enum import Enum
153 from enum import Enum
140 class MyEnum(Enum):
154 class MyEnum(Enum):
141 A = 'A'
155 A = 'A'
142 B = 'B'
156 B = 'B'
143 C = 'C'
157 C = 'C'
144 """))
158 """))
145 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
159 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
146 self.shell.run_code("pass") # trigger another reload
160 self.shell.run_code("pass") # trigger another reload
147
161
162 def test_reload_class_type(self):
163 self.shell.magic_autoreload("2")
164 mod_name, mod_fn = self.new_module(
165 """
166 class Test():
167 def meth(self):
168 return "old"
169 """
170 )
171 assert "test" not in self.shell.ns
172 assert "result" not in self.shell.ns
173
174 self.shell.run_code("from %s import Test" % mod_name)
175 self.shell.run_code("test = Test()")
176
177 self.write_file(
178 mod_fn,
179 """
180 class Test():
181 def meth(self):
182 return "new"
183 """,
184 )
185
186 test_object = self.shell.ns["test"]
187
188 # important to trigger autoreload logic !
189 self.shell.run_code("pass")
190
191 test_class = pickle_get_current_class(test_object)
192 assert isinstance(test_object, test_class)
193
194 # extra check.
195 self.shell.run_code("import pickle")
196 self.shell.run_code("p = pickle.dumps(test)")
197
148 def test_reload_class_attributes(self):
198 def test_reload_class_attributes(self):
149 self.shell.magic_autoreload("2")
199 self.shell.magic_autoreload("2")
150 mod_name, mod_fn = self.new_module(textwrap.dedent("""
200 mod_name, mod_fn = self.new_module(textwrap.dedent("""
151 class MyClass:
201 class MyClass:
152
202
153 def __init__(self, a=10):
203 def __init__(self, a=10):
154 self.a = a
204 self.a = a
155 self.b = 22
205 self.b = 22
156 # self.toto = 33
206 # self.toto = 33
157
207
158 def square(self):
208 def square(self):
159 print('compute square')
209 print('compute square')
160 return self.a*self.a
210 return self.a*self.a
161 """
211 """
162 )
212 )
163 )
213 )
164 self.shell.run_code("from %s import MyClass" % mod_name)
214 self.shell.run_code("from %s import MyClass" % mod_name)
165 self.shell.run_code("first = MyClass(5)")
215 self.shell.run_code("first = MyClass(5)")
166 self.shell.run_code("first.square()")
216 self.shell.run_code("first.square()")
167 with nt.assert_raises(AttributeError):
217 with nt.assert_raises(AttributeError):
168 self.shell.run_code("first.cube()")
218 self.shell.run_code("first.cube()")
169 with nt.assert_raises(AttributeError):
219 with nt.assert_raises(AttributeError):
170 self.shell.run_code("first.power(5)")
220 self.shell.run_code("first.power(5)")
171 self.shell.run_code("first.b")
221 self.shell.run_code("first.b")
172 with nt.assert_raises(AttributeError):
222 with nt.assert_raises(AttributeError):
173 self.shell.run_code("first.toto")
223 self.shell.run_code("first.toto")
174
224
175 # remove square, add power
225 # remove square, add power
176
226
177 self.write_file(
227 self.write_file(
178 mod_fn,
228 mod_fn,
179 textwrap.dedent(
229 textwrap.dedent(
180 """
230 """
181 class MyClass:
231 class MyClass:
182
232
183 def __init__(self, a=10):
233 def __init__(self, a=10):
184 self.a = a
234 self.a = a
185 self.b = 11
235 self.b = 11
186
236
187 def power(self, p):
237 def power(self, p):
188 print('compute power '+str(p))
238 print('compute power '+str(p))
189 return self.a**p
239 return self.a**p
190 """
240 """
191 ),
241 ),
192 )
242 )
193
243
194 self.shell.run_code("second = MyClass(5)")
244 self.shell.run_code("second = MyClass(5)")
195
245
196 for object_name in {'first', 'second'}:
246 for object_name in {'first', 'second'}:
197 self.shell.run_code("{object_name}.power(5)".format(object_name=object_name))
247 self.shell.run_code("{object_name}.power(5)".format(object_name=object_name))
198 with nt.assert_raises(AttributeError):
248 with nt.assert_raises(AttributeError):
199 self.shell.run_code("{object_name}.cube()".format(object_name=object_name))
249 self.shell.run_code("{object_name}.cube()".format(object_name=object_name))
200 with nt.assert_raises(AttributeError):
250 with nt.assert_raises(AttributeError):
201 self.shell.run_code("{object_name}.square()".format(object_name=object_name))
251 self.shell.run_code("{object_name}.square()".format(object_name=object_name))
202 self.shell.run_code("{object_name}.b".format(object_name=object_name))
252 self.shell.run_code("{object_name}.b".format(object_name=object_name))
203 self.shell.run_code("{object_name}.a".format(object_name=object_name))
253 self.shell.run_code("{object_name}.a".format(object_name=object_name))
204 with nt.assert_raises(AttributeError):
254 with nt.assert_raises(AttributeError):
205 self.shell.run_code("{object_name}.toto".format(object_name=object_name))
255 self.shell.run_code("{object_name}.toto".format(object_name=object_name))
206
256
207 def _check_smoketest(self, use_aimport=True):
257 def _check_smoketest(self, use_aimport=True):
208 """
258 """
209 Functional test for the automatic reloader using either
259 Functional test for the automatic reloader using either
210 '%autoreload 1' or '%autoreload 2'
260 '%autoreload 1' or '%autoreload 2'
211 """
261 """
212
262
213 mod_name, mod_fn = self.new_module("""
263 mod_name, mod_fn = self.new_module("""
214 x = 9
264 x = 9
215
265
216 z = 123 # this item will be deleted
266 z = 123 # this item will be deleted
217
267
218 def foo(y):
268 def foo(y):
219 return y + 3
269 return y + 3
220
270
221 class Baz(object):
271 class Baz(object):
222 def __init__(self, x):
272 def __init__(self, x):
223 self.x = x
273 self.x = x
224 def bar(self, y):
274 def bar(self, y):
225 return self.x + y
275 return self.x + y
226 @property
276 @property
227 def quux(self):
277 def quux(self):
228 return 42
278 return 42
229 def zzz(self):
279 def zzz(self):
230 '''This method will be deleted below'''
280 '''This method will be deleted below'''
231 return 99
281 return 99
232
282
233 class Bar: # old-style class: weakref doesn't work for it on Python < 2.7
283 class Bar: # old-style class: weakref doesn't work for it on Python < 2.7
234 def foo(self):
284 def foo(self):
235 return 1
285 return 1
236 """)
286 """)
237
287
238 #
288 #
239 # Import module, and mark for reloading
289 # Import module, and mark for reloading
240 #
290 #
241 if use_aimport:
291 if use_aimport:
242 self.shell.magic_autoreload("1")
292 self.shell.magic_autoreload("1")
243 self.shell.magic_aimport(mod_name)
293 self.shell.magic_aimport(mod_name)
244 stream = StringIO()
294 stream = StringIO()
245 self.shell.magic_aimport("", stream=stream)
295 self.shell.magic_aimport("", stream=stream)
246 nt.assert_in(("Modules to reload:\n%s" % mod_name), stream.getvalue())
296 nt.assert_in(("Modules to reload:\n%s" % mod_name), stream.getvalue())
247
297
248 with nt.assert_raises(ImportError):
298 with nt.assert_raises(ImportError):
249 self.shell.magic_aimport("tmpmod_as318989e89ds")
299 self.shell.magic_aimport("tmpmod_as318989e89ds")
250 else:
300 else:
251 self.shell.magic_autoreload("2")
301 self.shell.magic_autoreload("2")
252 self.shell.run_code("import %s" % mod_name)
302 self.shell.run_code("import %s" % mod_name)
253 stream = StringIO()
303 stream = StringIO()
254 self.shell.magic_aimport("", stream=stream)
304 self.shell.magic_aimport("", stream=stream)
255 nt.assert_true("Modules to reload:\nall-except-skipped" in
305 nt.assert_true("Modules to reload:\nall-except-skipped" in
256 stream.getvalue())
306 stream.getvalue())
257 nt.assert_in(mod_name, self.shell.ns)
307 nt.assert_in(mod_name, self.shell.ns)
258
308
259 mod = sys.modules[mod_name]
309 mod = sys.modules[mod_name]
260
310
261 #
311 #
262 # Test module contents
312 # Test module contents
263 #
313 #
264 old_foo = mod.foo
314 old_foo = mod.foo
265 old_obj = mod.Baz(9)
315 old_obj = mod.Baz(9)
266 old_obj2 = mod.Bar()
316 old_obj2 = mod.Bar()
267
317
268 def check_module_contents():
318 def check_module_contents():
269 nt.assert_equal(mod.x, 9)
319 nt.assert_equal(mod.x, 9)
270 nt.assert_equal(mod.z, 123)
320 nt.assert_equal(mod.z, 123)
271
321
272 nt.assert_equal(old_foo(0), 3)
322 nt.assert_equal(old_foo(0), 3)
273 nt.assert_equal(mod.foo(0), 3)
323 nt.assert_equal(mod.foo(0), 3)
274
324
275 obj = mod.Baz(9)
325 obj = mod.Baz(9)
276 nt.assert_equal(old_obj.bar(1), 10)
326 nt.assert_equal(old_obj.bar(1), 10)
277 nt.assert_equal(obj.bar(1), 10)
327 nt.assert_equal(obj.bar(1), 10)
278 nt.assert_equal(obj.quux, 42)
328 nt.assert_equal(obj.quux, 42)
279 nt.assert_equal(obj.zzz(), 99)
329 nt.assert_equal(obj.zzz(), 99)
280
330
281 obj2 = mod.Bar()
331 obj2 = mod.Bar()
282 nt.assert_equal(old_obj2.foo(), 1)
332 nt.assert_equal(old_obj2.foo(), 1)
283 nt.assert_equal(obj2.foo(), 1)
333 nt.assert_equal(obj2.foo(), 1)
284
334
285 check_module_contents()
335 check_module_contents()
286
336
287 #
337 #
288 # Simulate a failed reload: no reload should occur and exactly
338 # Simulate a failed reload: no reload should occur and exactly
289 # one error message should be printed
339 # one error message should be printed
290 #
340 #
291 self.write_file(mod_fn, """
341 self.write_file(mod_fn, """
292 a syntax error
342 a syntax error
293 """)
343 """)
294
344
295 with tt.AssertPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
345 with tt.AssertPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
296 self.shell.run_code("pass") # trigger reload
346 self.shell.run_code("pass") # trigger reload
297 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
347 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
298 self.shell.run_code("pass") # trigger another reload
348 self.shell.run_code("pass") # trigger another reload
299 check_module_contents()
349 check_module_contents()
300
350
301 #
351 #
302 # Rewrite module (this time reload should succeed)
352 # Rewrite module (this time reload should succeed)
303 #
353 #
304 self.write_file(mod_fn, """
354 self.write_file(mod_fn, """
305 x = 10
355 x = 10
306
356
307 def foo(y):
357 def foo(y):
308 return y + 4
358 return y + 4
309
359
310 class Baz(object):
360 class Baz(object):
311 def __init__(self, x):
361 def __init__(self, x):
312 self.x = x
362 self.x = x
313 def bar(self, y):
363 def bar(self, y):
314 return self.x + y + 1
364 return self.x + y + 1
315 @property
365 @property
316 def quux(self):
366 def quux(self):
317 return 43
367 return 43
318
368
319 class Bar: # old-style class
369 class Bar: # old-style class
320 def foo(self):
370 def foo(self):
321 return 2
371 return 2
322 """)
372 """)
323
373
324 def check_module_contents():
374 def check_module_contents():
325 nt.assert_equal(mod.x, 10)
375 nt.assert_equal(mod.x, 10)
326 nt.assert_false(hasattr(mod, 'z'))
376 nt.assert_false(hasattr(mod, 'z'))
327
377
328 nt.assert_equal(old_foo(0), 4) # superreload magic!
378 nt.assert_equal(old_foo(0), 4) # superreload magic!
329 nt.assert_equal(mod.foo(0), 4)
379 nt.assert_equal(mod.foo(0), 4)
330
380
331 obj = mod.Baz(9)
381 obj = mod.Baz(9)
332 nt.assert_equal(old_obj.bar(1), 11) # superreload magic!
382 nt.assert_equal(old_obj.bar(1), 11) # superreload magic!
333 nt.assert_equal(obj.bar(1), 11)
383 nt.assert_equal(obj.bar(1), 11)
334
384
335 nt.assert_equal(old_obj.quux, 43)
385 nt.assert_equal(old_obj.quux, 43)
336 nt.assert_equal(obj.quux, 43)
386 nt.assert_equal(obj.quux, 43)
337
387
338 nt.assert_false(hasattr(old_obj, 'zzz'))
388 nt.assert_false(hasattr(old_obj, 'zzz'))
339 nt.assert_false(hasattr(obj, 'zzz'))
389 nt.assert_false(hasattr(obj, 'zzz'))
340
390
341 obj2 = mod.Bar()
391 obj2 = mod.Bar()
342 nt.assert_equal(old_obj2.foo(), 2)
392 nt.assert_equal(old_obj2.foo(), 2)
343 nt.assert_equal(obj2.foo(), 2)
393 nt.assert_equal(obj2.foo(), 2)
344
394
345 self.shell.run_code("pass") # trigger reload
395 self.shell.run_code("pass") # trigger reload
346 check_module_contents()
396 check_module_contents()
347
397
348 #
398 #
349 # Another failure case: deleted file (shouldn't reload)
399 # Another failure case: deleted file (shouldn't reload)
350 #
400 #
351 os.unlink(mod_fn)
401 os.unlink(mod_fn)
352
402
353 self.shell.run_code("pass") # trigger reload
403 self.shell.run_code("pass") # trigger reload
354 check_module_contents()
404 check_module_contents()
355
405
356 #
406 #
357 # Disable autoreload and rewrite module: no reload should occur
407 # Disable autoreload and rewrite module: no reload should occur
358 #
408 #
359 if use_aimport:
409 if use_aimport:
360 self.shell.magic_aimport("-" + mod_name)
410 self.shell.magic_aimport("-" + mod_name)
361 stream = StringIO()
411 stream = StringIO()
362 self.shell.magic_aimport("", stream=stream)
412 self.shell.magic_aimport("", stream=stream)
363 nt.assert_true(("Modules to skip:\n%s" % mod_name) in
413 nt.assert_true(("Modules to skip:\n%s" % mod_name) in
364 stream.getvalue())
414 stream.getvalue())
365
415
366 # This should succeed, although no such module exists
416 # This should succeed, although no such module exists
367 self.shell.magic_aimport("-tmpmod_as318989e89ds")
417 self.shell.magic_aimport("-tmpmod_as318989e89ds")
368 else:
418 else:
369 self.shell.magic_autoreload("0")
419 self.shell.magic_autoreload("0")
370
420
371 self.write_file(mod_fn, """
421 self.write_file(mod_fn, """
372 x = -99
422 x = -99
373 """)
423 """)
374
424
375 self.shell.run_code("pass") # trigger reload
425 self.shell.run_code("pass") # trigger reload
376 self.shell.run_code("pass")
426 self.shell.run_code("pass")
377 check_module_contents()
427 check_module_contents()
378
428
379 #
429 #
380 # Re-enable autoreload: reload should now occur
430 # Re-enable autoreload: reload should now occur
381 #
431 #
382 if use_aimport:
432 if use_aimport:
383 self.shell.magic_aimport(mod_name)
433 self.shell.magic_aimport(mod_name)
384 else:
434 else:
385 self.shell.magic_autoreload("")
435 self.shell.magic_autoreload("")
386
436
387 self.shell.run_code("pass") # trigger reload
437 self.shell.run_code("pass") # trigger reload
388 nt.assert_equal(mod.x, -99)
438 nt.assert_equal(mod.x, -99)
389
439
390 def test_smoketest_aimport(self):
440 def test_smoketest_aimport(self):
391 self._check_smoketest(use_aimport=True)
441 self._check_smoketest(use_aimport=True)
392
442
393 def test_smoketest_autoreload(self):
443 def test_smoketest_autoreload(self):
394 self._check_smoketest(use_aimport=False)
444 self._check_smoketest(use_aimport=False)
395
445
396
446
397
447
398
448
449
General Comments 0
You need to be logged in to leave comments. Login now