##// END OF EJS Templates
Merge pull request #11283 from oscar6echo/improve-autoreload...
Matthias Bussonnier -
r24522:197a017e merge
parent child Browse files
Show More
@@ -0,0 +1,30 b''
1 magic `%autoreload 2` now captures new methods added to classes. Earlier, only methods existing as of the initial import were being tracked and updated.
2
3 This new feature helps dual environement development - Jupyter+IDE - where the code gradually moves from notebook cells to package files, as it gets structured.
4
5 **Example**: An instance of the class `MyClass` will be able to access the method `cube()` after it is uncommented and the file `file1.py` saved on disk.
6
7 ````python
8 # notebook
9
10 from mymodule import MyClass
11 first = MyClass(5)
12 ````
13
14 ````python
15 # mymodule/file1.py
16
17 class MyClass:
18
19 def __init__(self, a=10):
20 self.a = a
21
22 def square(self):
23 print('compute square')
24 return self.a*self.a
25
26 # def cube(self):
27 # print('compute cube')
28 # return self.a*self.a*self.a
29 ````
30
@@ -1,523 +1,530 b''
1 """IPython extension to reload modules before executing user code.
1 """IPython extension to reload modules before executing user code.
2
2
3 ``autoreload`` reloads modules automatically before entering the execution of
3 ``autoreload`` reloads modules automatically before entering the execution of
4 code typed at the IPython prompt.
4 code typed at the IPython prompt.
5
5
6 This makes for example the following workflow possible:
6 This makes for example the following workflow possible:
7
7
8 .. sourcecode:: ipython
8 .. sourcecode:: ipython
9
9
10 In [1]: %load_ext autoreload
10 In [1]: %load_ext autoreload
11
11
12 In [2]: %autoreload 2
12 In [2]: %autoreload 2
13
13
14 In [3]: from foo import some_function
14 In [3]: from foo import some_function
15
15
16 In [4]: some_function()
16 In [4]: some_function()
17 Out[4]: 42
17 Out[4]: 42
18
18
19 In [5]: # open foo.py in an editor and change some_function to return 43
19 In [5]: # open foo.py in an editor and change some_function to return 43
20
20
21 In [6]: some_function()
21 In [6]: some_function()
22 Out[6]: 43
22 Out[6]: 43
23
23
24 The module was reloaded without reloading it explicitly, and the object
24 The module was reloaded without reloading it explicitly, and the object
25 imported with ``from foo import ...`` was also updated.
25 imported with ``from foo import ...`` was also updated.
26
26
27 Usage
27 Usage
28 =====
28 =====
29
29
30 The following magic commands are provided:
30 The following magic commands are provided:
31
31
32 ``%autoreload``
32 ``%autoreload``
33
33
34 Reload all modules (except those excluded by ``%aimport``)
34 Reload all modules (except those excluded by ``%aimport``)
35 automatically now.
35 automatically now.
36
36
37 ``%autoreload 0``
37 ``%autoreload 0``
38
38
39 Disable automatic reloading.
39 Disable automatic reloading.
40
40
41 ``%autoreload 1``
41 ``%autoreload 1``
42
42
43 Reload all modules imported with ``%aimport`` every time before
43 Reload all modules imported with ``%aimport`` every time before
44 executing the Python code typed.
44 executing the Python code typed.
45
45
46 ``%autoreload 2``
46 ``%autoreload 2``
47
47
48 Reload all modules (except those excluded by ``%aimport``) every
48 Reload all modules (except those excluded by ``%aimport``) every
49 time before executing the Python code typed.
49 time before executing the Python code typed.
50
50
51 ``%aimport``
51 ``%aimport``
52
52
53 List modules which are to be automatically imported or not to be imported.
53 List modules which are to be automatically imported or not to be imported.
54
54
55 ``%aimport foo``
55 ``%aimport foo``
56
56
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
58
58
59 ``%aimport foo, bar``
59 ``%aimport foo, bar``
60
60
61 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
61 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
62
62
63 ``%aimport -foo``
63 ``%aimport -foo``
64
64
65 Mark module 'foo' to not be autoreloaded.
65 Mark module 'foo' to not be autoreloaded.
66
66
67 Caveats
67 Caveats
68 =======
68 =======
69
69
70 Reloading Python modules in a reliable way is in general difficult,
70 Reloading Python modules in a reliable way is in general difficult,
71 and unexpected things may occur. ``%autoreload`` tries to work around
71 and unexpected things may occur. ``%autoreload`` tries to work around
72 common pitfalls by replacing function code objects and parts of
72 common pitfalls by replacing function code objects and parts of
73 classes previously in the module with new versions. This makes the
73 classes previously in the module with new versions. This makes the
74 following things to work:
74 following things to work:
75
75
76 - Functions and classes imported via 'from xxx import foo' are upgraded
76 - Functions and classes imported via 'from xxx import foo' are upgraded
77 to new versions when 'xxx' is reloaded.
77 to new versions when 'xxx' is reloaded.
78
78
79 - Methods and properties of classes are upgraded on reload, so that
79 - Methods and properties of classes are upgraded on reload, so that
80 calling 'c.foo()' on an object 'c' created before the reload causes
80 calling 'c.foo()' on an object 'c' created before the reload causes
81 the new code for 'foo' to be executed.
81 the new code for 'foo' to be executed.
82
82
83 Some of the known remaining caveats are:
83 Some of the known remaining caveats are:
84
84
85 - Replacing code objects does not always succeed: changing a @property
85 - Replacing code objects does not always succeed: changing a @property
86 in a class to an ordinary method or a method to a member variable
86 in a class to an ordinary method or a method to a member variable
87 can cause problems (but in old objects only).
87 can cause problems (but in old objects only).
88
88
89 - Functions that are removed (eg. via monkey-patching) from a module
89 - Functions that are removed (eg. via monkey-patching) from a module
90 before it is reloaded are not upgraded.
90 before it is reloaded are not upgraded.
91
91
92 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
92 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
93 """
93 """
94
94
95 skip_doctest = True
95 skip_doctest = True
96
96
97 #-----------------------------------------------------------------------------
97 #-----------------------------------------------------------------------------
98 # Copyright (C) 2000 Thomas Heller
98 # Copyright (C) 2000 Thomas Heller
99 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
99 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
100 # Copyright (C) 2012 The IPython Development Team
100 # Copyright (C) 2012 The IPython Development Team
101 #
101 #
102 # Distributed under the terms of the BSD License. The full license is in
102 # Distributed under the terms of the BSD License. The full license is in
103 # the file COPYING, distributed as part of this software.
103 # the file COPYING, distributed as part of this software.
104 #-----------------------------------------------------------------------------
104 #-----------------------------------------------------------------------------
105 #
105 #
106 # This IPython module is written by Pauli Virtanen, based on the autoreload
106 # This IPython module is written by Pauli Virtanen, based on the autoreload
107 # code by Thomas Heller.
107 # code by Thomas Heller.
108
108
109 #-----------------------------------------------------------------------------
109 #-----------------------------------------------------------------------------
110 # Imports
110 # Imports
111 #-----------------------------------------------------------------------------
111 #-----------------------------------------------------------------------------
112
112
113 import os
113 import os
114 import sys
114 import sys
115 import traceback
115 import traceback
116 import types
116 import types
117 import weakref
117 import weakref
118 from importlib import import_module
118 from importlib import import_module
119 from importlib.util import source_from_cache
119 from importlib.util import source_from_cache
120 from imp import reload
120 from imp import reload
121
121
122 #------------------------------------------------------------------------------
122 #------------------------------------------------------------------------------
123 # Autoreload functionality
123 # Autoreload functionality
124 #------------------------------------------------------------------------------
124 #------------------------------------------------------------------------------
125
125
126 class ModuleReloader(object):
126 class ModuleReloader(object):
127 enabled = False
127 enabled = False
128 """Whether this reloader is enabled"""
128 """Whether this reloader is enabled"""
129
129
130 check_all = True
130 check_all = True
131 """Autoreload all modules, not just those listed in 'modules'"""
131 """Autoreload all modules, not just those listed in 'modules'"""
132
132
133 def __init__(self):
133 def __init__(self):
134 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
134 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
135 self.failed = {}
135 self.failed = {}
136 # Modules specially marked as autoreloadable.
136 # Modules specially marked as autoreloadable.
137 self.modules = {}
137 self.modules = {}
138 # Modules specially marked as not autoreloadable.
138 # Modules specially marked as not autoreloadable.
139 self.skip_modules = {}
139 self.skip_modules = {}
140 # (module-name, name) -> weakref, for replacing old code objects
140 # (module-name, name) -> weakref, for replacing old code objects
141 self.old_objects = {}
141 self.old_objects = {}
142 # Module modification timestamps
142 # Module modification timestamps
143 self.modules_mtimes = {}
143 self.modules_mtimes = {}
144
144
145 # Cache module modification times
145 # Cache module modification times
146 self.check(check_all=True, do_reload=False)
146 self.check(check_all=True, do_reload=False)
147
147
148 def mark_module_skipped(self, module_name):
148 def mark_module_skipped(self, module_name):
149 """Skip reloading the named module in the future"""
149 """Skip reloading the named module in the future"""
150 try:
150 try:
151 del self.modules[module_name]
151 del self.modules[module_name]
152 except KeyError:
152 except KeyError:
153 pass
153 pass
154 self.skip_modules[module_name] = True
154 self.skip_modules[module_name] = True
155
155
156 def mark_module_reloadable(self, module_name):
156 def mark_module_reloadable(self, module_name):
157 """Reload the named module in the future (if it is imported)"""
157 """Reload the named module in the future (if it is imported)"""
158 try:
158 try:
159 del self.skip_modules[module_name]
159 del self.skip_modules[module_name]
160 except KeyError:
160 except KeyError:
161 pass
161 pass
162 self.modules[module_name] = True
162 self.modules[module_name] = True
163
163
164 def aimport_module(self, module_name):
164 def aimport_module(self, module_name):
165 """Import a module, and mark it reloadable
165 """Import a module, and mark it reloadable
166
166
167 Returns
167 Returns
168 -------
168 -------
169 top_module : module
169 top_module : module
170 The imported module if it is top-level, or the top-level
170 The imported module if it is top-level, or the top-level
171 top_name : module
171 top_name : module
172 Name of top_module
172 Name of top_module
173
173
174 """
174 """
175 self.mark_module_reloadable(module_name)
175 self.mark_module_reloadable(module_name)
176
176
177 import_module(module_name)
177 import_module(module_name)
178 top_name = module_name.split('.')[0]
178 top_name = module_name.split('.')[0]
179 top_module = sys.modules[top_name]
179 top_module = sys.modules[top_name]
180 return top_module, top_name
180 return top_module, top_name
181
181
182 def filename_and_mtime(self, module):
182 def filename_and_mtime(self, module):
183 if not hasattr(module, '__file__') or module.__file__ is None:
183 if not hasattr(module, '__file__') or module.__file__ is None:
184 return None, None
184 return None, None
185
185
186 if getattr(module, '__name__', None) in [None, '__mp_main__', '__main__']:
186 if getattr(module, '__name__', None) in [None, '__mp_main__', '__main__']:
187 # we cannot reload(__main__) or reload(__mp_main__)
187 # we cannot reload(__main__) or reload(__mp_main__)
188 return None, None
188 return None, None
189
189
190 filename = module.__file__
190 filename = module.__file__
191 path, ext = os.path.splitext(filename)
191 path, ext = os.path.splitext(filename)
192
192
193 if ext.lower() == '.py':
193 if ext.lower() == '.py':
194 py_filename = filename
194 py_filename = filename
195 else:
195 else:
196 try:
196 try:
197 py_filename = source_from_cache(filename)
197 py_filename = source_from_cache(filename)
198 except ValueError:
198 except ValueError:
199 return None, None
199 return None, None
200
200
201 try:
201 try:
202 pymtime = os.stat(py_filename).st_mtime
202 pymtime = os.stat(py_filename).st_mtime
203 except OSError:
203 except OSError:
204 return None, None
204 return None, None
205
205
206 return py_filename, pymtime
206 return py_filename, pymtime
207
207
208 def check(self, check_all=False, do_reload=True):
208 def check(self, check_all=False, do_reload=True):
209 """Check whether some modules need to be reloaded."""
209 """Check whether some modules need to be reloaded."""
210
210
211 if not self.enabled and not check_all:
211 if not self.enabled and not check_all:
212 return
212 return
213
213
214 if check_all or self.check_all:
214 if check_all or self.check_all:
215 modules = list(sys.modules.keys())
215 modules = list(sys.modules.keys())
216 else:
216 else:
217 modules = list(self.modules.keys())
217 modules = list(self.modules.keys())
218
218
219 for modname in modules:
219 for modname in modules:
220 m = sys.modules.get(modname, None)
220 m = sys.modules.get(modname, None)
221
221
222 if modname in self.skip_modules:
222 if modname in self.skip_modules:
223 continue
223 continue
224
224
225 py_filename, pymtime = self.filename_and_mtime(m)
225 py_filename, pymtime = self.filename_and_mtime(m)
226 if py_filename is None:
226 if py_filename is None:
227 continue
227 continue
228
228
229 try:
229 try:
230 if pymtime <= self.modules_mtimes[modname]:
230 if pymtime <= self.modules_mtimes[modname]:
231 continue
231 continue
232 except KeyError:
232 except KeyError:
233 self.modules_mtimes[modname] = pymtime
233 self.modules_mtimes[modname] = pymtime
234 continue
234 continue
235 else:
235 else:
236 if self.failed.get(py_filename, None) == pymtime:
236 if self.failed.get(py_filename, None) == pymtime:
237 continue
237 continue
238
238
239 self.modules_mtimes[modname] = pymtime
239 self.modules_mtimes[modname] = pymtime
240
240
241 # If we've reached this point, we should try to reload the module
241 # If we've reached this point, we should try to reload the module
242 if do_reload:
242 if do_reload:
243 try:
243 try:
244 superreload(m, reload, self.old_objects)
244 superreload(m, reload, self.old_objects)
245 if py_filename in self.failed:
245 if py_filename in self.failed:
246 del self.failed[py_filename]
246 del self.failed[py_filename]
247 except:
247 except:
248 print("[autoreload of %s failed: %s]" % (
248 print("[autoreload of %s failed: %s]" % (
249 modname, traceback.format_exc(10)), file=sys.stderr)
249 modname, traceback.format_exc(10)), file=sys.stderr)
250 self.failed[py_filename] = pymtime
250 self.failed[py_filename] = pymtime
251
251
252 #------------------------------------------------------------------------------
252 #------------------------------------------------------------------------------
253 # superreload
253 # superreload
254 #------------------------------------------------------------------------------
254 #------------------------------------------------------------------------------
255
255
256
256
257 func_attrs = ['__code__', '__defaults__', '__doc__',
257 func_attrs = ['__code__', '__defaults__', '__doc__',
258 '__closure__', '__globals__', '__dict__']
258 '__closure__', '__globals__', '__dict__']
259
259
260
260
261 def update_function(old, new):
261 def update_function(old, new):
262 """Upgrade the code object of a function"""
262 """Upgrade the code object of a function"""
263 for name in func_attrs:
263 for name in func_attrs:
264 try:
264 try:
265 setattr(old, name, getattr(new, name))
265 setattr(old, name, getattr(new, name))
266 except (AttributeError, TypeError):
266 except (AttributeError, TypeError):
267 pass
267 pass
268
268
269
269
270 def update_class(old, new):
270 def update_class(old, new):
271 """Replace stuff in the __dict__ of a class, and upgrade
271 """Replace stuff in the __dict__ of a class, and upgrade
272 method code objects"""
272 method code objects, and add new methods, if any"""
273 for key in list(old.__dict__.keys()):
273 for key in list(old.__dict__.keys()):
274 old_obj = getattr(old, key)
274 old_obj = getattr(old, key)
275 try:
275 try:
276 new_obj = getattr(new, key)
276 new_obj = getattr(new, key)
277 if old_obj == new_obj:
277 if old_obj == new_obj:
278 continue
278 continue
279 except AttributeError:
279 except AttributeError:
280 # obsolete attribute: remove it
280 # obsolete attribute: remove it
281 try:
281 try:
282 delattr(old, key)
282 delattr(old, key)
283 except (AttributeError, TypeError):
283 except (AttributeError, TypeError):
284 pass
284 pass
285 continue
285 continue
286
286
287 if update_generic(old_obj, new_obj): continue
287 if update_generic(old_obj, new_obj): continue
288
288
289 try:
289 try:
290 setattr(old, key, getattr(new, key))
290 setattr(old, key, getattr(new, key))
291 except (AttributeError, TypeError):
291 except (AttributeError, TypeError):
292 pass # skip non-writable attributes
292 pass # skip non-writable attributes
293
293
294 for key in list(new.__dict__.keys()):
295 if key not in list(old.__dict__.keys()):
296 try:
297 setattr(old, key, getattr(new, key))
298 except (AttributeError, TypeError):
299 pass # skip non-writable attributes
300
294
301
295 def update_property(old, new):
302 def update_property(old, new):
296 """Replace get/set/del functions of a property"""
303 """Replace get/set/del functions of a property"""
297 update_generic(old.fdel, new.fdel)
304 update_generic(old.fdel, new.fdel)
298 update_generic(old.fget, new.fget)
305 update_generic(old.fget, new.fget)
299 update_generic(old.fset, new.fset)
306 update_generic(old.fset, new.fset)
300
307
301
308
302 def isinstance2(a, b, typ):
309 def isinstance2(a, b, typ):
303 return isinstance(a, typ) and isinstance(b, typ)
310 return isinstance(a, typ) and isinstance(b, typ)
304
311
305
312
306 UPDATE_RULES = [
313 UPDATE_RULES = [
307 (lambda a, b: isinstance2(a, b, type),
314 (lambda a, b: isinstance2(a, b, type),
308 update_class),
315 update_class),
309 (lambda a, b: isinstance2(a, b, types.FunctionType),
316 (lambda a, b: isinstance2(a, b, types.FunctionType),
310 update_function),
317 update_function),
311 (lambda a, b: isinstance2(a, b, property),
318 (lambda a, b: isinstance2(a, b, property),
312 update_property),
319 update_property),
313 ]
320 ]
314 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
321 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
315 lambda a, b: update_function(a.__func__, b.__func__)),
322 lambda a, b: update_function(a.__func__, b.__func__)),
316 ])
323 ])
317
324
318
325
319 def update_generic(a, b):
326 def update_generic(a, b):
320 for type_check, update in UPDATE_RULES:
327 for type_check, update in UPDATE_RULES:
321 if type_check(a, b):
328 if type_check(a, b):
322 update(a, b)
329 update(a, b)
323 return True
330 return True
324 return False
331 return False
325
332
326
333
327 class StrongRef(object):
334 class StrongRef(object):
328 def __init__(self, obj):
335 def __init__(self, obj):
329 self.obj = obj
336 self.obj = obj
330 def __call__(self):
337 def __call__(self):
331 return self.obj
338 return self.obj
332
339
333
340
334 def superreload(module, reload=reload, old_objects={}):
341 def superreload(module, reload=reload, old_objects={}):
335 """Enhanced version of the builtin reload function.
342 """Enhanced version of the builtin reload function.
336
343
337 superreload remembers objects previously in the module, and
344 superreload remembers objects previously in the module, and
338
345
339 - upgrades the class dictionary of every old class in the module
346 - upgrades the class dictionary of every old class in the module
340 - upgrades the code object of every old function and method
347 - upgrades the code object of every old function and method
341 - clears the module's namespace before reloading
348 - clears the module's namespace before reloading
342
349
343 """
350 """
344
351
345 # collect old objects in the module
352 # collect old objects in the module
346 for name, obj in list(module.__dict__.items()):
353 for name, obj in list(module.__dict__.items()):
347 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
354 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
348 continue
355 continue
349 key = (module.__name__, name)
356 key = (module.__name__, name)
350 try:
357 try:
351 old_objects.setdefault(key, []).append(weakref.ref(obj))
358 old_objects.setdefault(key, []).append(weakref.ref(obj))
352 except TypeError:
359 except TypeError:
353 pass
360 pass
354
361
355 # reload module
362 # reload module
356 try:
363 try:
357 # clear namespace first from old cruft
364 # clear namespace first from old cruft
358 old_dict = module.__dict__.copy()
365 old_dict = module.__dict__.copy()
359 old_name = module.__name__
366 old_name = module.__name__
360 module.__dict__.clear()
367 module.__dict__.clear()
361 module.__dict__['__name__'] = old_name
368 module.__dict__['__name__'] = old_name
362 module.__dict__['__loader__'] = old_dict['__loader__']
369 module.__dict__['__loader__'] = old_dict['__loader__']
363 except (TypeError, AttributeError, KeyError):
370 except (TypeError, AttributeError, KeyError):
364 pass
371 pass
365
372
366 try:
373 try:
367 module = reload(module)
374 module = reload(module)
368 except:
375 except:
369 # restore module dictionary on failed reload
376 # restore module dictionary on failed reload
370 module.__dict__.update(old_dict)
377 module.__dict__.update(old_dict)
371 raise
378 raise
372
379
373 # iterate over all objects and update functions & classes
380 # iterate over all objects and update functions & classes
374 for name, new_obj in list(module.__dict__.items()):
381 for name, new_obj in list(module.__dict__.items()):
375 key = (module.__name__, name)
382 key = (module.__name__, name)
376 if key not in old_objects: continue
383 if key not in old_objects: continue
377
384
378 new_refs = []
385 new_refs = []
379 for old_ref in old_objects[key]:
386 for old_ref in old_objects[key]:
380 old_obj = old_ref()
387 old_obj = old_ref()
381 if old_obj is None: continue
388 if old_obj is None: continue
382 new_refs.append(old_ref)
389 new_refs.append(old_ref)
383 update_generic(old_obj, new_obj)
390 update_generic(old_obj, new_obj)
384
391
385 if new_refs:
392 if new_refs:
386 old_objects[key] = new_refs
393 old_objects[key] = new_refs
387 else:
394 else:
388 del old_objects[key]
395 del old_objects[key]
389
396
390 return module
397 return module
391
398
392 #------------------------------------------------------------------------------
399 #------------------------------------------------------------------------------
393 # IPython connectivity
400 # IPython connectivity
394 #------------------------------------------------------------------------------
401 #------------------------------------------------------------------------------
395
402
396 from IPython.core.magic import Magics, magics_class, line_magic
403 from IPython.core.magic import Magics, magics_class, line_magic
397
404
398 @magics_class
405 @magics_class
399 class AutoreloadMagics(Magics):
406 class AutoreloadMagics(Magics):
400 def __init__(self, *a, **kw):
407 def __init__(self, *a, **kw):
401 super(AutoreloadMagics, self).__init__(*a, **kw)
408 super(AutoreloadMagics, self).__init__(*a, **kw)
402 self._reloader = ModuleReloader()
409 self._reloader = ModuleReloader()
403 self._reloader.check_all = False
410 self._reloader.check_all = False
404 self.loaded_modules = set(sys.modules)
411 self.loaded_modules = set(sys.modules)
405
412
406 @line_magic
413 @line_magic
407 def autoreload(self, parameter_s=''):
414 def autoreload(self, parameter_s=''):
408 r"""%autoreload => Reload modules automatically
415 r"""%autoreload => Reload modules automatically
409
416
410 %autoreload
417 %autoreload
411 Reload all modules (except those excluded by %aimport) automatically
418 Reload all modules (except those excluded by %aimport) automatically
412 now.
419 now.
413
420
414 %autoreload 0
421 %autoreload 0
415 Disable automatic reloading.
422 Disable automatic reloading.
416
423
417 %autoreload 1
424 %autoreload 1
418 Reload all modules imported with %aimport every time before executing
425 Reload all modules imported with %aimport every time before executing
419 the Python code typed.
426 the Python code typed.
420
427
421 %autoreload 2
428 %autoreload 2
422 Reload all modules (except those excluded by %aimport) every time
429 Reload all modules (except those excluded by %aimport) every time
423 before executing the Python code typed.
430 before executing the Python code typed.
424
431
425 Reloading Python modules in a reliable way is in general
432 Reloading Python modules in a reliable way is in general
426 difficult, and unexpected things may occur. %autoreload tries to
433 difficult, and unexpected things may occur. %autoreload tries to
427 work around common pitfalls by replacing function code objects and
434 work around common pitfalls by replacing function code objects and
428 parts of classes previously in the module with new versions. This
435 parts of classes previously in the module with new versions. This
429 makes the following things to work:
436 makes the following things to work:
430
437
431 - Functions and classes imported via 'from xxx import foo' are upgraded
438 - Functions and classes imported via 'from xxx import foo' are upgraded
432 to new versions when 'xxx' is reloaded.
439 to new versions when 'xxx' is reloaded.
433
440
434 - Methods and properties of classes are upgraded on reload, so that
441 - Methods and properties of classes are upgraded on reload, so that
435 calling 'c.foo()' on an object 'c' created before the reload causes
442 calling 'c.foo()' on an object 'c' created before the reload causes
436 the new code for 'foo' to be executed.
443 the new code for 'foo' to be executed.
437
444
438 Some of the known remaining caveats are:
445 Some of the known remaining caveats are:
439
446
440 - Replacing code objects does not always succeed: changing a @property
447 - Replacing code objects does not always succeed: changing a @property
441 in a class to an ordinary method or a method to a member variable
448 in a class to an ordinary method or a method to a member variable
442 can cause problems (but in old objects only).
449 can cause problems (but in old objects only).
443
450
444 - Functions that are removed (eg. via monkey-patching) from a module
451 - Functions that are removed (eg. via monkey-patching) from a module
445 before it is reloaded are not upgraded.
452 before it is reloaded are not upgraded.
446
453
447 - C extension modules cannot be reloaded, and so cannot be
454 - C extension modules cannot be reloaded, and so cannot be
448 autoreloaded.
455 autoreloaded.
449
456
450 """
457 """
451 if parameter_s == '':
458 if parameter_s == '':
452 self._reloader.check(True)
459 self._reloader.check(True)
453 elif parameter_s == '0':
460 elif parameter_s == '0':
454 self._reloader.enabled = False
461 self._reloader.enabled = False
455 elif parameter_s == '1':
462 elif parameter_s == '1':
456 self._reloader.check_all = False
463 self._reloader.check_all = False
457 self._reloader.enabled = True
464 self._reloader.enabled = True
458 elif parameter_s == '2':
465 elif parameter_s == '2':
459 self._reloader.check_all = True
466 self._reloader.check_all = True
460 self._reloader.enabled = True
467 self._reloader.enabled = True
461
468
462 @line_magic
469 @line_magic
463 def aimport(self, parameter_s='', stream=None):
470 def aimport(self, parameter_s='', stream=None):
464 """%aimport => Import modules for automatic reloading.
471 """%aimport => Import modules for automatic reloading.
465
472
466 %aimport
473 %aimport
467 List modules to automatically import and not to import.
474 List modules to automatically import and not to import.
468
475
469 %aimport foo
476 %aimport foo
470 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
477 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
471
478
472 %aimport foo, bar
479 %aimport foo, bar
473 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
480 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
474
481
475 %aimport -foo
482 %aimport -foo
476 Mark module 'foo' to not be autoreloaded for %autoreload 1
483 Mark module 'foo' to not be autoreloaded for %autoreload 1
477 """
484 """
478 modname = parameter_s
485 modname = parameter_s
479 if not modname:
486 if not modname:
480 to_reload = sorted(self._reloader.modules.keys())
487 to_reload = sorted(self._reloader.modules.keys())
481 to_skip = sorted(self._reloader.skip_modules.keys())
488 to_skip = sorted(self._reloader.skip_modules.keys())
482 if stream is None:
489 if stream is None:
483 stream = sys.stdout
490 stream = sys.stdout
484 if self._reloader.check_all:
491 if self._reloader.check_all:
485 stream.write("Modules to reload:\nall-except-skipped\n")
492 stream.write("Modules to reload:\nall-except-skipped\n")
486 else:
493 else:
487 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
494 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
488 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
495 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
489 elif modname.startswith('-'):
496 elif modname.startswith('-'):
490 modname = modname[1:]
497 modname = modname[1:]
491 self._reloader.mark_module_skipped(modname)
498 self._reloader.mark_module_skipped(modname)
492 else:
499 else:
493 for _module in ([_.strip() for _ in modname.split(',')]):
500 for _module in ([_.strip() for _ in modname.split(',')]):
494 top_module, top_name = self._reloader.aimport_module(_module)
501 top_module, top_name = self._reloader.aimport_module(_module)
495
502
496 # Inject module to user namespace
503 # Inject module to user namespace
497 self.shell.push({top_name: top_module})
504 self.shell.push({top_name: top_module})
498
505
499 def pre_run_cell(self):
506 def pre_run_cell(self):
500 if self._reloader.enabled:
507 if self._reloader.enabled:
501 try:
508 try:
502 self._reloader.check()
509 self._reloader.check()
503 except:
510 except:
504 pass
511 pass
505
512
506 def post_execute_hook(self):
513 def post_execute_hook(self):
507 """Cache the modification times of any modules imported in this execution
514 """Cache the modification times of any modules imported in this execution
508 """
515 """
509 newly_loaded_modules = set(sys.modules) - self.loaded_modules
516 newly_loaded_modules = set(sys.modules) - self.loaded_modules
510 for modname in newly_loaded_modules:
517 for modname in newly_loaded_modules:
511 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
518 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
512 if pymtime is not None:
519 if pymtime is not None:
513 self._reloader.modules_mtimes[modname] = pymtime
520 self._reloader.modules_mtimes[modname] = pymtime
514
521
515 self.loaded_modules.update(newly_loaded_modules)
522 self.loaded_modules.update(newly_loaded_modules)
516
523
517
524
518 def load_ipython_extension(ip):
525 def load_ipython_extension(ip):
519 """Load the extension in IPython."""
526 """Load the extension in IPython."""
520 auto_reload = AutoreloadMagics(ip)
527 auto_reload = AutoreloadMagics(ip)
521 ip.register_magics(auto_reload)
528 ip.register_magics(auto_reload)
522 ip.events.register('pre_run_cell', auto_reload.pre_run_cell)
529 ip.events.register('pre_run_cell', auto_reload.pre_run_cell)
523 ip.events.register('post_execute', auto_reload.post_execute_hook)
530 ip.events.register('post_execute', auto_reload.post_execute_hook)
@@ -1,342 +1,404 b''
1 """Tests for autoreload extension.
1 """Tests for autoreload extension.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (c) 2012 IPython Development Team.
4 # Copyright (c) 2012 IPython Development Team.
5 #
5 #
6 # Distributed under the terms of the Modified BSD License.
6 # Distributed under the terms of the Modified BSD License.
7 #
7 #
8 # The full license is in the file COPYING.txt, distributed with this software.
8 # The full license is in the file COPYING.txt, distributed with this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 import os
15 import os
16 import sys
16 import sys
17 import tempfile
17 import tempfile
18 import textwrap
18 import textwrap
19 import shutil
19 import shutil
20 import random
20 import random
21 import time
21 import time
22 from io import StringIO
22 from io import StringIO
23
23
24 import nose.tools as nt
24 import nose.tools as nt
25 import IPython.testing.tools as tt
25 import IPython.testing.tools as tt
26
26
27 from IPython.testing.decorators import skipif
27 from IPython.testing.decorators import skipif
28
28
29 from IPython.extensions.autoreload import AutoreloadMagics
29 from IPython.extensions.autoreload import AutoreloadMagics
30 from IPython.core.events import EventManager, pre_run_cell
30 from IPython.core.events import EventManager, pre_run_cell
31
31
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33 # Test fixture
33 # Test fixture
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35
35
36 noop = lambda *a, **kw: None
36 noop = lambda *a, **kw: None
37
37
38 class FakeShell(object):
38 class FakeShell(object):
39
39
40 def __init__(self):
40 def __init__(self):
41 self.ns = {}
41 self.ns = {}
42 self.events = EventManager(self, {'pre_run_cell', pre_run_cell})
42 self.events = EventManager(self, {'pre_run_cell', pre_run_cell})
43 self.auto_magics = AutoreloadMagics(shell=self)
43 self.auto_magics = AutoreloadMagics(shell=self)
44 self.events.register('pre_run_cell', self.auto_magics.pre_run_cell)
44 self.events.register('pre_run_cell', self.auto_magics.pre_run_cell)
45
45
46 register_magics = set_hook = noop
46 register_magics = set_hook = noop
47
47
48 def run_code(self, code):
48 def run_code(self, code):
49 self.events.trigger('pre_run_cell')
49 self.events.trigger('pre_run_cell')
50 exec(code, self.ns)
50 exec(code, self.ns)
51 self.auto_magics.post_execute_hook()
51 self.auto_magics.post_execute_hook()
52
52
53 def push(self, items):
53 def push(self, items):
54 self.ns.update(items)
54 self.ns.update(items)
55
55
56 def magic_autoreload(self, parameter):
56 def magic_autoreload(self, parameter):
57 self.auto_magics.autoreload(parameter)
57 self.auto_magics.autoreload(parameter)
58
58
59 def magic_aimport(self, parameter, stream=None):
59 def magic_aimport(self, parameter, stream=None):
60 self.auto_magics.aimport(parameter, stream=stream)
60 self.auto_magics.aimport(parameter, stream=stream)
61 self.auto_magics.post_execute_hook()
61 self.auto_magics.post_execute_hook()
62
62
63
63
64 class Fixture(object):
64 class Fixture(object):
65 """Fixture for creating test module files"""
65 """Fixture for creating test module files"""
66
66
67 test_dir = None
67 test_dir = None
68 old_sys_path = None
68 old_sys_path = None
69 filename_chars = "abcdefghijklmopqrstuvwxyz0123456789"
69 filename_chars = "abcdefghijklmopqrstuvwxyz0123456789"
70
70
71 def setUp(self):
71 def setUp(self):
72 self.test_dir = tempfile.mkdtemp()
72 self.test_dir = tempfile.mkdtemp()
73 self.old_sys_path = list(sys.path)
73 self.old_sys_path = list(sys.path)
74 sys.path.insert(0, self.test_dir)
74 sys.path.insert(0, self.test_dir)
75 self.shell = FakeShell()
75 self.shell = FakeShell()
76
76
77 def tearDown(self):
77 def tearDown(self):
78 shutil.rmtree(self.test_dir)
78 shutil.rmtree(self.test_dir)
79 sys.path = self.old_sys_path
79 sys.path = self.old_sys_path
80
80
81 self.test_dir = None
81 self.test_dir = None
82 self.old_sys_path = None
82 self.old_sys_path = None
83 self.shell = None
83 self.shell = None
84
84
85 def get_module(self):
85 def get_module(self):
86 module_name = "tmpmod_" + "".join(random.sample(self.filename_chars,20))
86 module_name = "tmpmod_" + "".join(random.sample(self.filename_chars,20))
87 if module_name in sys.modules:
87 if module_name in sys.modules:
88 del sys.modules[module_name]
88 del sys.modules[module_name]
89 file_name = os.path.join(self.test_dir, module_name + ".py")
89 file_name = os.path.join(self.test_dir, module_name + ".py")
90 return module_name, file_name
90 return module_name, file_name
91
91
92 def write_file(self, filename, content):
92 def write_file(self, filename, content):
93 """
93 """
94 Write a file, and force a timestamp difference of at least one second
94 Write a file, and force a timestamp difference of at least one second
95
95
96 Notes
96 Notes
97 -----
97 -----
98 Python's .pyc files record the timestamp of their compilation
98 Python's .pyc files record the timestamp of their compilation
99 with a time resolution of one second.
99 with a time resolution of one second.
100
100
101 Therefore, we need to force a timestamp difference between .py
101 Therefore, we need to force a timestamp difference between .py
102 and .pyc, without having the .py file be timestamped in the
102 and .pyc, without having the .py file be timestamped in the
103 future, and without changing the timestamp of the .pyc file
103 future, and without changing the timestamp of the .pyc file
104 (because that is stored in the file). The only reliable way
104 (because that is stored in the file). The only reliable way
105 to achieve this seems to be to sleep.
105 to achieve this seems to be to sleep.
106 """
106 """
107
107
108 # Sleep one second + eps
108 # Sleep one second + eps
109 time.sleep(1.05)
109 time.sleep(1.05)
110
110
111 # Write
111 # Write
112 f = open(filename, 'w')
112 f = open(filename, 'w')
113 try:
113 try:
114 f.write(content)
114 f.write(content)
115 finally:
115 finally:
116 f.close()
116 f.close()
117
117
118 def new_module(self, code):
118 def new_module(self, code):
119 mod_name, mod_fn = self.get_module()
119 mod_name, mod_fn = self.get_module()
120 f = open(mod_fn, 'w')
120 f = open(mod_fn, 'w')
121 try:
121 try:
122 f.write(code)
122 f.write(code)
123 finally:
123 finally:
124 f.close()
124 f.close()
125 return mod_name, mod_fn
125 return mod_name, mod_fn
126
126
127 #-----------------------------------------------------------------------------
127 #-----------------------------------------------------------------------------
128 # Test automatic reloading
128 # Test automatic reloading
129 #-----------------------------------------------------------------------------
129 #-----------------------------------------------------------------------------
130
130
131 class TestAutoreload(Fixture):
131 class TestAutoreload(Fixture):
132
132
133 @skipif(sys.version_info < (3, 6))
133 @skipif(sys.version_info < (3, 6))
134 def test_reload_enums(self):
134 def test_reload_enums(self):
135 import enum
135 import enum
136 mod_name, mod_fn = self.new_module(textwrap.dedent("""
136 mod_name, mod_fn = self.new_module(textwrap.dedent("""
137 from enum import Enum
137 from enum import Enum
138 class MyEnum(Enum):
138 class MyEnum(Enum):
139 A = 'A'
139 A = 'A'
140 B = 'B'
140 B = 'B'
141 """))
141 """))
142 self.shell.magic_autoreload("2")
142 self.shell.magic_autoreload("2")
143 self.shell.magic_aimport(mod_name)
143 self.shell.magic_aimport(mod_name)
144 self.write_file(mod_fn, textwrap.dedent("""
144 self.write_file(mod_fn, textwrap.dedent("""
145 from enum import Enum
145 from enum import Enum
146 class MyEnum(Enum):
146 class MyEnum(Enum):
147 A = 'A'
147 A = 'A'
148 B = 'B'
148 B = 'B'
149 C = 'C'
149 C = 'C'
150 """))
150 """))
151 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
151 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
152 self.shell.run_code("pass") # trigger another reload
152 self.shell.run_code("pass") # trigger another reload
153
153
154 def test_reload_class_attributes(self):
155 self.shell.magic_autoreload("2")
156 mod_name, mod_fn = self.new_module(textwrap.dedent("""
157 class MyClass:
158
159 def __init__(self, a=10):
160 self.a = a
161 self.b = 22
162 # self.toto = 33
163
164 def square(self):
165 print('compute square')
166 return self.a*self.a
167 """
168 )
169 )
170 self.shell.run_code("from %s import MyClass" % mod_name)
171 self.shell.run_code("first = MyClass(5)")
172 self.shell.run_code("first.square()")
173 with nt.assert_raises(AttributeError):
174 self.shell.run_code("first.cube()")
175 with nt.assert_raises(AttributeError):
176 self.shell.run_code("first.power(5)")
177 self.shell.run_code("first.b")
178 with nt.assert_raises(AttributeError):
179 self.shell.run_code("first.toto")
180
181 # remove square, add power
182
183 self.write_file(
184 mod_fn,
185 textwrap.dedent(
186 """
187 class MyClass:
188
189 def __init__(self, a=10):
190 self.a = a
191 self.b = 11
192
193 def power(self, p):
194 print('compute power '+str(p))
195 return self.a**p
196 """
197 ),
198 )
199
200 self.shell.run_code("second = MyClass(5)")
201
202 for object_name in {'first', 'second'}:
203 self.shell.run_code("{object_name}.power(5)".format(object_name=object_name))
204 with nt.assert_raises(AttributeError):
205 self.shell.run_code("{object_name}.cube()".format(object_name=object_name))
206 with nt.assert_raises(AttributeError):
207 self.shell.run_code("{object_name}.square()".format(object_name=object_name))
208 self.shell.run_code("{object_name}.b".format(object_name=object_name))
209 self.shell.run_code("{object_name}.a".format(object_name=object_name))
210 with nt.assert_raises(AttributeError):
211 self.shell.run_code("{object_name}.toto".format(object_name=object_name))
154
212
155 def _check_smoketest(self, use_aimport=True):
213 def _check_smoketest(self, use_aimport=True):
156 """
214 """
157 Functional test for the automatic reloader using either
215 Functional test for the automatic reloader using either
158 '%autoreload 1' or '%autoreload 2'
216 '%autoreload 1' or '%autoreload 2'
159 """
217 """
160
218
161 mod_name, mod_fn = self.new_module("""
219 mod_name, mod_fn = self.new_module("""
162 x = 9
220 x = 9
163
221
164 z = 123 # this item will be deleted
222 z = 123 # this item will be deleted
165
223
166 def foo(y):
224 def foo(y):
167 return y + 3
225 return y + 3
168
226
169 class Baz(object):
227 class Baz(object):
170 def __init__(self, x):
228 def __init__(self, x):
171 self.x = x
229 self.x = x
172 def bar(self, y):
230 def bar(self, y):
173 return self.x + y
231 return self.x + y
174 @property
232 @property
175 def quux(self):
233 def quux(self):
176 return 42
234 return 42
177 def zzz(self):
235 def zzz(self):
178 '''This method will be deleted below'''
236 '''This method will be deleted below'''
179 return 99
237 return 99
180
238
181 class Bar: # old-style class: weakref doesn't work for it on Python < 2.7
239 class Bar: # old-style class: weakref doesn't work for it on Python < 2.7
182 def foo(self):
240 def foo(self):
183 return 1
241 return 1
184 """)
242 """)
185
243
186 #
244 #
187 # Import module, and mark for reloading
245 # Import module, and mark for reloading
188 #
246 #
189 if use_aimport:
247 if use_aimport:
190 self.shell.magic_autoreload("1")
248 self.shell.magic_autoreload("1")
191 self.shell.magic_aimport(mod_name)
249 self.shell.magic_aimport(mod_name)
192 stream = StringIO()
250 stream = StringIO()
193 self.shell.magic_aimport("", stream=stream)
251 self.shell.magic_aimport("", stream=stream)
194 nt.assert_in(("Modules to reload:\n%s" % mod_name), stream.getvalue())
252 nt.assert_in(("Modules to reload:\n%s" % mod_name), stream.getvalue())
195
253
196 with nt.assert_raises(ImportError):
254 with nt.assert_raises(ImportError):
197 self.shell.magic_aimport("tmpmod_as318989e89ds")
255 self.shell.magic_aimport("tmpmod_as318989e89ds")
198 else:
256 else:
199 self.shell.magic_autoreload("2")
257 self.shell.magic_autoreload("2")
200 self.shell.run_code("import %s" % mod_name)
258 self.shell.run_code("import %s" % mod_name)
201 stream = StringIO()
259 stream = StringIO()
202 self.shell.magic_aimport("", stream=stream)
260 self.shell.magic_aimport("", stream=stream)
203 nt.assert_true("Modules to reload:\nall-except-skipped" in
261 nt.assert_true("Modules to reload:\nall-except-skipped" in
204 stream.getvalue())
262 stream.getvalue())
205 nt.assert_in(mod_name, self.shell.ns)
263 nt.assert_in(mod_name, self.shell.ns)
206
264
207 mod = sys.modules[mod_name]
265 mod = sys.modules[mod_name]
208
266
209 #
267 #
210 # Test module contents
268 # Test module contents
211 #
269 #
212 old_foo = mod.foo
270 old_foo = mod.foo
213 old_obj = mod.Baz(9)
271 old_obj = mod.Baz(9)
214 old_obj2 = mod.Bar()
272 old_obj2 = mod.Bar()
215
273
216 def check_module_contents():
274 def check_module_contents():
217 nt.assert_equal(mod.x, 9)
275 nt.assert_equal(mod.x, 9)
218 nt.assert_equal(mod.z, 123)
276 nt.assert_equal(mod.z, 123)
219
277
220 nt.assert_equal(old_foo(0), 3)
278 nt.assert_equal(old_foo(0), 3)
221 nt.assert_equal(mod.foo(0), 3)
279 nt.assert_equal(mod.foo(0), 3)
222
280
223 obj = mod.Baz(9)
281 obj = mod.Baz(9)
224 nt.assert_equal(old_obj.bar(1), 10)
282 nt.assert_equal(old_obj.bar(1), 10)
225 nt.assert_equal(obj.bar(1), 10)
283 nt.assert_equal(obj.bar(1), 10)
226 nt.assert_equal(obj.quux, 42)
284 nt.assert_equal(obj.quux, 42)
227 nt.assert_equal(obj.zzz(), 99)
285 nt.assert_equal(obj.zzz(), 99)
228
286
229 obj2 = mod.Bar()
287 obj2 = mod.Bar()
230 nt.assert_equal(old_obj2.foo(), 1)
288 nt.assert_equal(old_obj2.foo(), 1)
231 nt.assert_equal(obj2.foo(), 1)
289 nt.assert_equal(obj2.foo(), 1)
232
290
233 check_module_contents()
291 check_module_contents()
234
292
235 #
293 #
236 # Simulate a failed reload: no reload should occur and exactly
294 # Simulate a failed reload: no reload should occur and exactly
237 # one error message should be printed
295 # one error message should be printed
238 #
296 #
239 self.write_file(mod_fn, """
297 self.write_file(mod_fn, """
240 a syntax error
298 a syntax error
241 """)
299 """)
242
300
243 with tt.AssertPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
301 with tt.AssertPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
244 self.shell.run_code("pass") # trigger reload
302 self.shell.run_code("pass") # trigger reload
245 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
303 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
246 self.shell.run_code("pass") # trigger another reload
304 self.shell.run_code("pass") # trigger another reload
247 check_module_contents()
305 check_module_contents()
248
306
249 #
307 #
250 # Rewrite module (this time reload should succeed)
308 # Rewrite module (this time reload should succeed)
251 #
309 #
252 self.write_file(mod_fn, """
310 self.write_file(mod_fn, """
253 x = 10
311 x = 10
254
312
255 def foo(y):
313 def foo(y):
256 return y + 4
314 return y + 4
257
315
258 class Baz(object):
316 class Baz(object):
259 def __init__(self, x):
317 def __init__(self, x):
260 self.x = x
318 self.x = x
261 def bar(self, y):
319 def bar(self, y):
262 return self.x + y + 1
320 return self.x + y + 1
263 @property
321 @property
264 def quux(self):
322 def quux(self):
265 return 43
323 return 43
266
324
267 class Bar: # old-style class
325 class Bar: # old-style class
268 def foo(self):
326 def foo(self):
269 return 2
327 return 2
270 """)
328 """)
271
329
272 def check_module_contents():
330 def check_module_contents():
273 nt.assert_equal(mod.x, 10)
331 nt.assert_equal(mod.x, 10)
274 nt.assert_false(hasattr(mod, 'z'))
332 nt.assert_false(hasattr(mod, 'z'))
275
333
276 nt.assert_equal(old_foo(0), 4) # superreload magic!
334 nt.assert_equal(old_foo(0), 4) # superreload magic!
277 nt.assert_equal(mod.foo(0), 4)
335 nt.assert_equal(mod.foo(0), 4)
278
336
279 obj = mod.Baz(9)
337 obj = mod.Baz(9)
280 nt.assert_equal(old_obj.bar(1), 11) # superreload magic!
338 nt.assert_equal(old_obj.bar(1), 11) # superreload magic!
281 nt.assert_equal(obj.bar(1), 11)
339 nt.assert_equal(obj.bar(1), 11)
282
340
283 nt.assert_equal(old_obj.quux, 43)
341 nt.assert_equal(old_obj.quux, 43)
284 nt.assert_equal(obj.quux, 43)
342 nt.assert_equal(obj.quux, 43)
285
343
286 nt.assert_false(hasattr(old_obj, 'zzz'))
344 nt.assert_false(hasattr(old_obj, 'zzz'))
287 nt.assert_false(hasattr(obj, 'zzz'))
345 nt.assert_false(hasattr(obj, 'zzz'))
288
346
289 obj2 = mod.Bar()
347 obj2 = mod.Bar()
290 nt.assert_equal(old_obj2.foo(), 2)
348 nt.assert_equal(old_obj2.foo(), 2)
291 nt.assert_equal(obj2.foo(), 2)
349 nt.assert_equal(obj2.foo(), 2)
292
350
293 self.shell.run_code("pass") # trigger reload
351 self.shell.run_code("pass") # trigger reload
294 check_module_contents()
352 check_module_contents()
295
353
296 #
354 #
297 # Another failure case: deleted file (shouldn't reload)
355 # Another failure case: deleted file (shouldn't reload)
298 #
356 #
299 os.unlink(mod_fn)
357 os.unlink(mod_fn)
300
358
301 self.shell.run_code("pass") # trigger reload
359 self.shell.run_code("pass") # trigger reload
302 check_module_contents()
360 check_module_contents()
303
361
304 #
362 #
305 # Disable autoreload and rewrite module: no reload should occur
363 # Disable autoreload and rewrite module: no reload should occur
306 #
364 #
307 if use_aimport:
365 if use_aimport:
308 self.shell.magic_aimport("-" + mod_name)
366 self.shell.magic_aimport("-" + mod_name)
309 stream = StringIO()
367 stream = StringIO()
310 self.shell.magic_aimport("", stream=stream)
368 self.shell.magic_aimport("", stream=stream)
311 nt.assert_true(("Modules to skip:\n%s" % mod_name) in
369 nt.assert_true(("Modules to skip:\n%s" % mod_name) in
312 stream.getvalue())
370 stream.getvalue())
313
371
314 # This should succeed, although no such module exists
372 # This should succeed, although no such module exists
315 self.shell.magic_aimport("-tmpmod_as318989e89ds")
373 self.shell.magic_aimport("-tmpmod_as318989e89ds")
316 else:
374 else:
317 self.shell.magic_autoreload("0")
375 self.shell.magic_autoreload("0")
318
376
319 self.write_file(mod_fn, """
377 self.write_file(mod_fn, """
320 x = -99
378 x = -99
321 """)
379 """)
322
380
323 self.shell.run_code("pass") # trigger reload
381 self.shell.run_code("pass") # trigger reload
324 self.shell.run_code("pass")
382 self.shell.run_code("pass")
325 check_module_contents()
383 check_module_contents()
326
384
327 #
385 #
328 # Re-enable autoreload: reload should now occur
386 # Re-enable autoreload: reload should now occur
329 #
387 #
330 if use_aimport:
388 if use_aimport:
331 self.shell.magic_aimport(mod_name)
389 self.shell.magic_aimport(mod_name)
332 else:
390 else:
333 self.shell.magic_autoreload("")
391 self.shell.magic_autoreload("")
334
392
335 self.shell.run_code("pass") # trigger reload
393 self.shell.run_code("pass") # trigger reload
336 nt.assert_equal(mod.x, -99)
394 nt.assert_equal(mod.x, -99)
337
395
338 def test_smoketest_aimport(self):
396 def test_smoketest_aimport(self):
339 self._check_smoketest(use_aimport=True)
397 self._check_smoketest(use_aimport=True)
340
398
341 def test_smoketest_autoreload(self):
399 def test_smoketest_autoreload(self):
342 self._check_smoketest(use_aimport=False)
400 self._check_smoketest(use_aimport=False)
401
402
403
404
@@ -1,454 +1,460 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """IPython Test Suite Runner.
2 """IPython Test Suite Runner.
3
3
4 This module provides a main entry point to a user script to test IPython
4 This module provides a main entry point to a user script to test IPython
5 itself from the command line. There are two ways of running this script:
5 itself from the command line. There are two ways of running this script:
6
6
7 1. With the syntax `iptest all`. This runs our entire test suite by
7 1. With the syntax `iptest all`. This runs our entire test suite by
8 calling this script (with different arguments) recursively. This
8 calling this script (with different arguments) recursively. This
9 causes modules and package to be tested in different processes, using nose
9 causes modules and package to be tested in different processes, using nose
10 or trial where appropriate.
10 or trial where appropriate.
11 2. With the regular nose syntax, like `iptest -vvs IPython`. In this form
11 2. With the regular nose syntax, like `iptest -vvs IPython`. In this form
12 the script simply calls nose, but with special command line flags and
12 the script simply calls nose, but with special command line flags and
13 plugins loaded.
13 plugins loaded.
14
14
15 """
15 """
16
16
17 # Copyright (c) IPython Development Team.
17 # Copyright (c) IPython Development Team.
18 # Distributed under the terms of the Modified BSD License.
18 # Distributed under the terms of the Modified BSD License.
19
19
20
20
21 import glob
21 import glob
22 from io import BytesIO
22 from io import BytesIO
23 import os
23 import os
24 import os.path as path
24 import os.path as path
25 import sys
25 import sys
26 from threading import Thread, Lock, Event
26 from threading import Thread, Lock, Event
27 import warnings
27 import warnings
28
28
29 import nose.plugins.builtin
29 import nose.plugins.builtin
30 from nose.plugins.xunit import Xunit
30 from nose.plugins.xunit import Xunit
31 from nose import SkipTest
31 from nose import SkipTest
32 from nose.core import TestProgram
32 from nose.core import TestProgram
33 from nose.plugins import Plugin
33 from nose.plugins import Plugin
34 from nose.util import safe_str
34 from nose.util import safe_str
35
35
36 from IPython import version_info
36 from IPython import version_info
37 from IPython.utils.py3compat import decode
37 from IPython.utils.py3compat import decode
38 from IPython.utils.importstring import import_item
38 from IPython.utils.importstring import import_item
39 from IPython.testing.plugin.ipdoctest import IPythonDoctest
39 from IPython.testing.plugin.ipdoctest import IPythonDoctest
40 from IPython.external.decorators import KnownFailure, knownfailureif
40 from IPython.external.decorators import KnownFailure, knownfailureif
41
41
42 pjoin = path.join
42 pjoin = path.join
43
43
44
44
45 # Enable printing all warnings raise by IPython's modules
45 # Enable printing all warnings raise by IPython's modules
46 warnings.filterwarnings('ignore', message='.*Matplotlib is building the font cache.*', category=UserWarning, module='.*')
46 warnings.filterwarnings('ignore', message='.*Matplotlib is building the font cache.*', category=UserWarning, module='.*')
47 warnings.filterwarnings('error', message='.*', category=ResourceWarning, module='.*')
47 warnings.filterwarnings('error', message='.*', category=ResourceWarning, module='.*')
48 warnings.filterwarnings('error', message=".*{'config': True}.*", category=DeprecationWarning, module='IPy.*')
48 warnings.filterwarnings('error', message=".*{'config': True}.*", category=DeprecationWarning, module='IPy.*')
49 warnings.filterwarnings('default', message='.*', category=Warning, module='IPy.*')
49 warnings.filterwarnings('default', message='.*', category=Warning, module='IPy.*')
50
50
51 warnings.filterwarnings('error', message='.*apply_wrapper.*', category=DeprecationWarning, module='.*')
51 warnings.filterwarnings('error', message='.*apply_wrapper.*', category=DeprecationWarning, module='.*')
52 warnings.filterwarnings('error', message='.*make_label_dec', category=DeprecationWarning, module='.*')
52 warnings.filterwarnings('error', message='.*make_label_dec', category=DeprecationWarning, module='.*')
53 warnings.filterwarnings('error', message='.*decorated_dummy.*', category=DeprecationWarning, module='.*')
53 warnings.filterwarnings('error', message='.*decorated_dummy.*', category=DeprecationWarning, module='.*')
54 warnings.filterwarnings('error', message='.*skip_file_no_x11.*', category=DeprecationWarning, module='.*')
54 warnings.filterwarnings('error', message='.*skip_file_no_x11.*', category=DeprecationWarning, module='.*')
55 warnings.filterwarnings('error', message='.*onlyif_any_cmd_exists.*', category=DeprecationWarning, module='.*')
55 warnings.filterwarnings('error', message='.*onlyif_any_cmd_exists.*', category=DeprecationWarning, module='.*')
56
56
57 warnings.filterwarnings('error', message='.*disable_gui.*', category=DeprecationWarning, module='.*')
57 warnings.filterwarnings('error', message='.*disable_gui.*', category=DeprecationWarning, module='.*')
58
58
59 warnings.filterwarnings('error', message='.*ExceptionColors global is deprecated.*', category=DeprecationWarning, module='.*')
59 warnings.filterwarnings('error', message='.*ExceptionColors global is deprecated.*', category=DeprecationWarning, module='.*')
60
60
61 # Jedi older versions
61 # Jedi older versions
62 warnings.filterwarnings(
62 warnings.filterwarnings(
63 'error', message='.*elementwise != comparison failed and.*', category=FutureWarning, module='.*')
63 'error', message='.*elementwise != comparison failed and.*', category=FutureWarning, module='.*')
64
64
65 if version_info < (6,):
65 if version_info < (6,):
66 # nose.tools renames all things from `camelCase` to `snake_case` which raise an
66 # nose.tools renames all things from `camelCase` to `snake_case` which raise an
67 # warning with the runner they also import from standard import library. (as of Dec 2015)
67 # warning with the runner they also import from standard import library. (as of Dec 2015)
68 # Ignore, let's revisit that in a couple of years for IPython 6.
68 # Ignore, let's revisit that in a couple of years for IPython 6.
69 warnings.filterwarnings(
69 warnings.filterwarnings(
70 'ignore', message='.*Please use assertEqual instead', category=Warning, module='IPython.*')
70 'ignore', message='.*Please use assertEqual instead', category=Warning, module='IPython.*')
71
71
72 if version_info < (8,):
72 if version_info < (8,):
73 warnings.filterwarnings('ignore', message='.*Completer.complete.*',
73 warnings.filterwarnings('ignore', message='.*Completer.complete.*',
74 category=PendingDeprecationWarning, module='.*')
74 category=PendingDeprecationWarning, module='.*')
75 else:
75 else:
76 warnings.warn(
76 warnings.warn(
77 'Completer.complete was pending deprecation and should be changed to Deprecated', FutureWarning)
77 'Completer.complete was pending deprecation and should be changed to Deprecated', FutureWarning)
78
78
79
79
80
80
81 # ------------------------------------------------------------------------------
81 # ------------------------------------------------------------------------------
82 # Monkeypatch Xunit to count known failures as skipped.
82 # Monkeypatch Xunit to count known failures as skipped.
83 # ------------------------------------------------------------------------------
83 # ------------------------------------------------------------------------------
84 def monkeypatch_xunit():
84 def monkeypatch_xunit():
85 try:
85 try:
86 knownfailureif(True)(lambda: None)()
86 knownfailureif(True)(lambda: None)()
87 except Exception as e:
87 except Exception as e:
88 KnownFailureTest = type(e)
88 KnownFailureTest = type(e)
89
89
90 def addError(self, test, err, capt=None):
90 def addError(self, test, err, capt=None):
91 if issubclass(err[0], KnownFailureTest):
91 if issubclass(err[0], KnownFailureTest):
92 err = (SkipTest,) + err[1:]
92 err = (SkipTest,) + err[1:]
93 return self.orig_addError(test, err, capt)
93 return self.orig_addError(test, err, capt)
94
94
95 Xunit.orig_addError = Xunit.addError
95 Xunit.orig_addError = Xunit.addError
96 Xunit.addError = addError
96 Xunit.addError = addError
97
97
98 #-----------------------------------------------------------------------------
98 #-----------------------------------------------------------------------------
99 # Check which dependencies are installed and greater than minimum version.
99 # Check which dependencies are installed and greater than minimum version.
100 #-----------------------------------------------------------------------------
100 #-----------------------------------------------------------------------------
101 def extract_version(mod):
101 def extract_version(mod):
102 return mod.__version__
102 return mod.__version__
103
103
104 def test_for(item, min_version=None, callback=extract_version):
104 def test_for(item, min_version=None, callback=extract_version):
105 """Test to see if item is importable, and optionally check against a minimum
105 """Test to see if item is importable, and optionally check against a minimum
106 version.
106 version.
107
107
108 If min_version is given, the default behavior is to check against the
108 If min_version is given, the default behavior is to check against the
109 `__version__` attribute of the item, but specifying `callback` allows you to
109 `__version__` attribute of the item, but specifying `callback` allows you to
110 extract the value you are interested in. e.g::
110 extract the value you are interested in. e.g::
111
111
112 In [1]: import sys
112 In [1]: import sys
113
113
114 In [2]: from IPython.testing.iptest import test_for
114 In [2]: from IPython.testing.iptest import test_for
115
115
116 In [3]: test_for('sys', (2,6), callback=lambda sys: sys.version_info)
116 In [3]: test_for('sys', (2,6), callback=lambda sys: sys.version_info)
117 Out[3]: True
117 Out[3]: True
118
118
119 """
119 """
120 try:
120 try:
121 check = import_item(item)
121 check = import_item(item)
122 except (ImportError, RuntimeError):
122 except (ImportError, RuntimeError):
123 # GTK reports Runtime error if it can't be initialized even if it's
123 # GTK reports Runtime error if it can't be initialized even if it's
124 # importable.
124 # importable.
125 return False
125 return False
126 else:
126 else:
127 if min_version:
127 if min_version:
128 if callback:
128 if callback:
129 # extra processing step to get version to compare
129 # extra processing step to get version to compare
130 check = callback(check)
130 check = callback(check)
131
131
132 return check >= min_version
132 return check >= min_version
133 else:
133 else:
134 return True
134 return True
135
135
136 # Global dict where we can store information on what we have and what we don't
136 # Global dict where we can store information on what we have and what we don't
137 # have available at test run time
137 # have available at test run time
138 have = {'matplotlib': test_for('matplotlib'),
138 have = {'matplotlib': test_for('matplotlib'),
139 'pygments': test_for('pygments'),
139 'pygments': test_for('pygments'),
140 'sqlite3': test_for('sqlite3')}
140 'sqlite3': test_for('sqlite3')}
141
141
142 #-----------------------------------------------------------------------------
142 #-----------------------------------------------------------------------------
143 # Test suite definitions
143 # Test suite definitions
144 #-----------------------------------------------------------------------------
144 #-----------------------------------------------------------------------------
145
145
146 test_group_names = ['core',
146 test_group_names = ['core',
147 'extensions', 'lib', 'terminal', 'testing', 'utils',
147 'extensions', 'lib', 'terminal', 'testing', 'utils',
148 ]
148 ]
149
149
150 class TestSection(object):
150 class TestSection(object):
151 def __init__(self, name, includes):
151 def __init__(self, name, includes):
152 self.name = name
152 self.name = name
153 self.includes = includes
153 self.includes = includes
154 self.excludes = []
154 self.excludes = []
155 self.dependencies = []
155 self.dependencies = []
156 self.enabled = True
156 self.enabled = True
157
157
158 def exclude(self, module):
158 def exclude(self, module):
159 if not module.startswith('IPython'):
159 if not module.startswith('IPython'):
160 module = self.includes[0] + "." + module
160 module = self.includes[0] + "." + module
161 self.excludes.append(module.replace('.', os.sep))
161 self.excludes.append(module.replace('.', os.sep))
162
162
163 def requires(self, *packages):
163 def requires(self, *packages):
164 self.dependencies.extend(packages)
164 self.dependencies.extend(packages)
165
165
166 @property
166 @property
167 def will_run(self):
167 def will_run(self):
168 return self.enabled and all(have[p] for p in self.dependencies)
168 return self.enabled and all(have[p] for p in self.dependencies)
169
169
170 # Name -> (include, exclude, dependencies_met)
170 # Name -> (include, exclude, dependencies_met)
171 test_sections = {n:TestSection(n, ['IPython.%s' % n]) for n in test_group_names}
171 test_sections = {n:TestSection(n, ['IPython.%s' % n]) for n in test_group_names}
172
172
173
173
174 # Exclusions and dependencies
174 # Exclusions and dependencies
175 # ---------------------------
175 # ---------------------------
176
176
177 # core:
177 # core:
178 sec = test_sections['core']
178 sec = test_sections['core']
179 if not have['sqlite3']:
179 if not have['sqlite3']:
180 sec.exclude('tests.test_history')
180 sec.exclude('tests.test_history')
181 sec.exclude('history')
181 sec.exclude('history')
182 if not have['matplotlib']:
182 if not have['matplotlib']:
183 sec.exclude('pylabtools'),
183 sec.exclude('pylabtools'),
184 sec.exclude('tests.test_pylabtools')
184 sec.exclude('tests.test_pylabtools')
185
185
186 # lib:
186 # lib:
187 sec = test_sections['lib']
187 sec = test_sections['lib']
188 sec.exclude('kernel')
188 sec.exclude('kernel')
189 if not have['pygments']:
189 if not have['pygments']:
190 sec.exclude('tests.test_lexers')
190 sec.exclude('tests.test_lexers')
191 # We do this unconditionally, so that the test suite doesn't import
191 # We do this unconditionally, so that the test suite doesn't import
192 # gtk, changing the default encoding and masking some unicode bugs.
192 # gtk, changing the default encoding and masking some unicode bugs.
193 sec.exclude('inputhookgtk')
193 sec.exclude('inputhookgtk')
194 # We also do this unconditionally, because wx can interfere with Unix signals.
194 # We also do this unconditionally, because wx can interfere with Unix signals.
195 # There are currently no tests for it anyway.
195 # There are currently no tests for it anyway.
196 sec.exclude('inputhookwx')
196 sec.exclude('inputhookwx')
197 # Testing inputhook will need a lot of thought, to figure out
197 # Testing inputhook will need a lot of thought, to figure out
198 # how to have tests that don't lock up with the gui event
198 # how to have tests that don't lock up with the gui event
199 # loops in the picture
199 # loops in the picture
200 sec.exclude('inputhook')
200 sec.exclude('inputhook')
201
201
202 # testing:
202 # testing:
203 sec = test_sections['testing']
203 sec = test_sections['testing']
204 # These have to be skipped on win32 because they use echo, rm, cd, etc.
204 # These have to be skipped on win32 because they use echo, rm, cd, etc.
205 # See ticket https://github.com/ipython/ipython/issues/87
205 # See ticket https://github.com/ipython/ipython/issues/87
206 if sys.platform == 'win32':
206 if sys.platform == 'win32':
207 sec.exclude('plugin.test_exampleip')
207 sec.exclude('plugin.test_exampleip')
208 sec.exclude('plugin.dtexample')
208 sec.exclude('plugin.dtexample')
209
209
210 # don't run jupyter_console tests found via shim
210 # don't run jupyter_console tests found via shim
211 test_sections['terminal'].exclude('console')
211 test_sections['terminal'].exclude('console')
212
212
213 # extensions:
213 # extensions:
214 sec = test_sections['extensions']
214 sec = test_sections['extensions']
215 # This is deprecated in favour of rpy2
215 # This is deprecated in favour of rpy2
216 sec.exclude('rmagic')
216 sec.exclude('rmagic')
217 # autoreload does some strange stuff, so move it to its own test section
217 # autoreload does some strange stuff, so move it to its own test section
218 sec.exclude('autoreload')
218 sec.exclude('autoreload')
219 sec.exclude('tests.test_autoreload')
219 sec.exclude('tests.test_autoreload')
220 test_sections['autoreload'] = TestSection('autoreload',
220 test_sections['autoreload'] = TestSection('autoreload',
221 ['IPython.extensions.autoreload', 'IPython.extensions.tests.test_autoreload'])
221 ['IPython.extensions.autoreload', 'IPython.extensions.tests.test_autoreload'])
222 test_group_names.append('autoreload')
222 test_group_names.append('autoreload')
223
223
224
224
225 #-----------------------------------------------------------------------------
225 #-----------------------------------------------------------------------------
226 # Functions and classes
226 # Functions and classes
227 #-----------------------------------------------------------------------------
227 #-----------------------------------------------------------------------------
228
228
229 def check_exclusions_exist():
229 def check_exclusions_exist():
230 from IPython.paths import get_ipython_package_dir
230 from IPython.paths import get_ipython_package_dir
231 from warnings import warn
231 from warnings import warn
232 parent = os.path.dirname(get_ipython_package_dir())
232 parent = os.path.dirname(get_ipython_package_dir())
233 for sec in test_sections:
233 for sec in test_sections:
234 for pattern in sec.exclusions:
234 for pattern in sec.exclusions:
235 fullpath = pjoin(parent, pattern)
235 fullpath = pjoin(parent, pattern)
236 if not os.path.exists(fullpath) and not glob.glob(fullpath + '.*'):
236 if not os.path.exists(fullpath) and not glob.glob(fullpath + '.*'):
237 warn("Excluding nonexistent file: %r" % pattern)
237 warn("Excluding nonexistent file: %r" % pattern)
238
238
239
239
240 class ExclusionPlugin(Plugin):
240 class ExclusionPlugin(Plugin):
241 """A nose plugin to effect our exclusions of files and directories.
241 """A nose plugin to effect our exclusions of files and directories.
242 """
242 """
243 name = 'exclusions'
243 name = 'exclusions'
244 score = 3000 # Should come before any other plugins
244 score = 3000 # Should come before any other plugins
245
245
246 def __init__(self, exclude_patterns=None):
246 def __init__(self, exclude_patterns=None):
247 """
247 """
248 Parameters
248 Parameters
249 ----------
249 ----------
250
250
251 exclude_patterns : sequence of strings, optional
251 exclude_patterns : sequence of strings, optional
252 Filenames containing these patterns (as raw strings, not as regular
252 Filenames containing these patterns (as raw strings, not as regular
253 expressions) are excluded from the tests.
253 expressions) are excluded from the tests.
254 """
254 """
255 self.exclude_patterns = exclude_patterns or []
255 self.exclude_patterns = exclude_patterns or []
256 super(ExclusionPlugin, self).__init__()
256 super(ExclusionPlugin, self).__init__()
257
257
258 def options(self, parser, env=os.environ):
258 def options(self, parser, env=os.environ):
259 Plugin.options(self, parser, env)
259 Plugin.options(self, parser, env)
260
260
261 def configure(self, options, config):
261 def configure(self, options, config):
262 Plugin.configure(self, options, config)
262 Plugin.configure(self, options, config)
263 # Override nose trying to disable plugin.
263 # Override nose trying to disable plugin.
264 self.enabled = True
264 self.enabled = True
265
265
266 def wantFile(self, filename):
266 def wantFile(self, filename):
267 """Return whether the given filename should be scanned for tests.
267 """Return whether the given filename should be scanned for tests.
268 """
268 """
269 if any(pat in filename for pat in self.exclude_patterns):
269 if any(pat in filename for pat in self.exclude_patterns):
270 return False
270 return False
271 return None
271 return None
272
272
273 def wantDirectory(self, directory):
273 def wantDirectory(self, directory):
274 """Return whether the given directory should be scanned for tests.
274 """Return whether the given directory should be scanned for tests.
275 """
275 """
276 if any(pat in directory for pat in self.exclude_patterns):
276 if any(pat in directory for pat in self.exclude_patterns):
277 return False
277 return False
278 return None
278 return None
279
279
280
280
281 class StreamCapturer(Thread):
281 class StreamCapturer(Thread):
282 daemon = True # Don't hang if main thread crashes
282 daemon = True # Don't hang if main thread crashes
283 started = False
283 started = False
284 def __init__(self, echo=False):
284 def __init__(self, echo=False):
285 super(StreamCapturer, self).__init__()
285 super(StreamCapturer, self).__init__()
286 self.echo = echo
286 self.echo = echo
287 self.streams = []
287 self.streams = []
288 self.buffer = BytesIO()
288 self.buffer = BytesIO()
289 self.readfd, self.writefd = os.pipe()
289 self.readfd, self.writefd = os.pipe()
290 self.buffer_lock = Lock()
290 self.buffer_lock = Lock()
291 self.stop = Event()
291 self.stop = Event()
292
292
293 def run(self):
293 def run(self):
294 self.started = True
294 self.started = True
295
295
296 while not self.stop.is_set():
296 while not self.stop.is_set():
297 chunk = os.read(self.readfd, 1024)
297 chunk = os.read(self.readfd, 1024)
298
298
299 with self.buffer_lock:
299 with self.buffer_lock:
300 self.buffer.write(chunk)
300 self.buffer.write(chunk)
301 if self.echo:
301 if self.echo:
302 sys.stdout.write(decode(chunk))
302 sys.stdout.write(decode(chunk))
303
303
304 os.close(self.readfd)
304 os.close(self.readfd)
305 os.close(self.writefd)
305 os.close(self.writefd)
306
306
307 def reset_buffer(self):
307 def reset_buffer(self):
308 with self.buffer_lock:
308 with self.buffer_lock:
309 self.buffer.truncate(0)
309 self.buffer.truncate(0)
310 self.buffer.seek(0)
310 self.buffer.seek(0)
311
311
312 def get_buffer(self):
312 def get_buffer(self):
313 with self.buffer_lock:
313 with self.buffer_lock:
314 return self.buffer.getvalue()
314 return self.buffer.getvalue()
315
315
316 def ensure_started(self):
316 def ensure_started(self):
317 if not self.started:
317 if not self.started:
318 self.start()
318 self.start()
319
319
320 def halt(self):
320 def halt(self):
321 """Safely stop the thread."""
321 """Safely stop the thread."""
322 if not self.started:
322 if not self.started:
323 return
323 return
324
324
325 self.stop.set()
325 self.stop.set()
326 os.write(self.writefd, b'\0') # Ensure we're not locked in a read()
326 os.write(self.writefd, b'\0') # Ensure we're not locked in a read()
327 self.join()
327 self.join()
328
328
329 class SubprocessStreamCapturePlugin(Plugin):
329 class SubprocessStreamCapturePlugin(Plugin):
330 name='subprocstreams'
330 name='subprocstreams'
331 def __init__(self):
331 def __init__(self):
332 Plugin.__init__(self)
332 Plugin.__init__(self)
333 self.stream_capturer = StreamCapturer()
333 self.stream_capturer = StreamCapturer()
334 self.destination = os.environ.get('IPTEST_SUBPROC_STREAMS', 'capture')
334 self.destination = os.environ.get('IPTEST_SUBPROC_STREAMS', 'capture')
335 # This is ugly, but distant parts of the test machinery need to be able
335 # This is ugly, but distant parts of the test machinery need to be able
336 # to redirect streams, so we make the object globally accessible.
336 # to redirect streams, so we make the object globally accessible.
337 nose.iptest_stdstreams_fileno = self.get_write_fileno
337 nose.iptest_stdstreams_fileno = self.get_write_fileno
338
338
339 def get_write_fileno(self):
339 def get_write_fileno(self):
340 if self.destination == 'capture':
340 if self.destination == 'capture':
341 self.stream_capturer.ensure_started()
341 self.stream_capturer.ensure_started()
342 return self.stream_capturer.writefd
342 return self.stream_capturer.writefd
343 elif self.destination == 'discard':
343 elif self.destination == 'discard':
344 return os.open(os.devnull, os.O_WRONLY)
344 return os.open(os.devnull, os.O_WRONLY)
345 else:
345 else:
346 return sys.__stdout__.fileno()
346 return sys.__stdout__.fileno()
347
347
348 def configure(self, options, config):
348 def configure(self, options, config):
349 Plugin.configure(self, options, config)
349 Plugin.configure(self, options, config)
350 # Override nose trying to disable plugin.
350 # Override nose trying to disable plugin.
351 if self.destination == 'capture':
351 if self.destination == 'capture':
352 self.enabled = True
352 self.enabled = True
353
353
354 def startTest(self, test):
354 def startTest(self, test):
355 # Reset log capture
355 # Reset log capture
356 self.stream_capturer.reset_buffer()
356 self.stream_capturer.reset_buffer()
357
357
358 def formatFailure(self, test, err):
358 def formatFailure(self, test, err):
359 # Show output
359 # Show output
360 ec, ev, tb = err
360 ec, ev, tb = err
361 captured = self.stream_capturer.get_buffer().decode('utf-8', 'replace')
361 captured = self.stream_capturer.get_buffer().decode('utf-8', 'replace')
362 if captured.strip():
362 if captured.strip():
363 ev = safe_str(ev)
363 ev = safe_str(ev)
364 out = [ev, '>> begin captured subprocess output <<',
364 out = [ev, '>> begin captured subprocess output <<',
365 captured,
365 captured,
366 '>> end captured subprocess output <<']
366 '>> end captured subprocess output <<']
367 return ec, '\n'.join(out), tb
367 return ec, '\n'.join(out), tb
368
368
369 return err
369 return err
370
370
371 formatError = formatFailure
371 formatError = formatFailure
372
372
373 def finalize(self, result):
373 def finalize(self, result):
374 self.stream_capturer.halt()
374 self.stream_capturer.halt()
375
375
376
376
377 def run_iptest():
377 def run_iptest():
378 """Run the IPython test suite using nose.
378 """Run the IPython test suite using nose.
379
379
380 This function is called when this script is **not** called with the form
380 This function is called when this script is **not** called with the form
381 `iptest all`. It simply calls nose with appropriate command line flags
381 `iptest all`. It simply calls nose with appropriate command line flags
382 and accepts all of the standard nose arguments.
382 and accepts all of the standard nose arguments.
383 """
383 """
384 # Apply our monkeypatch to Xunit
384 # Apply our monkeypatch to Xunit
385 if '--with-xunit' in sys.argv and not hasattr(Xunit, 'orig_addError'):
385 if '--with-xunit' in sys.argv and not hasattr(Xunit, 'orig_addError'):
386 monkeypatch_xunit()
386 monkeypatch_xunit()
387
387
388 arg1 = sys.argv[1]
388 arg1 = sys.argv[1]
389 if arg1.startswith('IPython/'):
390 if arg1.endswith('.py'):
391 arg1 = arg1[:-3]
392 sys.argv[1] = arg1.replace('/', '.')
393
394 arg1 = sys.argv[1]
389 if arg1 in test_sections:
395 if arg1 in test_sections:
390 section = test_sections[arg1]
396 section = test_sections[arg1]
391 sys.argv[1:2] = section.includes
397 sys.argv[1:2] = section.includes
392 elif arg1.startswith('IPython.') and arg1[8:] in test_sections:
398 elif arg1.startswith('IPython.') and arg1[8:] in test_sections:
393 section = test_sections[arg1[8:]]
399 section = test_sections[arg1[8:]]
394 sys.argv[1:2] = section.includes
400 sys.argv[1:2] = section.includes
395 else:
401 else:
396 section = TestSection(arg1, includes=[arg1])
402 section = TestSection(arg1, includes=[arg1])
397
403
398
404
399 argv = sys.argv + [ '--detailed-errors', # extra info in tracebacks
405 argv = sys.argv + [ '--detailed-errors', # extra info in tracebacks
400 # We add --exe because of setuptools' imbecility (it
406 # We add --exe because of setuptools' imbecility (it
401 # blindly does chmod +x on ALL files). Nose does the
407 # blindly does chmod +x on ALL files). Nose does the
402 # right thing and it tries to avoid executables,
408 # right thing and it tries to avoid executables,
403 # setuptools unfortunately forces our hand here. This
409 # setuptools unfortunately forces our hand here. This
404 # has been discussed on the distutils list and the
410 # has been discussed on the distutils list and the
405 # setuptools devs refuse to fix this problem!
411 # setuptools devs refuse to fix this problem!
406 '--exe',
412 '--exe',
407 ]
413 ]
408 if '-a' not in argv and '-A' not in argv:
414 if '-a' not in argv and '-A' not in argv:
409 argv = argv + ['-a', '!crash']
415 argv = argv + ['-a', '!crash']
410
416
411 if nose.__version__ >= '0.11':
417 if nose.__version__ >= '0.11':
412 # I don't fully understand why we need this one, but depending on what
418 # I don't fully understand why we need this one, but depending on what
413 # directory the test suite is run from, if we don't give it, 0 tests
419 # directory the test suite is run from, if we don't give it, 0 tests
414 # get run. Specifically, if the test suite is run from the source dir
420 # get run. Specifically, if the test suite is run from the source dir
415 # with an argument (like 'iptest.py IPython.core', 0 tests are run,
421 # with an argument (like 'iptest.py IPython.core', 0 tests are run,
416 # even if the same call done in this directory works fine). It appears
422 # even if the same call done in this directory works fine). It appears
417 # that if the requested package is in the current dir, nose bails early
423 # that if the requested package is in the current dir, nose bails early
418 # by default. Since it's otherwise harmless, leave it in by default
424 # by default. Since it's otherwise harmless, leave it in by default
419 # for nose >= 0.11, though unfortunately nose 0.10 doesn't support it.
425 # for nose >= 0.11, though unfortunately nose 0.10 doesn't support it.
420 argv.append('--traverse-namespace')
426 argv.append('--traverse-namespace')
421
427
422 plugins = [ ExclusionPlugin(section.excludes), KnownFailure(),
428 plugins = [ ExclusionPlugin(section.excludes), KnownFailure(),
423 SubprocessStreamCapturePlugin() ]
429 SubprocessStreamCapturePlugin() ]
424
430
425 # we still have some vestigial doctests in core
431 # we still have some vestigial doctests in core
426 if (section.name.startswith(('core', 'IPython.core', 'IPython.utils'))):
432 if (section.name.startswith(('core', 'IPython.core', 'IPython.utils'))):
427 plugins.append(IPythonDoctest())
433 plugins.append(IPythonDoctest())
428 argv.extend([
434 argv.extend([
429 '--with-ipdoctest',
435 '--with-ipdoctest',
430 '--ipdoctest-tests',
436 '--ipdoctest-tests',
431 '--ipdoctest-extension=txt',
437 '--ipdoctest-extension=txt',
432 ])
438 ])
433
439
434
440
435 # Use working directory set by parent process (see iptestcontroller)
441 # Use working directory set by parent process (see iptestcontroller)
436 if 'IPTEST_WORKING_DIR' in os.environ:
442 if 'IPTEST_WORKING_DIR' in os.environ:
437 os.chdir(os.environ['IPTEST_WORKING_DIR'])
443 os.chdir(os.environ['IPTEST_WORKING_DIR'])
438
444
439 # We need a global ipython running in this process, but the special
445 # We need a global ipython running in this process, but the special
440 # in-process group spawns its own IPython kernels, so for *that* group we
446 # in-process group spawns its own IPython kernels, so for *that* group we
441 # must avoid also opening the global one (otherwise there's a conflict of
447 # must avoid also opening the global one (otherwise there's a conflict of
442 # singletons). Ultimately the solution to this problem is to refactor our
448 # singletons). Ultimately the solution to this problem is to refactor our
443 # assumptions about what needs to be a singleton and what doesn't (app
449 # assumptions about what needs to be a singleton and what doesn't (app
444 # objects should, individual shells shouldn't). But for now, this
450 # objects should, individual shells shouldn't). But for now, this
445 # workaround allows the test suite for the inprocess module to complete.
451 # workaround allows the test suite for the inprocess module to complete.
446 if 'kernel.inprocess' not in section.name:
452 if 'kernel.inprocess' not in section.name:
447 from IPython.testing import globalipapp
453 from IPython.testing import globalipapp
448 globalipapp.start_ipython()
454 globalipapp.start_ipython()
449
455
450 # Now nose can run
456 # Now nose can run
451 TestProgram(argv=argv, addplugins=plugins)
457 TestProgram(argv=argv, addplugins=plugins)
452
458
453 if __name__ == '__main__':
459 if __name__ == '__main__':
454 run_iptest()
460 run_iptest()
General Comments 0
You need to be logged in to leave comments. Login now