##// END OF EJS Templates
Merge pull request #10025 from takluyver/deepreload-no-numpy...
Min RK -
r22972:e4b4f5a1 merge
parent child Browse files
Show More
@@ -1,344 +1,344 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 Provides a reload() function that acts recursively.
3 Provides a reload() function that acts recursively.
4
4
5 Python's normal :func:`python:reload` function only reloads the module that it's
5 Python's normal :func:`python:reload` function only reloads the module that it's
6 passed. The :func:`reload` function in this module also reloads everything
6 passed. The :func:`reload` function in this module also reloads everything
7 imported from that module, which is useful when you're changing files deep
7 imported from that module, which is useful when you're changing files deep
8 inside a package.
8 inside a package.
9
9
10 To use this as your default reload function, type this for Python 2::
10 To use this as your default reload function, type this for Python 2::
11
11
12 import __builtin__
12 import __builtin__
13 from IPython.lib import deepreload
13 from IPython.lib import deepreload
14 __builtin__.reload = deepreload.reload
14 __builtin__.reload = deepreload.reload
15
15
16 Or this for Python 3::
16 Or this for Python 3::
17
17
18 import builtins
18 import builtins
19 from IPython.lib import deepreload
19 from IPython.lib import deepreload
20 builtins.reload = deepreload.reload
20 builtins.reload = deepreload.reload
21
21
22 A reference to the original :func:`python:reload` is stored in this module as
22 A reference to the original :func:`python:reload` is stored in this module as
23 :data:`original_reload`, so you can restore it later.
23 :data:`original_reload`, so you can restore it later.
24
24
25 This code is almost entirely based on knee.py, which is a Python
25 This code is almost entirely based on knee.py, which is a Python
26 re-implementation of hierarchical module import.
26 re-implementation of hierarchical module import.
27 """
27 """
28 from __future__ import print_function
28 from __future__ import print_function
29 #*****************************************************************************
29 #*****************************************************************************
30 # Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
30 # Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
31 #
31 #
32 # Distributed under the terms of the BSD License. The full license is in
32 # Distributed under the terms of the BSD License. The full license is in
33 # the file COPYING, distributed as part of this software.
33 # the file COPYING, distributed as part of this software.
34 #*****************************************************************************
34 #*****************************************************************************
35
35
36 from contextlib import contextmanager
36 from contextlib import contextmanager
37 import imp
37 import imp
38 import sys
38 import sys
39
39
40 from types import ModuleType
40 from types import ModuleType
41 from warnings import warn
41 from warnings import warn
42
42
43 from IPython.utils.py3compat import builtin_mod, builtin_mod_name
43 from IPython.utils.py3compat import builtin_mod, builtin_mod_name
44
44
45 original_import = builtin_mod.__import__
45 original_import = builtin_mod.__import__
46
46
47 @contextmanager
47 @contextmanager
48 def replace_import_hook(new_import):
48 def replace_import_hook(new_import):
49 saved_import = builtin_mod.__import__
49 saved_import = builtin_mod.__import__
50 builtin_mod.__import__ = new_import
50 builtin_mod.__import__ = new_import
51 try:
51 try:
52 yield
52 yield
53 finally:
53 finally:
54 builtin_mod.__import__ = saved_import
54 builtin_mod.__import__ = saved_import
55
55
56 def get_parent(globals, level):
56 def get_parent(globals, level):
57 """
57 """
58 parent, name = get_parent(globals, level)
58 parent, name = get_parent(globals, level)
59
59
60 Return the package that an import is being performed in. If globals comes
60 Return the package that an import is being performed in. If globals comes
61 from the module foo.bar.bat (not itself a package), this returns the
61 from the module foo.bar.bat (not itself a package), this returns the
62 sys.modules entry for foo.bar. If globals is from a package's __init__.py,
62 sys.modules entry for foo.bar. If globals is from a package's __init__.py,
63 the package's entry in sys.modules is returned.
63 the package's entry in sys.modules is returned.
64
64
65 If globals doesn't come from a package or a module in a package, or a
65 If globals doesn't come from a package or a module in a package, or a
66 corresponding entry is not found in sys.modules, None is returned.
66 corresponding entry is not found in sys.modules, None is returned.
67 """
67 """
68 orig_level = level
68 orig_level = level
69
69
70 if not level or not isinstance(globals, dict):
70 if not level or not isinstance(globals, dict):
71 return None, ''
71 return None, ''
72
72
73 pkgname = globals.get('__package__', None)
73 pkgname = globals.get('__package__', None)
74
74
75 if pkgname is not None:
75 if pkgname is not None:
76 # __package__ is set, so use it
76 # __package__ is set, so use it
77 if not hasattr(pkgname, 'rindex'):
77 if not hasattr(pkgname, 'rindex'):
78 raise ValueError('__package__ set to non-string')
78 raise ValueError('__package__ set to non-string')
79 if len(pkgname) == 0:
79 if len(pkgname) == 0:
80 if level > 0:
80 if level > 0:
81 raise ValueError('Attempted relative import in non-package')
81 raise ValueError('Attempted relative import in non-package')
82 return None, ''
82 return None, ''
83 name = pkgname
83 name = pkgname
84 else:
84 else:
85 # __package__ not set, so figure it out and set it
85 # __package__ not set, so figure it out and set it
86 if '__name__' not in globals:
86 if '__name__' not in globals:
87 return None, ''
87 return None, ''
88 modname = globals['__name__']
88 modname = globals['__name__']
89
89
90 if '__path__' in globals:
90 if '__path__' in globals:
91 # __path__ is set, so modname is already the package name
91 # __path__ is set, so modname is already the package name
92 globals['__package__'] = name = modname
92 globals['__package__'] = name = modname
93 else:
93 else:
94 # Normal module, so work out the package name if any
94 # Normal module, so work out the package name if any
95 lastdot = modname.rfind('.')
95 lastdot = modname.rfind('.')
96 if lastdot < 0 < level:
96 if lastdot < 0 < level:
97 raise ValueError("Attempted relative import in non-package")
97 raise ValueError("Attempted relative import in non-package")
98 if lastdot < 0:
98 if lastdot < 0:
99 globals['__package__'] = None
99 globals['__package__'] = None
100 return None, ''
100 return None, ''
101 globals['__package__'] = name = modname[:lastdot]
101 globals['__package__'] = name = modname[:lastdot]
102
102
103 dot = len(name)
103 dot = len(name)
104 for x in range(level, 1, -1):
104 for x in range(level, 1, -1):
105 try:
105 try:
106 dot = name.rindex('.', 0, dot)
106 dot = name.rindex('.', 0, dot)
107 except ValueError:
107 except ValueError:
108 raise ValueError("attempted relative import beyond top-level "
108 raise ValueError("attempted relative import beyond top-level "
109 "package")
109 "package")
110 name = name[:dot]
110 name = name[:dot]
111
111
112 try:
112 try:
113 parent = sys.modules[name]
113 parent = sys.modules[name]
114 except:
114 except:
115 if orig_level < 1:
115 if orig_level < 1:
116 warn("Parent module '%.200s' not found while handling absolute "
116 warn("Parent module '%.200s' not found while handling absolute "
117 "import" % name)
117 "import" % name)
118 parent = None
118 parent = None
119 else:
119 else:
120 raise SystemError("Parent module '%.200s' not loaded, cannot "
120 raise SystemError("Parent module '%.200s' not loaded, cannot "
121 "perform relative import" % name)
121 "perform relative import" % name)
122
122
123 # We expect, but can't guarantee, if parent != None, that:
123 # We expect, but can't guarantee, if parent != None, that:
124 # - parent.__name__ == name
124 # - parent.__name__ == name
125 # - parent.__dict__ is globals
125 # - parent.__dict__ is globals
126 # If this is violated... Who cares?
126 # If this is violated... Who cares?
127 return parent, name
127 return parent, name
128
128
129 def load_next(mod, altmod, name, buf):
129 def load_next(mod, altmod, name, buf):
130 """
130 """
131 mod, name, buf = load_next(mod, altmod, name, buf)
131 mod, name, buf = load_next(mod, altmod, name, buf)
132
132
133 altmod is either None or same as mod
133 altmod is either None or same as mod
134 """
134 """
135
135
136 if len(name) == 0:
136 if len(name) == 0:
137 # completely empty module name should only happen in
137 # completely empty module name should only happen in
138 # 'from . import' (or '__import__("")')
138 # 'from . import' (or '__import__("")')
139 return mod, None, buf
139 return mod, None, buf
140
140
141 dot = name.find('.')
141 dot = name.find('.')
142 if dot == 0:
142 if dot == 0:
143 raise ValueError('Empty module name')
143 raise ValueError('Empty module name')
144
144
145 if dot < 0:
145 if dot < 0:
146 subname = name
146 subname = name
147 next = None
147 next = None
148 else:
148 else:
149 subname = name[:dot]
149 subname = name[:dot]
150 next = name[dot+1:]
150 next = name[dot+1:]
151
151
152 if buf != '':
152 if buf != '':
153 buf += '.'
153 buf += '.'
154 buf += subname
154 buf += subname
155
155
156 result = import_submodule(mod, subname, buf)
156 result = import_submodule(mod, subname, buf)
157 if result is None and mod != altmod:
157 if result is None and mod != altmod:
158 result = import_submodule(altmod, subname, subname)
158 result = import_submodule(altmod, subname, subname)
159 if result is not None:
159 if result is not None:
160 buf = subname
160 buf = subname
161
161
162 if result is None:
162 if result is None:
163 raise ImportError("No module named %.200s" % name)
163 raise ImportError("No module named %.200s" % name)
164
164
165 return result, next, buf
165 return result, next, buf
166
166
167 # Need to keep track of what we've already reloaded to prevent cyclic evil
167 # Need to keep track of what we've already reloaded to prevent cyclic evil
168 found_now = {}
168 found_now = {}
169
169
170 def import_submodule(mod, subname, fullname):
170 def import_submodule(mod, subname, fullname):
171 """m = import_submodule(mod, subname, fullname)"""
171 """m = import_submodule(mod, subname, fullname)"""
172 # Require:
172 # Require:
173 # if mod == None: subname == fullname
173 # if mod == None: subname == fullname
174 # else: mod.__name__ + "." + subname == fullname
174 # else: mod.__name__ + "." + subname == fullname
175
175
176 global found_now
176 global found_now
177 if fullname in found_now and fullname in sys.modules:
177 if fullname in found_now and fullname in sys.modules:
178 m = sys.modules[fullname]
178 m = sys.modules[fullname]
179 else:
179 else:
180 print('Reloading', fullname)
180 print('Reloading', fullname)
181 found_now[fullname] = 1
181 found_now[fullname] = 1
182 oldm = sys.modules.get(fullname, None)
182 oldm = sys.modules.get(fullname, None)
183
183
184 if mod is None:
184 if mod is None:
185 path = None
185 path = None
186 elif hasattr(mod, '__path__'):
186 elif hasattr(mod, '__path__'):
187 path = mod.__path__
187 path = mod.__path__
188 else:
188 else:
189 return None
189 return None
190
190
191 try:
191 try:
192 # This appears to be necessary on Python 3, because imp.find_module()
192 # This appears to be necessary on Python 3, because imp.find_module()
193 # tries to import standard libraries (like io) itself, and we don't
193 # tries to import standard libraries (like io) itself, and we don't
194 # want them to be processed by our deep_import_hook.
194 # want them to be processed by our deep_import_hook.
195 with replace_import_hook(original_import):
195 with replace_import_hook(original_import):
196 fp, filename, stuff = imp.find_module(subname, path)
196 fp, filename, stuff = imp.find_module(subname, path)
197 except ImportError:
197 except ImportError:
198 return None
198 return None
199
199
200 try:
200 try:
201 m = imp.load_module(fullname, fp, filename, stuff)
201 m = imp.load_module(fullname, fp, filename, stuff)
202 except:
202 except:
203 # load_module probably removed name from modules because of
203 # load_module probably removed name from modules because of
204 # the error. Put back the original module object.
204 # the error. Put back the original module object.
205 if oldm:
205 if oldm:
206 sys.modules[fullname] = oldm
206 sys.modules[fullname] = oldm
207 raise
207 raise
208 finally:
208 finally:
209 if fp: fp.close()
209 if fp: fp.close()
210
210
211 add_submodule(mod, m, fullname, subname)
211 add_submodule(mod, m, fullname, subname)
212
212
213 return m
213 return m
214
214
215 def add_submodule(mod, submod, fullname, subname):
215 def add_submodule(mod, submod, fullname, subname):
216 """mod.{subname} = submod"""
216 """mod.{subname} = submod"""
217 if mod is None:
217 if mod is None:
218 return #Nothing to do here.
218 return #Nothing to do here.
219
219
220 if submod is None:
220 if submod is None:
221 submod = sys.modules[fullname]
221 submod = sys.modules[fullname]
222
222
223 setattr(mod, subname, submod)
223 setattr(mod, subname, submod)
224
224
225 return
225 return
226
226
227 def ensure_fromlist(mod, fromlist, buf, recursive):
227 def ensure_fromlist(mod, fromlist, buf, recursive):
228 """Handle 'from module import a, b, c' imports."""
228 """Handle 'from module import a, b, c' imports."""
229 if not hasattr(mod, '__path__'):
229 if not hasattr(mod, '__path__'):
230 return
230 return
231 for item in fromlist:
231 for item in fromlist:
232 if not hasattr(item, 'rindex'):
232 if not hasattr(item, 'rindex'):
233 raise TypeError("Item in ``from list'' not a string")
233 raise TypeError("Item in ``from list'' not a string")
234 if item == '*':
234 if item == '*':
235 if recursive:
235 if recursive:
236 continue # avoid endless recursion
236 continue # avoid endless recursion
237 try:
237 try:
238 all = mod.__all__
238 all = mod.__all__
239 except AttributeError:
239 except AttributeError:
240 pass
240 pass
241 else:
241 else:
242 ret = ensure_fromlist(mod, all, buf, 1)
242 ret = ensure_fromlist(mod, all, buf, 1)
243 if not ret:
243 if not ret:
244 return 0
244 return 0
245 elif not hasattr(mod, item):
245 elif not hasattr(mod, item):
246 import_submodule(mod, item, buf + '.' + item)
246 import_submodule(mod, item, buf + '.' + item)
247
247
248 def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1):
248 def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1):
249 """Replacement for __import__()"""
249 """Replacement for __import__()"""
250 parent, buf = get_parent(globals, level)
250 parent, buf = get_parent(globals, level)
251
251
252 head, name, buf = load_next(parent, None if level < 0 else parent, name, buf)
252 head, name, buf = load_next(parent, None if level < 0 else parent, name, buf)
253
253
254 tail = head
254 tail = head
255 while name:
255 while name:
256 tail, name, buf = load_next(tail, tail, name, buf)
256 tail, name, buf = load_next(tail, tail, name, buf)
257
257
258 # If tail is None, both get_parent and load_next found
258 # If tail is None, both get_parent and load_next found
259 # an empty module name: someone called __import__("") or
259 # an empty module name: someone called __import__("") or
260 # doctored faulty bytecode
260 # doctored faulty bytecode
261 if tail is None:
261 if tail is None:
262 raise ValueError('Empty module name')
262 raise ValueError('Empty module name')
263
263
264 if not fromlist:
264 if not fromlist:
265 return head
265 return head
266
266
267 ensure_fromlist(tail, fromlist, buf, 0)
267 ensure_fromlist(tail, fromlist, buf, 0)
268 return tail
268 return tail
269
269
270 modules_reloading = {}
270 modules_reloading = {}
271
271
272 def deep_reload_hook(m):
272 def deep_reload_hook(m):
273 """Replacement for reload()."""
273 """Replacement for reload()."""
274 if not isinstance(m, ModuleType):
274 if not isinstance(m, ModuleType):
275 raise TypeError("reload() argument must be module")
275 raise TypeError("reload() argument must be module")
276
276
277 name = m.__name__
277 name = m.__name__
278
278
279 if name not in sys.modules:
279 if name not in sys.modules:
280 raise ImportError("reload(): module %.200s not in sys.modules" % name)
280 raise ImportError("reload(): module %.200s not in sys.modules" % name)
281
281
282 global modules_reloading
282 global modules_reloading
283 try:
283 try:
284 return modules_reloading[name]
284 return modules_reloading[name]
285 except:
285 except:
286 modules_reloading[name] = m
286 modules_reloading[name] = m
287
287
288 dot = name.rfind('.')
288 dot = name.rfind('.')
289 if dot < 0:
289 if dot < 0:
290 subname = name
290 subname = name
291 path = None
291 path = None
292 else:
292 else:
293 try:
293 try:
294 parent = sys.modules[name[:dot]]
294 parent = sys.modules[name[:dot]]
295 except KeyError:
295 except KeyError:
296 modules_reloading.clear()
296 modules_reloading.clear()
297 raise ImportError("reload(): parent %.200s not in sys.modules" % name[:dot])
297 raise ImportError("reload(): parent %.200s not in sys.modules" % name[:dot])
298 subname = name[dot+1:]
298 subname = name[dot+1:]
299 path = getattr(parent, "__path__", None)
299 path = getattr(parent, "__path__", None)
300
300
301 try:
301 try:
302 # This appears to be necessary on Python 3, because imp.find_module()
302 # This appears to be necessary on Python 3, because imp.find_module()
303 # tries to import standard libraries (like io) itself, and we don't
303 # tries to import standard libraries (like io) itself, and we don't
304 # want them to be processed by our deep_import_hook.
304 # want them to be processed by our deep_import_hook.
305 with replace_import_hook(original_import):
305 with replace_import_hook(original_import):
306 fp, filename, stuff = imp.find_module(subname, path)
306 fp, filename, stuff = imp.find_module(subname, path)
307 finally:
307 finally:
308 modules_reloading.clear()
308 modules_reloading.clear()
309
309
310 try:
310 try:
311 newm = imp.load_module(name, fp, filename, stuff)
311 newm = imp.load_module(name, fp, filename, stuff)
312 except:
312 except:
313 # load_module probably removed name from modules because of
313 # load_module probably removed name from modules because of
314 # the error. Put back the original module object.
314 # the error. Put back the original module object.
315 sys.modules[name] = m
315 sys.modules[name] = m
316 raise
316 raise
317 finally:
317 finally:
318 if fp: fp.close()
318 if fp: fp.close()
319
319
320 modules_reloading.clear()
320 modules_reloading.clear()
321 return newm
321 return newm
322
322
323 # Save the original hooks
323 # Save the original hooks
324 try:
324 try:
325 original_reload = builtin_mod.reload
325 original_reload = builtin_mod.reload
326 except AttributeError:
326 except AttributeError:
327 original_reload = imp.reload # Python 3
327 original_reload = imp.reload # Python 3
328
328
329 # Replacement for reload()
329 # Replacement for reload()
330 def reload(module, exclude=('sys', 'os.path', builtin_mod_name, '__main__',
330 def reload(module, exclude=('sys', 'os.path', builtin_mod_name, '__main__',
331 'numpy._globals')):
331 'numpy', 'numpy._globals')):
332 """Recursively reload all modules used in the given module. Optionally
332 """Recursively reload all modules used in the given module. Optionally
333 takes a list of modules to exclude from reloading. The default exclude
333 takes a list of modules to exclude from reloading. The default exclude
334 list contains sys, __main__, and __builtin__, to prevent, e.g., resetting
334 list contains sys, __main__, and __builtin__, to prevent, e.g., resetting
335 display, exception, and io hooks.
335 display, exception, and io hooks.
336 """
336 """
337 global found_now
337 global found_now
338 for i in exclude:
338 for i in exclude:
339 found_now[i] = 1
339 found_now[i] = 1
340 try:
340 try:
341 with replace_import_hook(deep_import_hook):
341 with replace_import_hook(deep_import_hook):
342 return deep_reload_hook(module)
342 return deep_reload_hook(module)
343 finally:
343 finally:
344 found_now = {}
344 found_now = {}
@@ -1,59 +1,34 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Test suite for the deepreload module."""
2 """Test suite for the deepreload module."""
3
3
4 #-----------------------------------------------------------------------------
4 # Copyright (c) IPython Development Team.
5 # Imports
5 # Distributed under the terms of the Modified BSD License.
6 #-----------------------------------------------------------------------------
7
6
8 import os
7 import os
9
8
10 import nose.tools as nt
9 import nose.tools as nt
11
10
12 from IPython.testing import decorators as dec
13 from IPython.utils.py3compat import builtin_mod_name, PY3
14 from IPython.utils.syspathcontext import prepended_to_syspath
11 from IPython.utils.syspathcontext import prepended_to_syspath
15 from IPython.utils.tempdir import TemporaryDirectory
12 from IPython.utils.tempdir import TemporaryDirectory
16 from IPython.lib.deepreload import reload as dreload
13 from IPython.lib.deepreload import reload as dreload
17
14
18 #-----------------------------------------------------------------------------
19 # Test functions begin
20 #-----------------------------------------------------------------------------
21
22 @dec.skipif_not_numpy
23 def test_deepreload_numpy():
24 "Test that NumPy can be deep reloaded."
25 import numpy
26 # TODO: Find a way to exclude all standard library modules from reloading.
27 exclude = [
28 # Standard exclusions:
29 'sys', 'os.path', builtin_mod_name, '__main__',
30 # Test-related exclusions:
31 'unittest', 'UserDict', '_collections_abc', 'tokenize',
32 'collections', 'collections.abc',
33 'importlib', 'importlib.machinery', '_imp',
34 'importlib._bootstrap', 'importlib._bootstrap_external',
35 '_frozen_importlib', '_frozen_importlib_external',
36 ]
37
38 dreload(numpy, exclude=exclude)
39
40 def test_deepreload():
15 def test_deepreload():
41 "Test that dreload does deep reloads and skips excluded modules."
16 "Test that dreload does deep reloads and skips excluded modules."
42 with TemporaryDirectory() as tmpdir:
17 with TemporaryDirectory() as tmpdir:
43 with prepended_to_syspath(tmpdir):
18 with prepended_to_syspath(tmpdir):
44 with open(os.path.join(tmpdir, 'A.py'), 'w') as f:
19 with open(os.path.join(tmpdir, 'A.py'), 'w') as f:
45 f.write("class Object(object):\n pass\n")
20 f.write("class Object(object):\n pass\n")
46 with open(os.path.join(tmpdir, 'B.py'), 'w') as f:
21 with open(os.path.join(tmpdir, 'B.py'), 'w') as f:
47 f.write("import A\n")
22 f.write("import A\n")
48 import A
23 import A
49 import B
24 import B
50
25
51 # Test that A is not reloaded.
26 # Test that A is not reloaded.
52 obj = A.Object()
27 obj = A.Object()
53 dreload(B, exclude=['A'])
28 dreload(B, exclude=['A'])
54 nt.assert_true(isinstance(obj, A.Object))
29 nt.assert_true(isinstance(obj, A.Object))
55
30
56 # Test that A is reloaded.
31 # Test that A is reloaded.
57 obj = A.Object()
32 obj = A.Object()
58 dreload(B)
33 dreload(B)
59 nt.assert_false(isinstance(obj, A.Object))
34 nt.assert_false(isinstance(obj, A.Object))
General Comments 0
You need to be logged in to leave comments. Login now