##// END OF EJS Templates
Try to fix updating classes in Autoreload....
Matthias Bussonnier -
Show More
@@ -1,586 +1,642 b''
1 """IPython extension to reload modules before executing user code.
1 """IPython extension to reload modules before executing user code.
2
2
3 ``autoreload`` reloads modules automatically before entering the execution of
3 ``autoreload`` reloads modules automatically before entering the execution of
4 code typed at the IPython prompt.
4 code typed at the IPython prompt.
5
5
6 This makes for example the following workflow possible:
6 This makes for example the following workflow possible:
7
7
8 .. sourcecode:: ipython
8 .. sourcecode:: ipython
9
9
10 In [1]: %load_ext autoreload
10 In [1]: %load_ext autoreload
11
11
12 In [2]: %autoreload 2
12 In [2]: %autoreload 2
13
13
14 In [3]: from foo import some_function
14 In [3]: from foo import some_function
15
15
16 In [4]: some_function()
16 In [4]: some_function()
17 Out[4]: 42
17 Out[4]: 42
18
18
19 In [5]: # open foo.py in an editor and change some_function to return 43
19 In [5]: # open foo.py in an editor and change some_function to return 43
20
20
21 In [6]: some_function()
21 In [6]: some_function()
22 Out[6]: 43
22 Out[6]: 43
23
23
24 The module was reloaded without reloading it explicitly, and the object
24 The module was reloaded without reloading it explicitly, and the object
25 imported with ``from foo import ...`` was also updated.
25 imported with ``from foo import ...`` was also updated.
26
26
27 Usage
27 Usage
28 =====
28 =====
29
29
30 The following magic commands are provided:
30 The following magic commands are provided:
31
31
32 ``%autoreload``
32 ``%autoreload``
33
33
34 Reload all modules (except those excluded by ``%aimport``)
34 Reload all modules (except those excluded by ``%aimport``)
35 automatically now.
35 automatically now.
36
36
37 ``%autoreload 0``
37 ``%autoreload 0``
38
38
39 Disable automatic reloading.
39 Disable automatic reloading.
40
40
41 ``%autoreload 1``
41 ``%autoreload 1``
42
42
43 Reload all modules imported with ``%aimport`` every time before
43 Reload all modules imported with ``%aimport`` every time before
44 executing the Python code typed.
44 executing the Python code typed.
45
45
46 ``%autoreload 2``
46 ``%autoreload 2``
47
47
48 Reload all modules (except those excluded by ``%aimport``) every
48 Reload all modules (except those excluded by ``%aimport``) every
49 time before executing the Python code typed.
49 time before executing the Python code typed.
50
50
51 ``%aimport``
51 ``%aimport``
52
52
53 List modules which are to be automatically imported or not to be imported.
53 List modules which are to be automatically imported or not to be imported.
54
54
55 ``%aimport foo``
55 ``%aimport foo``
56
56
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
58
58
59 ``%aimport foo, bar``
59 ``%aimport foo, bar``
60
60
61 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
61 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
62
62
63 ``%aimport -foo``
63 ``%aimport -foo``
64
64
65 Mark module 'foo' to not be autoreloaded.
65 Mark module 'foo' to not be autoreloaded.
66
66
67 Caveats
67 Caveats
68 =======
68 =======
69
69
70 Reloading Python modules in a reliable way is in general difficult,
70 Reloading Python modules in a reliable way is in general difficult,
71 and unexpected things may occur. ``%autoreload`` tries to work around
71 and unexpected things may occur. ``%autoreload`` tries to work around
72 common pitfalls by replacing function code objects and parts of
72 common pitfalls by replacing function code objects and parts of
73 classes previously in the module with new versions. This makes the
73 classes previously in the module with new versions. This makes the
74 following things to work:
74 following things to work:
75
75
76 - Functions and classes imported via 'from xxx import foo' are upgraded
76 - Functions and classes imported via 'from xxx import foo' are upgraded
77 to new versions when 'xxx' is reloaded.
77 to new versions when 'xxx' is reloaded.
78
78
79 - Methods and properties of classes are upgraded on reload, so that
79 - Methods and properties of classes are upgraded on reload, so that
80 calling 'c.foo()' on an object 'c' created before the reload causes
80 calling 'c.foo()' on an object 'c' created before the reload causes
81 the new code for 'foo' to be executed.
81 the new code for 'foo' to be executed.
82
82
83 Some of the known remaining caveats are:
83 Some of the known remaining caveats are:
84
84
85 - Replacing code objects does not always succeed: changing a @property
85 - Replacing code objects does not always succeed: changing a @property
86 in a class to an ordinary method or a method to a member variable
86 in a class to an ordinary method or a method to a member variable
87 can cause problems (but in old objects only).
87 can cause problems (but in old objects only).
88
88
89 - Functions that are removed (eg. via monkey-patching) from a module
89 - Functions that are removed (eg. via monkey-patching) from a module
90 before it is reloaded are not upgraded.
90 before it is reloaded are not upgraded.
91
91
92 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
92 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
93 """
93 """
94
94
95 skip_doctest = True
95 skip_doctest = True
96
96
97 #-----------------------------------------------------------------------------
97 #-----------------------------------------------------------------------------
98 # Copyright (C) 2000 Thomas Heller
98 # Copyright (C) 2000 Thomas Heller
99 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
99 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
100 # Copyright (C) 2012 The IPython Development Team
100 # Copyright (C) 2012 The IPython Development Team
101 #
101 #
102 # Distributed under the terms of the BSD License. The full license is in
102 # Distributed under the terms of the BSD License. The full license is in
103 # the file COPYING, distributed as part of this software.
103 # the file COPYING, distributed as part of this software.
104 #-----------------------------------------------------------------------------
104 #-----------------------------------------------------------------------------
105 #
105 #
106 # This IPython module is written by Pauli Virtanen, based on the autoreload
106 # This IPython module is written by Pauli Virtanen, based on the autoreload
107 # code by Thomas Heller.
107 # code by Thomas Heller.
108
108
109 #-----------------------------------------------------------------------------
109 #-----------------------------------------------------------------------------
110 # Imports
110 # Imports
111 #-----------------------------------------------------------------------------
111 #-----------------------------------------------------------------------------
112
112
113 import os
113 import os
114 import sys
114 import sys
115 import traceback
115 import traceback
116 import types
116 import types
117 import weakref
117 import weakref
118 import inspect
118 import inspect
119 from importlib import import_module
119 from importlib import import_module
120 from importlib.util import source_from_cache
120 from importlib.util import source_from_cache
121 from imp import reload
121 from imp import reload
122
122
123 #------------------------------------------------------------------------------
123 #------------------------------------------------------------------------------
124 # Autoreload functionality
124 # Autoreload functionality
125 #------------------------------------------------------------------------------
125 #------------------------------------------------------------------------------
126
126
127 class ModuleReloader(object):
127 class ModuleReloader(object):
128 enabled = False
128 enabled = False
129 """Whether this reloader is enabled"""
129 """Whether this reloader is enabled"""
130
130
131 check_all = True
131 check_all = True
132 """Autoreload all modules, not just those listed in 'modules'"""
132 """Autoreload all modules, not just those listed in 'modules'"""
133
133
134 def __init__(self):
134 def __init__(self):
135 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
135 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
136 self.failed = {}
136 self.failed = {}
137 # Modules specially marked as autoreloadable.
137 # Modules specially marked as autoreloadable.
138 self.modules = {}
138 self.modules = {}
139 # Modules specially marked as not autoreloadable.
139 # Modules specially marked as not autoreloadable.
140 self.skip_modules = {}
140 self.skip_modules = {}
141 # (module-name, name) -> weakref, for replacing old code objects
141 # (module-name, name) -> weakref, for replacing old code objects
142 self.old_objects = {}
142 self.old_objects = {}
143 # Module modification timestamps
143 # Module modification timestamps
144 self.modules_mtimes = {}
144 self.modules_mtimes = {}
145
145
146 # Cache module modification times
146 # Cache module modification times
147 self.check(check_all=True, do_reload=False)
147 self.check(check_all=True, do_reload=False)
148
148
149 def mark_module_skipped(self, module_name):
149 def mark_module_skipped(self, module_name):
150 """Skip reloading the named module in the future"""
150 """Skip reloading the named module in the future"""
151 try:
151 try:
152 del self.modules[module_name]
152 del self.modules[module_name]
153 except KeyError:
153 except KeyError:
154 pass
154 pass
155 self.skip_modules[module_name] = True
155 self.skip_modules[module_name] = True
156
156
157 def mark_module_reloadable(self, module_name):
157 def mark_module_reloadable(self, module_name):
158 """Reload the named module in the future (if it is imported)"""
158 """Reload the named module in the future (if it is imported)"""
159 try:
159 try:
160 del self.skip_modules[module_name]
160 del self.skip_modules[module_name]
161 except KeyError:
161 except KeyError:
162 pass
162 pass
163 self.modules[module_name] = True
163 self.modules[module_name] = True
164
164
165 def aimport_module(self, module_name):
165 def aimport_module(self, module_name):
166 """Import a module, and mark it reloadable
166 """Import a module, and mark it reloadable
167
167
168 Returns
168 Returns
169 -------
169 -------
170 top_module : module
170 top_module : module
171 The imported module if it is top-level, or the top-level
171 The imported module if it is top-level, or the top-level
172 top_name : module
172 top_name : module
173 Name of top_module
173 Name of top_module
174
174
175 """
175 """
176 self.mark_module_reloadable(module_name)
176 self.mark_module_reloadable(module_name)
177
177
178 import_module(module_name)
178 import_module(module_name)
179 top_name = module_name.split('.')[0]
179 top_name = module_name.split('.')[0]
180 top_module = sys.modules[top_name]
180 top_module = sys.modules[top_name]
181 return top_module, top_name
181 return top_module, top_name
182
182
183 def filename_and_mtime(self, module):
183 def filename_and_mtime(self, module):
184 if not hasattr(module, '__file__') or module.__file__ is None:
184 if not hasattr(module, '__file__') or module.__file__ is None:
185 return None, None
185 return None, None
186
186
187 if getattr(module, '__name__', None) in [None, '__mp_main__', '__main__']:
187 if getattr(module, '__name__', None) in [None, '__mp_main__', '__main__']:
188 # we cannot reload(__main__) or reload(__mp_main__)
188 # we cannot reload(__main__) or reload(__mp_main__)
189 return None, None
189 return None, None
190
190
191 filename = module.__file__
191 filename = module.__file__
192 path, ext = os.path.splitext(filename)
192 path, ext = os.path.splitext(filename)
193
193
194 if ext.lower() == '.py':
194 if ext.lower() == '.py':
195 py_filename = filename
195 py_filename = filename
196 else:
196 else:
197 try:
197 try:
198 py_filename = source_from_cache(filename)
198 py_filename = source_from_cache(filename)
199 except ValueError:
199 except ValueError:
200 return None, None
200 return None, None
201
201
202 try:
202 try:
203 pymtime = os.stat(py_filename).st_mtime
203 pymtime = os.stat(py_filename).st_mtime
204 except OSError:
204 except OSError:
205 return None, None
205 return None, None
206
206
207 return py_filename, pymtime
207 return py_filename, pymtime
208
208
209 def check(self, check_all=False, do_reload=True):
209 def check(self, check_all=False, do_reload=True):
210 """Check whether some modules need to be reloaded."""
210 """Check whether some modules need to be reloaded."""
211
211
212 if not self.enabled and not check_all:
212 if not self.enabled and not check_all:
213 return
213 return
214
214
215 if check_all or self.check_all:
215 if check_all or self.check_all:
216 modules = list(sys.modules.keys())
216 modules = list(sys.modules.keys())
217 else:
217 else:
218 modules = list(self.modules.keys())
218 modules = list(self.modules.keys())
219
219
220 for modname in modules:
220 for modname in modules:
221 m = sys.modules.get(modname, None)
221 m = sys.modules.get(modname, None)
222
222
223 if modname in self.skip_modules:
223 if modname in self.skip_modules:
224 continue
224 continue
225
225
226 py_filename, pymtime = self.filename_and_mtime(m)
226 py_filename, pymtime = self.filename_and_mtime(m)
227 if py_filename is None:
227 if py_filename is None:
228 continue
228 continue
229
229
230 try:
230 try:
231 if pymtime <= self.modules_mtimes[modname]:
231 if pymtime <= self.modules_mtimes[modname]:
232 continue
232 continue
233 except KeyError:
233 except KeyError:
234 self.modules_mtimes[modname] = pymtime
234 self.modules_mtimes[modname] = pymtime
235 continue
235 continue
236 else:
236 else:
237 if self.failed.get(py_filename, None) == pymtime:
237 if self.failed.get(py_filename, None) == pymtime:
238 continue
238 continue
239
239
240 self.modules_mtimes[modname] = pymtime
240 self.modules_mtimes[modname] = pymtime
241
241
242 # If we've reached this point, we should try to reload the module
242 # If we've reached this point, we should try to reload the module
243 if do_reload:
243 if do_reload:
244 try:
244 try:
245 superreload(m, reload, self.old_objects)
245 superreload(m, reload, self.old_objects)
246 if py_filename in self.failed:
246 if py_filename in self.failed:
247 del self.failed[py_filename]
247 del self.failed[py_filename]
248 except:
248 except:
249 print("[autoreload of %s failed: %s]" % (
249 print("[autoreload of %s failed: %s]" % (
250 modname, traceback.format_exc(10)), file=sys.stderr)
250 modname, traceback.format_exc(10)), file=sys.stderr)
251 self.failed[py_filename] = pymtime
251 self.failed[py_filename] = pymtime
252
252
253 #------------------------------------------------------------------------------
253 #------------------------------------------------------------------------------
254 # superreload
254 # superreload
255 #------------------------------------------------------------------------------
255 #------------------------------------------------------------------------------
256
256
257
257
258 func_attrs = ['__code__', '__defaults__', '__doc__',
258 func_attrs = ['__code__', '__defaults__', '__doc__',
259 '__closure__', '__globals__', '__dict__']
259 '__closure__', '__globals__', '__dict__']
260
260
261
261
262 def update_function(old, new):
262 def update_function(old, new):
263 """Upgrade the code object of a function"""
263 """Upgrade the code object of a function"""
264 for name in func_attrs:
264 for name in func_attrs:
265 try:
265 try:
266 setattr(old, name, getattr(new, name))
266 setattr(old, name, getattr(new, name))
267 except (AttributeError, TypeError):
267 except (AttributeError, TypeError):
268 pass
268 pass
269
269
270
270
271 def _find_instances(old_type):
272 """Try to find all instances of a class that need updating.
273
274 Classic graph exploration, we want to avoid re-visiting object multiple times.
275 """
276 # find ipython workspace stack frame, this is just to bootstrap where we
277 # find the object that need updating.
278 frame = next(frame_nfo.frame for frame_nfo in inspect.stack()
279 if 'trigger' in frame_nfo.function)
280 # build generator for non-private variable values from workspace
281 shell = frame.f_locals['self'].shell
282 user_ns = shell.user_ns
283 user_ns_hidden = shell.user_ns_hidden
284 nonmatching = object()
285 objects = ( value for key, value in user_ns.items()
286 if not key.startswith('_')
287 and (value is not user_ns_hidden.get(key, nonmatching))
288 and not inspect.ismodule(value))
289
290 # note: in the following we do use dict as object might not be hashable.
291 # list of objects we found that will need an update.
292 to_update = {}
293
294 # list of object we have not recursed into yet
295 open_set = {}
296
297 # list of object we have visited already
298 closed_set = {}
299
300 open_set.update({id(o):o for o in objects})
301
302 it = 0
303 while len(open_set) > 0:
304 it += 1
305 if it > 100_000:
306 raise ValueError('infinite')
307 (current_id,current) = next(iter(open_set.items()))
308 if type(current) is old_type:
309 to_update[current_id] = current
310 if hasattr(current, '__dict__') and not (inspect.isfunction(current)
311 or inspect.ismethod(current)):
312 potential_new = {id(o):o for o in current.__dict__.values() if id(o) not in closed_set.keys()}
313 open_set.update(potential_new)
314 # if object is a container, search it
315 if hasattr(current, 'items') or (hasattr(current, '__contains__')
316 and not isinstance(current, str)):
317 potential_new = (value for key, value in current.items()
318 if not str(key).startswith('_')
319 and not inspect.ismodule(value) and not id(value) in closed_set.keys())
320 open_set.update(potential_new)
321 del open_set[id(current)]
322 closed_set[id(current)] = current
323 return to_update.values()
324
271 def update_instances(old, new, objects=None):
325 def update_instances(old, new, objects=None):
272 """Iterate through objects recursively, searching for instances of old and
326 """Iterate through objects recursively, searching for instances of old and
273 replace their __class__ reference with new. If no objects are given, start
327 replace their __class__ reference with new. If no objects are given, start
274 with the current ipython workspace.
328 with the current ipython workspace.
275 """
329 """
276 if not objects:
330 if not objects:
277 # find ipython workspace stack frame
331 # find ipython workspace stack frame
278 frame = next(frame_nfo.frame for frame_nfo in inspect.stack()
332 frame = next(frame_nfo.frame for frame_nfo in inspect.stack()
279 if 'trigger' in frame_nfo.function)
333 if 'trigger' in frame_nfo.function)
280 # build generator for non-private variable values from workspace
334 # build generator for non-private variable values from workspace
281 shell = frame.f_locals['self'].shell
335 shell = frame.f_locals['self'].shell
282 user_ns = shell.user_ns
336 user_ns = shell.user_ns
283 user_ns_hidden = shell.user_ns_hidden
337 user_ns_hidden = shell.user_ns_hidden
284 nonmatching = object()
338 nonmatching = object()
285 objects = ( value for key, value in user_ns.items()
339 objects = ( value for key, value in user_ns.items()
286 if not key.startswith('_')
340 if not key.startswith('_')
287 and (value is not user_ns_hidden.get(key, nonmatching))
341 and (value is not user_ns_hidden.get(key, nonmatching))
288 and not inspect.ismodule(value))
342 and not inspect.ismodule(value))
289
343
290 # use dict values if objects is a dict but don't touch private variables
344 # use dict values if objects is a dict but don't touch private variables
291 if hasattr(objects, 'items'):
345 if hasattr(objects, 'items'):
292 objects = (value for key, value in objects.items()
346 objects = (value for key, value in objects.items()
293 if not str(key).startswith('_')
347 if not str(key).startswith('_')
294 and not inspect.ismodule(value) )
348 and not inspect.ismodule(value) )
295
349
296 # try if objects is iterable
350 # try if objects is iterable
297 try:
351 try:
298 for obj in objects:
352 for obj in objects:
299
353
300 # update, if object is instance of old_class (but no subclasses)
354 # update, if object is instance of old_class (but no subclasses)
301 if type(obj) is old:
355 if type(obj) is old:
302 obj.__class__ = new
356 obj.__class__ = new
303
357
304
358
305 # if object is instance of other class, look for nested instances
359 # if object is instance of other class, look for nested instances
306 if hasattr(obj, '__dict__') and not (inspect.isfunction(obj)
360 if hasattr(obj, '__dict__') and not (inspect.isfunction(obj)
307 or inspect.ismethod(obj)):
361 or inspect.ismethod(obj)):
308 update_instances(old, new, obj.__dict__)
362 update_instances(old, new, obj.__dict__)
309
363
310 # if object is a container, search it
364 # if object is a container, search it
311 if hasattr(obj, 'items') or (hasattr(obj, '__contains__')
365 if hasattr(obj, 'items') or (hasattr(obj, '__contains__')
312 and not isinstance(obj, str)):
366 and not isinstance(obj, str)):
313 update_instances(old, new, obj)
367 update_instances(old, new, obj)
314
368
315 except TypeError:
369 except TypeError:
316 pass
370 pass
317
371
318
372
319 def update_class(old, new):
373 def update_class(old, new):
320 """Replace stuff in the __dict__ of a class, and upgrade
374 """Replace stuff in the __dict__ of a class, and upgrade
321 method code objects, and add new methods, if any"""
375 method code objects, and add new methods, if any"""
376 print('old is', old)
322 for key in list(old.__dict__.keys()):
377 for key in list(old.__dict__.keys()):
323 old_obj = getattr(old, key)
378 old_obj = getattr(old, key)
324 try:
379 try:
325 new_obj = getattr(new, key)
380 new_obj = getattr(new, key)
326 # explicitly checking that comparison returns True to handle
381 # explicitly checking that comparison returns True to handle
327 # cases where `==` doesn't return a boolean.
382 # cases where `==` doesn't return a boolean.
328 if (old_obj == new_obj) is True:
383 if (old_obj == new_obj) is True:
329 continue
384 continue
330 except AttributeError:
385 except AttributeError:
331 # obsolete attribute: remove it
386 # obsolete attribute: remove it
332 try:
387 try:
333 delattr(old, key)
388 delattr(old, key)
334 except (AttributeError, TypeError):
389 except (AttributeError, TypeError):
335 pass
390 pass
336 continue
391 continue
337
392
338 if update_generic(old_obj, new_obj): continue
393 if update_generic(old_obj, new_obj): continue
339
394
340 try:
395 try:
341 setattr(old, key, getattr(new, key))
396 setattr(old, key, getattr(new, key))
342 except (AttributeError, TypeError):
397 except (AttributeError, TypeError):
343 pass # skip non-writable attributes
398 pass # skip non-writable attributes
344
399
345 for key in list(new.__dict__.keys()):
400 for key in list(new.__dict__.keys()):
346 if key not in list(old.__dict__.keys()):
401 if key not in list(old.__dict__.keys()):
347 try:
402 try:
348 setattr(old, key, getattr(new, key))
403 setattr(old, key, getattr(new, key))
349 except (AttributeError, TypeError):
404 except (AttributeError, TypeError):
350 pass # skip non-writable attributes
405 pass # skip non-writable attributes
351
406
352 # update all instances of class
407 # update all instances of class
353 update_instances(old, new)
408 for instance in _find_instances(old):
409 instance.__class__ = new
354
410
355
411
356 def update_property(old, new):
412 def update_property(old, new):
357 """Replace get/set/del functions of a property"""
413 """Replace get/set/del functions of a property"""
358 update_generic(old.fdel, new.fdel)
414 update_generic(old.fdel, new.fdel)
359 update_generic(old.fget, new.fget)
415 update_generic(old.fget, new.fget)
360 update_generic(old.fset, new.fset)
416 update_generic(old.fset, new.fset)
361
417
362
418
363 def isinstance2(a, b, typ):
419 def isinstance2(a, b, typ):
364 return isinstance(a, typ) and isinstance(b, typ)
420 return isinstance(a, typ) and isinstance(b, typ)
365
421
366
422
367 UPDATE_RULES = [
423 UPDATE_RULES = [
368 (lambda a, b: isinstance2(a, b, type),
424 (lambda a, b: isinstance2(a, b, type),
369 update_class),
425 update_class),
370 (lambda a, b: isinstance2(a, b, types.FunctionType),
426 (lambda a, b: isinstance2(a, b, types.FunctionType),
371 update_function),
427 update_function),
372 (lambda a, b: isinstance2(a, b, property),
428 (lambda a, b: isinstance2(a, b, property),
373 update_property),
429 update_property),
374 ]
430 ]
375 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
431 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
376 lambda a, b: update_function(a.__func__, b.__func__)),
432 lambda a, b: update_function(a.__func__, b.__func__)),
377 ])
433 ])
378
434
379
435
380 def update_generic(a, b):
436 def update_generic(a, b):
381 for type_check, update in UPDATE_RULES:
437 for type_check, update in UPDATE_RULES:
382 if type_check(a, b):
438 if type_check(a, b):
383 update(a, b)
439 update(a, b)
384 return True
440 return True
385 return False
441 return False
386
442
387
443
388 class StrongRef(object):
444 class StrongRef(object):
389 def __init__(self, obj):
445 def __init__(self, obj):
390 self.obj = obj
446 self.obj = obj
391 def __call__(self):
447 def __call__(self):
392 return self.obj
448 return self.obj
393
449
394
450
395 def superreload(module, reload=reload, old_objects=None):
451 def superreload(module, reload=reload, old_objects=None):
396 """Enhanced version of the builtin reload function.
452 """Enhanced version of the builtin reload function.
397
453
398 superreload remembers objects previously in the module, and
454 superreload remembers objects previously in the module, and
399
455
400 - upgrades the class dictionary of every old class in the module
456 - upgrades the class dictionary of every old class in the module
401 - upgrades the code object of every old function and method
457 - upgrades the code object of every old function and method
402 - clears the module's namespace before reloading
458 - clears the module's namespace before reloading
403
459
404 """
460 """
405 if old_objects is None:
461 if old_objects is None:
406 old_objects = {}
462 old_objects = {}
407
463
408 # collect old objects in the module
464 # collect old objects in the module
409 for name, obj in list(module.__dict__.items()):
465 for name, obj in list(module.__dict__.items()):
410 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
466 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
411 continue
467 continue
412 key = (module.__name__, name)
468 key = (module.__name__, name)
413 try:
469 try:
414 old_objects.setdefault(key, []).append(weakref.ref(obj))
470 old_objects.setdefault(key, []).append(weakref.ref(obj))
415 except TypeError:
471 except TypeError:
416 pass
472 pass
417
473
418 # reload module
474 # reload module
419 try:
475 try:
420 # clear namespace first from old cruft
476 # clear namespace first from old cruft
421 old_dict = module.__dict__.copy()
477 old_dict = module.__dict__.copy()
422 old_name = module.__name__
478 old_name = module.__name__
423 module.__dict__.clear()
479 module.__dict__.clear()
424 module.__dict__['__name__'] = old_name
480 module.__dict__['__name__'] = old_name
425 module.__dict__['__loader__'] = old_dict['__loader__']
481 module.__dict__['__loader__'] = old_dict['__loader__']
426 except (TypeError, AttributeError, KeyError):
482 except (TypeError, AttributeError, KeyError):
427 pass
483 pass
428
484
429 try:
485 try:
430 module = reload(module)
486 module = reload(module)
431 except:
487 except:
432 # restore module dictionary on failed reload
488 # restore module dictionary on failed reload
433 module.__dict__.update(old_dict)
489 module.__dict__.update(old_dict)
434 raise
490 raise
435
491
436 # iterate over all objects and update functions & classes
492 # iterate over all objects and update functions & classes
437 for name, new_obj in list(module.__dict__.items()):
493 for name, new_obj in list(module.__dict__.items()):
438 key = (module.__name__, name)
494 key = (module.__name__, name)
439 if key not in old_objects: continue
495 if key not in old_objects: continue
440
496
441 new_refs = []
497 new_refs = []
442 for old_ref in old_objects[key]:
498 for old_ref in old_objects[key]:
443 old_obj = old_ref()
499 old_obj = old_ref()
444 if old_obj is None: continue
500 if old_obj is None: continue
445 new_refs.append(old_ref)
501 new_refs.append(old_ref)
446 update_generic(old_obj, new_obj)
502 update_generic(old_obj, new_obj)
447
503
448 if new_refs:
504 if new_refs:
449 old_objects[key] = new_refs
505 old_objects[key] = new_refs
450 else:
506 else:
451 del old_objects[key]
507 del old_objects[key]
452
508
453 return module
509 return module
454
510
455 #------------------------------------------------------------------------------
511 #------------------------------------------------------------------------------
456 # IPython connectivity
512 # IPython connectivity
457 #------------------------------------------------------------------------------
513 #------------------------------------------------------------------------------
458
514
459 from IPython.core.magic import Magics, magics_class, line_magic
515 from IPython.core.magic import Magics, magics_class, line_magic
460
516
461 @magics_class
517 @magics_class
462 class AutoreloadMagics(Magics):
518 class AutoreloadMagics(Magics):
463 def __init__(self, *a, **kw):
519 def __init__(self, *a, **kw):
464 super(AutoreloadMagics, self).__init__(*a, **kw)
520 super(AutoreloadMagics, self).__init__(*a, **kw)
465 self._reloader = ModuleReloader()
521 self._reloader = ModuleReloader()
466 self._reloader.check_all = False
522 self._reloader.check_all = False
467 self.loaded_modules = set(sys.modules)
523 self.loaded_modules = set(sys.modules)
468
524
469 @line_magic
525 @line_magic
470 def autoreload(self, parameter_s=''):
526 def autoreload(self, parameter_s=''):
471 r"""%autoreload => Reload modules automatically
527 r"""%autoreload => Reload modules automatically
472
528
473 %autoreload
529 %autoreload
474 Reload all modules (except those excluded by %aimport) automatically
530 Reload all modules (except those excluded by %aimport) automatically
475 now.
531 now.
476
532
477 %autoreload 0
533 %autoreload 0
478 Disable automatic reloading.
534 Disable automatic reloading.
479
535
480 %autoreload 1
536 %autoreload 1
481 Reload all modules imported with %aimport every time before executing
537 Reload all modules imported with %aimport every time before executing
482 the Python code typed.
538 the Python code typed.
483
539
484 %autoreload 2
540 %autoreload 2
485 Reload all modules (except those excluded by %aimport) every time
541 Reload all modules (except those excluded by %aimport) every time
486 before executing the Python code typed.
542 before executing the Python code typed.
487
543
488 Reloading Python modules in a reliable way is in general
544 Reloading Python modules in a reliable way is in general
489 difficult, and unexpected things may occur. %autoreload tries to
545 difficult, and unexpected things may occur. %autoreload tries to
490 work around common pitfalls by replacing function code objects and
546 work around common pitfalls by replacing function code objects and
491 parts of classes previously in the module with new versions. This
547 parts of classes previously in the module with new versions. This
492 makes the following things to work:
548 makes the following things to work:
493
549
494 - Functions and classes imported via 'from xxx import foo' are upgraded
550 - Functions and classes imported via 'from xxx import foo' are upgraded
495 to new versions when 'xxx' is reloaded.
551 to new versions when 'xxx' is reloaded.
496
552
497 - Methods and properties of classes are upgraded on reload, so that
553 - Methods and properties of classes are upgraded on reload, so that
498 calling 'c.foo()' on an object 'c' created before the reload causes
554 calling 'c.foo()' on an object 'c' created before the reload causes
499 the new code for 'foo' to be executed.
555 the new code for 'foo' to be executed.
500
556
501 Some of the known remaining caveats are:
557 Some of the known remaining caveats are:
502
558
503 - Replacing code objects does not always succeed: changing a @property
559 - Replacing code objects does not always succeed: changing a @property
504 in a class to an ordinary method or a method to a member variable
560 in a class to an ordinary method or a method to a member variable
505 can cause problems (but in old objects only).
561 can cause problems (but in old objects only).
506
562
507 - Functions that are removed (eg. via monkey-patching) from a module
563 - Functions that are removed (eg. via monkey-patching) from a module
508 before it is reloaded are not upgraded.
564 before it is reloaded are not upgraded.
509
565
510 - C extension modules cannot be reloaded, and so cannot be
566 - C extension modules cannot be reloaded, and so cannot be
511 autoreloaded.
567 autoreloaded.
512
568
513 """
569 """
514 if parameter_s == '':
570 if parameter_s == '':
515 self._reloader.check(True)
571 self._reloader.check(True)
516 elif parameter_s == '0':
572 elif parameter_s == '0':
517 self._reloader.enabled = False
573 self._reloader.enabled = False
518 elif parameter_s == '1':
574 elif parameter_s == '1':
519 self._reloader.check_all = False
575 self._reloader.check_all = False
520 self._reloader.enabled = True
576 self._reloader.enabled = True
521 elif parameter_s == '2':
577 elif parameter_s == '2':
522 self._reloader.check_all = True
578 self._reloader.check_all = True
523 self._reloader.enabled = True
579 self._reloader.enabled = True
524
580
525 @line_magic
581 @line_magic
526 def aimport(self, parameter_s='', stream=None):
582 def aimport(self, parameter_s='', stream=None):
527 """%aimport => Import modules for automatic reloading.
583 """%aimport => Import modules for automatic reloading.
528
584
529 %aimport
585 %aimport
530 List modules to automatically import and not to import.
586 List modules to automatically import and not to import.
531
587
532 %aimport foo
588 %aimport foo
533 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
589 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
534
590
535 %aimport foo, bar
591 %aimport foo, bar
536 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
592 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
537
593
538 %aimport -foo
594 %aimport -foo
539 Mark module 'foo' to not be autoreloaded for %autoreload 1
595 Mark module 'foo' to not be autoreloaded for %autoreload 1
540 """
596 """
541 modname = parameter_s
597 modname = parameter_s
542 if not modname:
598 if not modname:
543 to_reload = sorted(self._reloader.modules.keys())
599 to_reload = sorted(self._reloader.modules.keys())
544 to_skip = sorted(self._reloader.skip_modules.keys())
600 to_skip = sorted(self._reloader.skip_modules.keys())
545 if stream is None:
601 if stream is None:
546 stream = sys.stdout
602 stream = sys.stdout
547 if self._reloader.check_all:
603 if self._reloader.check_all:
548 stream.write("Modules to reload:\nall-except-skipped\n")
604 stream.write("Modules to reload:\nall-except-skipped\n")
549 else:
605 else:
550 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
606 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
551 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
607 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
552 elif modname.startswith('-'):
608 elif modname.startswith('-'):
553 modname = modname[1:]
609 modname = modname[1:]
554 self._reloader.mark_module_skipped(modname)
610 self._reloader.mark_module_skipped(modname)
555 else:
611 else:
556 for _module in ([_.strip() for _ in modname.split(',')]):
612 for _module in ([_.strip() for _ in modname.split(',')]):
557 top_module, top_name = self._reloader.aimport_module(_module)
613 top_module, top_name = self._reloader.aimport_module(_module)
558
614
559 # Inject module to user namespace
615 # Inject module to user namespace
560 self.shell.push({top_name: top_module})
616 self.shell.push({top_name: top_module})
561
617
562 def pre_run_cell(self):
618 def pre_run_cell(self):
563 if self._reloader.enabled:
619 if self._reloader.enabled:
564 try:
620 try:
565 self._reloader.check()
621 self._reloader.check()
566 except:
622 except:
567 pass
623 pass
568
624
569 def post_execute_hook(self):
625 def post_execute_hook(self):
570 """Cache the modification times of any modules imported in this execution
626 """Cache the modification times of any modules imported in this execution
571 """
627 """
572 newly_loaded_modules = set(sys.modules) - self.loaded_modules
628 newly_loaded_modules = set(sys.modules) - self.loaded_modules
573 for modname in newly_loaded_modules:
629 for modname in newly_loaded_modules:
574 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
630 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
575 if pymtime is not None:
631 if pymtime is not None:
576 self._reloader.modules_mtimes[modname] = pymtime
632 self._reloader.modules_mtimes[modname] = pymtime
577
633
578 self.loaded_modules.update(newly_loaded_modules)
634 self.loaded_modules.update(newly_loaded_modules)
579
635
580
636
581 def load_ipython_extension(ip):
637 def load_ipython_extension(ip):
582 """Load the extension in IPython."""
638 """Load the extension in IPython."""
583 auto_reload = AutoreloadMagics(ip)
639 auto_reload = AutoreloadMagics(ip)
584 ip.register_magics(auto_reload)
640 ip.register_magics(auto_reload)
585 ip.events.register('pre_run_cell', auto_reload.pre_run_cell)
641 ip.events.register('pre_run_cell', auto_reload.pre_run_cell)
586 ip.events.register('post_execute', auto_reload.post_execute_hook)
642 ip.events.register('post_execute', auto_reload.post_execute_hook)
@@ -1,398 +1,400 b''
1 """Tests for autoreload extension.
1 """Tests for autoreload extension.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (c) 2012 IPython Development Team.
4 # Copyright (c) 2012 IPython Development Team.
5 #
5 #
6 # Distributed under the terms of the Modified BSD License.
6 # Distributed under the terms of the Modified BSD License.
7 #
7 #
8 # The full license is in the file COPYING.txt, distributed with this software.
8 # The full license is in the file COPYING.txt, distributed with this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 import os
15 import os
16 import sys
16 import sys
17 import tempfile
17 import tempfile
18 import textwrap
18 import textwrap
19 import shutil
19 import shutil
20 import random
20 import random
21 import time
21 import time
22 from io import StringIO
22 from io import StringIO
23
23
24 import nose.tools as nt
24 import nose.tools as nt
25 import IPython.testing.tools as tt
25 import IPython.testing.tools as tt
26
26
27 from IPython.testing.decorators import skipif
27 from IPython.testing.decorators import skipif
28
28
29 from IPython.extensions.autoreload import AutoreloadMagics
29 from IPython.extensions.autoreload import AutoreloadMagics
30 from IPython.core.events import EventManager, pre_run_cell
30 from IPython.core.events import EventManager, pre_run_cell
31
31
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33 # Test fixture
33 # Test fixture
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35
35
36 noop = lambda *a, **kw: None
36 noop = lambda *a, **kw: None
37
37
38 class FakeShell(object):
38 class FakeShell:
39
39
40 def __init__(self):
40 def __init__(self):
41 self.ns = {}
41 self.ns = {}
42 self.user_ns = {}
43 self.user_ns_hidden = {}
42 self.events = EventManager(self, {'pre_run_cell', pre_run_cell})
44 self.events = EventManager(self, {'pre_run_cell', pre_run_cell})
43 self.auto_magics = AutoreloadMagics(shell=self)
45 self.auto_magics = AutoreloadMagics(shell=self)
44 self.events.register('pre_run_cell', self.auto_magics.pre_run_cell)
46 self.events.register('pre_run_cell', self.auto_magics.pre_run_cell)
45
47
46 register_magics = set_hook = noop
48 register_magics = set_hook = noop
47
49
48 def run_code(self, code):
50 def run_code(self, code):
49 self.events.trigger('pre_run_cell')
51 self.events.trigger('pre_run_cell')
50 exec(code, self.ns)
52 exec(code, self.ns)
51 self.auto_magics.post_execute_hook()
53 self.auto_magics.post_execute_hook()
52
54
53 def push(self, items):
55 def push(self, items):
54 self.ns.update(items)
56 self.ns.update(items)
55
57
56 def magic_autoreload(self, parameter):
58 def magic_autoreload(self, parameter):
57 self.auto_magics.autoreload(parameter)
59 self.auto_magics.autoreload(parameter)
58
60
59 def magic_aimport(self, parameter, stream=None):
61 def magic_aimport(self, parameter, stream=None):
60 self.auto_magics.aimport(parameter, stream=stream)
62 self.auto_magics.aimport(parameter, stream=stream)
61 self.auto_magics.post_execute_hook()
63 self.auto_magics.post_execute_hook()
62
64
63
65
64 class Fixture(object):
66 class Fixture(object):
65 """Fixture for creating test module files"""
67 """Fixture for creating test module files"""
66
68
67 test_dir = None
69 test_dir = None
68 old_sys_path = None
70 old_sys_path = None
69 filename_chars = "abcdefghijklmopqrstuvwxyz0123456789"
71 filename_chars = "abcdefghijklmopqrstuvwxyz0123456789"
70
72
71 def setUp(self):
73 def setUp(self):
72 self.test_dir = tempfile.mkdtemp()
74 self.test_dir = tempfile.mkdtemp()
73 self.old_sys_path = list(sys.path)
75 self.old_sys_path = list(sys.path)
74 sys.path.insert(0, self.test_dir)
76 sys.path.insert(0, self.test_dir)
75 self.shell = FakeShell()
77 self.shell = FakeShell()
76
78
77 def tearDown(self):
79 def tearDown(self):
78 shutil.rmtree(self.test_dir)
80 shutil.rmtree(self.test_dir)
79 sys.path = self.old_sys_path
81 sys.path = self.old_sys_path
80
82
81 self.test_dir = None
83 self.test_dir = None
82 self.old_sys_path = None
84 self.old_sys_path = None
83 self.shell = None
85 self.shell = None
84
86
85 def get_module(self):
87 def get_module(self):
86 module_name = "tmpmod_" + "".join(random.sample(self.filename_chars,20))
88 module_name = "tmpmod_" + "".join(random.sample(self.filename_chars,20))
87 if module_name in sys.modules:
89 if module_name in sys.modules:
88 del sys.modules[module_name]
90 del sys.modules[module_name]
89 file_name = os.path.join(self.test_dir, module_name + ".py")
91 file_name = os.path.join(self.test_dir, module_name + ".py")
90 return module_name, file_name
92 return module_name, file_name
91
93
92 def write_file(self, filename, content):
94 def write_file(self, filename, content):
93 """
95 """
94 Write a file, and force a timestamp difference of at least one second
96 Write a file, and force a timestamp difference of at least one second
95
97
96 Notes
98 Notes
97 -----
99 -----
98 Python's .pyc files record the timestamp of their compilation
100 Python's .pyc files record the timestamp of their compilation
99 with a time resolution of one second.
101 with a time resolution of one second.
100
102
101 Therefore, we need to force a timestamp difference between .py
103 Therefore, we need to force a timestamp difference between .py
102 and .pyc, without having the .py file be timestamped in the
104 and .pyc, without having the .py file be timestamped in the
103 future, and without changing the timestamp of the .pyc file
105 future, and without changing the timestamp of the .pyc file
104 (because that is stored in the file). The only reliable way
106 (because that is stored in the file). The only reliable way
105 to achieve this seems to be to sleep.
107 to achieve this seems to be to sleep.
106 """
108 """
107
109
108 # Sleep one second + eps
110 # Sleep one second + eps
109 time.sleep(1.05)
111 time.sleep(1.05)
110
112
111 # Write
113 # Write
112 with open(filename, 'w') as f:
114 with open(filename, 'w') as f:
113 f.write(content)
115 f.write(content)
114
116
115 def new_module(self, code):
117 def new_module(self, code):
116 mod_name, mod_fn = self.get_module()
118 mod_name, mod_fn = self.get_module()
117 with open(mod_fn, 'w') as f:
119 with open(mod_fn, 'w') as f:
118 f.write(code)
120 f.write(code)
119 return mod_name, mod_fn
121 return mod_name, mod_fn
120
122
121 #-----------------------------------------------------------------------------
123 #-----------------------------------------------------------------------------
122 # Test automatic reloading
124 # Test automatic reloading
123 #-----------------------------------------------------------------------------
125 #-----------------------------------------------------------------------------
124
126
125 class TestAutoreload(Fixture):
127 class TestAutoreload(Fixture):
126
128
127 @skipif(sys.version_info < (3, 6))
129 @skipif(sys.version_info < (3, 6))
128 def test_reload_enums(self):
130 def test_reload_enums(self):
129 import enum
131 import enum
130 mod_name, mod_fn = self.new_module(textwrap.dedent("""
132 mod_name, mod_fn = self.new_module(textwrap.dedent("""
131 from enum import Enum
133 from enum import Enum
132 class MyEnum(Enum):
134 class MyEnum(Enum):
133 A = 'A'
135 A = 'A'
134 B = 'B'
136 B = 'B'
135 """))
137 """))
136 self.shell.magic_autoreload("2")
138 self.shell.magic_autoreload("2")
137 self.shell.magic_aimport(mod_name)
139 self.shell.magic_aimport(mod_name)
138 self.write_file(mod_fn, textwrap.dedent("""
140 self.write_file(mod_fn, textwrap.dedent("""
139 from enum import Enum
141 from enum import Enum
140 class MyEnum(Enum):
142 class MyEnum(Enum):
141 A = 'A'
143 A = 'A'
142 B = 'B'
144 B = 'B'
143 C = 'C'
145 C = 'C'
144 """))
146 """))
145 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
147 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
146 self.shell.run_code("pass") # trigger another reload
148 self.shell.run_code("pass") # trigger another reload
147
149
148 def test_reload_class_attributes(self):
150 def test_reload_class_attributes(self):
149 self.shell.magic_autoreload("2")
151 self.shell.magic_autoreload("2")
150 mod_name, mod_fn = self.new_module(textwrap.dedent("""
152 mod_name, mod_fn = self.new_module(textwrap.dedent("""
151 class MyClass:
153 class MyClass:
152
154
153 def __init__(self, a=10):
155 def __init__(self, a=10):
154 self.a = a
156 self.a = a
155 self.b = 22
157 self.b = 22
156 # self.toto = 33
158 # self.toto = 33
157
159
158 def square(self):
160 def square(self):
159 print('compute square')
161 print('compute square')
160 return self.a*self.a
162 return self.a*self.a
161 """
163 """
162 )
164 )
163 )
165 )
164 self.shell.run_code("from %s import MyClass" % mod_name)
166 self.shell.run_code("from %s import MyClass" % mod_name)
165 self.shell.run_code("first = MyClass(5)")
167 self.shell.run_code("first = MyClass(5)")
166 self.shell.run_code("first.square()")
168 self.shell.run_code("first.square()")
167 with nt.assert_raises(AttributeError):
169 with nt.assert_raises(AttributeError):
168 self.shell.run_code("first.cube()")
170 self.shell.run_code("first.cube()")
169 with nt.assert_raises(AttributeError):
171 with nt.assert_raises(AttributeError):
170 self.shell.run_code("first.power(5)")
172 self.shell.run_code("first.power(5)")
171 self.shell.run_code("first.b")
173 self.shell.run_code("first.b")
172 with nt.assert_raises(AttributeError):
174 with nt.assert_raises(AttributeError):
173 self.shell.run_code("first.toto")
175 self.shell.run_code("first.toto")
174
176
175 # remove square, add power
177 # remove square, add power
176
178
177 self.write_file(
179 self.write_file(
178 mod_fn,
180 mod_fn,
179 textwrap.dedent(
181 textwrap.dedent(
180 """
182 """
181 class MyClass:
183 class MyClass:
182
184
183 def __init__(self, a=10):
185 def __init__(self, a=10):
184 self.a = a
186 self.a = a
185 self.b = 11
187 self.b = 11
186
188
187 def power(self, p):
189 def power(self, p):
188 print('compute power '+str(p))
190 print('compute power '+str(p))
189 return self.a**p
191 return self.a**p
190 """
192 """
191 ),
193 ),
192 )
194 )
193
195
194 self.shell.run_code("second = MyClass(5)")
196 self.shell.run_code("second = MyClass(5)")
195
197
196 for object_name in {'first', 'second'}:
198 for object_name in {'first', 'second'}:
197 self.shell.run_code("{object_name}.power(5)".format(object_name=object_name))
199 self.shell.run_code("{object_name}.power(5)".format(object_name=object_name))
198 with nt.assert_raises(AttributeError):
200 with nt.assert_raises(AttributeError):
199 self.shell.run_code("{object_name}.cube()".format(object_name=object_name))
201 self.shell.run_code("{object_name}.cube()".format(object_name=object_name))
200 with nt.assert_raises(AttributeError):
202 with nt.assert_raises(AttributeError):
201 self.shell.run_code("{object_name}.square()".format(object_name=object_name))
203 self.shell.run_code("{object_name}.square()".format(object_name=object_name))
202 self.shell.run_code("{object_name}.b".format(object_name=object_name))
204 self.shell.run_code("{object_name}.b".format(object_name=object_name))
203 self.shell.run_code("{object_name}.a".format(object_name=object_name))
205 self.shell.run_code("{object_name}.a".format(object_name=object_name))
204 with nt.assert_raises(AttributeError):
206 with nt.assert_raises(AttributeError):
205 self.shell.run_code("{object_name}.toto".format(object_name=object_name))
207 self.shell.run_code("{object_name}.toto".format(object_name=object_name))
206
208
207 def _check_smoketest(self, use_aimport=True):
209 def _check_smoketest(self, use_aimport=True):
208 """
210 """
209 Functional test for the automatic reloader using either
211 Functional test for the automatic reloader using either
210 '%autoreload 1' or '%autoreload 2'
212 '%autoreload 1' or '%autoreload 2'
211 """
213 """
212
214
213 mod_name, mod_fn = self.new_module("""
215 mod_name, mod_fn = self.new_module("""
214 x = 9
216 x = 9
215
217
216 z = 123 # this item will be deleted
218 z = 123 # this item will be deleted
217
219
218 def foo(y):
220 def foo(y):
219 return y + 3
221 return y + 3
220
222
221 class Baz(object):
223 class Baz(object):
222 def __init__(self, x):
224 def __init__(self, x):
223 self.x = x
225 self.x = x
224 def bar(self, y):
226 def bar(self, y):
225 return self.x + y
227 return self.x + y
226 @property
228 @property
227 def quux(self):
229 def quux(self):
228 return 42
230 return 42
229 def zzz(self):
231 def zzz(self):
230 '''This method will be deleted below'''
232 '''This method will be deleted below'''
231 return 99
233 return 99
232
234
233 class Bar: # old-style class: weakref doesn't work for it on Python < 2.7
235 class Bar: # old-style class: weakref doesn't work for it on Python < 2.7
234 def foo(self):
236 def foo(self):
235 return 1
237 return 1
236 """)
238 """)
237
239
238 #
240 #
239 # Import module, and mark for reloading
241 # Import module, and mark for reloading
240 #
242 #
241 if use_aimport:
243 if use_aimport:
242 self.shell.magic_autoreload("1")
244 self.shell.magic_autoreload("1")
243 self.shell.magic_aimport(mod_name)
245 self.shell.magic_aimport(mod_name)
244 stream = StringIO()
246 stream = StringIO()
245 self.shell.magic_aimport("", stream=stream)
247 self.shell.magic_aimport("", stream=stream)
246 nt.assert_in(("Modules to reload:\n%s" % mod_name), stream.getvalue())
248 nt.assert_in(("Modules to reload:\n%s" % mod_name), stream.getvalue())
247
249
248 with nt.assert_raises(ImportError):
250 with nt.assert_raises(ImportError):
249 self.shell.magic_aimport("tmpmod_as318989e89ds")
251 self.shell.magic_aimport("tmpmod_as318989e89ds")
250 else:
252 else:
251 self.shell.magic_autoreload("2")
253 self.shell.magic_autoreload("2")
252 self.shell.run_code("import %s" % mod_name)
254 self.shell.run_code("import %s" % mod_name)
253 stream = StringIO()
255 stream = StringIO()
254 self.shell.magic_aimport("", stream=stream)
256 self.shell.magic_aimport("", stream=stream)
255 nt.assert_true("Modules to reload:\nall-except-skipped" in
257 nt.assert_true("Modules to reload:\nall-except-skipped" in
256 stream.getvalue())
258 stream.getvalue())
257 nt.assert_in(mod_name, self.shell.ns)
259 nt.assert_in(mod_name, self.shell.ns)
258
260
259 mod = sys.modules[mod_name]
261 mod = sys.modules[mod_name]
260
262
261 #
263 #
262 # Test module contents
264 # Test module contents
263 #
265 #
264 old_foo = mod.foo
266 old_foo = mod.foo
265 old_obj = mod.Baz(9)
267 old_obj = mod.Baz(9)
266 old_obj2 = mod.Bar()
268 old_obj2 = mod.Bar()
267
269
268 def check_module_contents():
270 def check_module_contents():
269 nt.assert_equal(mod.x, 9)
271 nt.assert_equal(mod.x, 9)
270 nt.assert_equal(mod.z, 123)
272 nt.assert_equal(mod.z, 123)
271
273
272 nt.assert_equal(old_foo(0), 3)
274 nt.assert_equal(old_foo(0), 3)
273 nt.assert_equal(mod.foo(0), 3)
275 nt.assert_equal(mod.foo(0), 3)
274
276
275 obj = mod.Baz(9)
277 obj = mod.Baz(9)
276 nt.assert_equal(old_obj.bar(1), 10)
278 nt.assert_equal(old_obj.bar(1), 10)
277 nt.assert_equal(obj.bar(1), 10)
279 nt.assert_equal(obj.bar(1), 10)
278 nt.assert_equal(obj.quux, 42)
280 nt.assert_equal(obj.quux, 42)
279 nt.assert_equal(obj.zzz(), 99)
281 nt.assert_equal(obj.zzz(), 99)
280
282
281 obj2 = mod.Bar()
283 obj2 = mod.Bar()
282 nt.assert_equal(old_obj2.foo(), 1)
284 nt.assert_equal(old_obj2.foo(), 1)
283 nt.assert_equal(obj2.foo(), 1)
285 nt.assert_equal(obj2.foo(), 1)
284
286
285 check_module_contents()
287 check_module_contents()
286
288
287 #
289 #
288 # Simulate a failed reload: no reload should occur and exactly
290 # Simulate a failed reload: no reload should occur and exactly
289 # one error message should be printed
291 # one error message should be printed
290 #
292 #
291 self.write_file(mod_fn, """
293 self.write_file(mod_fn, """
292 a syntax error
294 a syntax error
293 """)
295 """)
294
296
295 with tt.AssertPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
297 with tt.AssertPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
296 self.shell.run_code("pass") # trigger reload
298 self.shell.run_code("pass") # trigger reload
297 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
299 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
298 self.shell.run_code("pass") # trigger another reload
300 self.shell.run_code("pass") # trigger another reload
299 check_module_contents()
301 check_module_contents()
300
302
301 #
303 #
302 # Rewrite module (this time reload should succeed)
304 # Rewrite module (this time reload should succeed)
303 #
305 #
304 self.write_file(mod_fn, """
306 self.write_file(mod_fn, """
305 x = 10
307 x = 10
306
308
307 def foo(y):
309 def foo(y):
308 return y + 4
310 return y + 4
309
311
310 class Baz(object):
312 class Baz(object):
311 def __init__(self, x):
313 def __init__(self, x):
312 self.x = x
314 self.x = x
313 def bar(self, y):
315 def bar(self, y):
314 return self.x + y + 1
316 return self.x + y + 1
315 @property
317 @property
316 def quux(self):
318 def quux(self):
317 return 43
319 return 43
318
320
319 class Bar: # old-style class
321 class Bar: # old-style class
320 def foo(self):
322 def foo(self):
321 return 2
323 return 2
322 """)
324 """)
323
325
324 def check_module_contents():
326 def check_module_contents():
325 nt.assert_equal(mod.x, 10)
327 nt.assert_equal(mod.x, 10)
326 nt.assert_false(hasattr(mod, 'z'))
328 nt.assert_false(hasattr(mod, 'z'))
327
329
328 nt.assert_equal(old_foo(0), 4) # superreload magic!
330 nt.assert_equal(old_foo(0), 4) # superreload magic!
329 nt.assert_equal(mod.foo(0), 4)
331 nt.assert_equal(mod.foo(0), 4)
330
332
331 obj = mod.Baz(9)
333 obj = mod.Baz(9)
332 nt.assert_equal(old_obj.bar(1), 11) # superreload magic!
334 nt.assert_equal(old_obj.bar(1), 11) # superreload magic!
333 nt.assert_equal(obj.bar(1), 11)
335 nt.assert_equal(obj.bar(1), 11)
334
336
335 nt.assert_equal(old_obj.quux, 43)
337 nt.assert_equal(old_obj.quux, 43)
336 nt.assert_equal(obj.quux, 43)
338 nt.assert_equal(obj.quux, 43)
337
339
338 nt.assert_false(hasattr(old_obj, 'zzz'))
340 nt.assert_false(hasattr(old_obj, 'zzz'))
339 nt.assert_false(hasattr(obj, 'zzz'))
341 nt.assert_false(hasattr(obj, 'zzz'))
340
342
341 obj2 = mod.Bar()
343 obj2 = mod.Bar()
342 nt.assert_equal(old_obj2.foo(), 2)
344 nt.assert_equal(old_obj2.foo(), 2)
343 nt.assert_equal(obj2.foo(), 2)
345 nt.assert_equal(obj2.foo(), 2)
344
346
345 self.shell.run_code("pass") # trigger reload
347 self.shell.run_code("pass") # trigger reload
346 check_module_contents()
348 check_module_contents()
347
349
348 #
350 #
349 # Another failure case: deleted file (shouldn't reload)
351 # Another failure case: deleted file (shouldn't reload)
350 #
352 #
351 os.unlink(mod_fn)
353 os.unlink(mod_fn)
352
354
353 self.shell.run_code("pass") # trigger reload
355 self.shell.run_code("pass") # trigger reload
354 check_module_contents()
356 check_module_contents()
355
357
356 #
358 #
357 # Disable autoreload and rewrite module: no reload should occur
359 # Disable autoreload and rewrite module: no reload should occur
358 #
360 #
359 if use_aimport:
361 if use_aimport:
360 self.shell.magic_aimport("-" + mod_name)
362 self.shell.magic_aimport("-" + mod_name)
361 stream = StringIO()
363 stream = StringIO()
362 self.shell.magic_aimport("", stream=stream)
364 self.shell.magic_aimport("", stream=stream)
363 nt.assert_true(("Modules to skip:\n%s" % mod_name) in
365 nt.assert_true(("Modules to skip:\n%s" % mod_name) in
364 stream.getvalue())
366 stream.getvalue())
365
367
366 # This should succeed, although no such module exists
368 # This should succeed, although no such module exists
367 self.shell.magic_aimport("-tmpmod_as318989e89ds")
369 self.shell.magic_aimport("-tmpmod_as318989e89ds")
368 else:
370 else:
369 self.shell.magic_autoreload("0")
371 self.shell.magic_autoreload("0")
370
372
371 self.write_file(mod_fn, """
373 self.write_file(mod_fn, """
372 x = -99
374 x = -99
373 """)
375 """)
374
376
375 self.shell.run_code("pass") # trigger reload
377 self.shell.run_code("pass") # trigger reload
376 self.shell.run_code("pass")
378 self.shell.run_code("pass")
377 check_module_contents()
379 check_module_contents()
378
380
379 #
381 #
380 # Re-enable autoreload: reload should now occur
382 # Re-enable autoreload: reload should now occur
381 #
383 #
382 if use_aimport:
384 if use_aimport:
383 self.shell.magic_aimport(mod_name)
385 self.shell.magic_aimport(mod_name)
384 else:
386 else:
385 self.shell.magic_autoreload("")
387 self.shell.magic_autoreload("")
386
388
387 self.shell.run_code("pass") # trigger reload
389 self.shell.run_code("pass") # trigger reload
388 nt.assert_equal(mod.x, -99)
390 nt.assert_equal(mod.x, -99)
389
391
390 def test_smoketest_aimport(self):
392 def test_smoketest_aimport(self):
391 self._check_smoketest(use_aimport=True)
393 self._check_smoketest(use_aimport=True)
392
394
393 def test_smoketest_autoreload(self):
395 def test_smoketest_autoreload(self):
394 self._check_smoketest(use_aimport=False)
396 self._check_smoketest(use_aimport=False)
395
397
396
398
397
399
398
400
General Comments 0
You need to be logged in to leave comments. Login now