##// END OF EJS Templates
Port `deepreload` to `importlib`...
Nikita Kniazev -
Show More
@@ -1,341 +1,299
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 Provides a reload() function that acts recursively.
3 Provides a reload() function that acts recursively.
4
4
5 Python's normal :func:`python:reload` function only reloads the module that it's
5 Python's normal :func:`python:reload` function only reloads the module that it's
6 passed. The :func:`reload` function in this module also reloads everything
6 passed. The :func:`reload` function in this module also reloads everything
7 imported from that module, which is useful when you're changing files deep
7 imported from that module, which is useful when you're changing files deep
8 inside a package.
8 inside a package.
9
9
10 To use this as your default reload function, type this::
10 To use this as your default reload function, type this::
11
11
12 import builtins
12 import builtins
13 from IPython.lib import deepreload
13 from IPython.lib import deepreload
14 builtins.reload = deepreload.reload
14 builtins.reload = deepreload.reload
15
15
16 A reference to the original :func:`python:reload` is stored in this module as
16 A reference to the original :func:`python:reload` is stored in this module as
17 :data:`original_reload`, so you can restore it later.
17 :data:`original_reload`, so you can restore it later.
18
18
19 This code is almost entirely based on knee.py, which is a Python
19 This code is almost entirely based on knee.py, which is a Python
20 re-implementation of hierarchical module import.
20 re-implementation of hierarchical module import.
21 """
21 """
22 #*****************************************************************************
22 #*****************************************************************************
23 # Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
23 # Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
24 #
24 #
25 # Distributed under the terms of the BSD License. The full license is in
25 # Distributed under the terms of the BSD License. The full license is in
26 # the file COPYING, distributed as part of this software.
26 # the file COPYING, distributed as part of this software.
27 #*****************************************************************************
27 #*****************************************************************************
28
28
29 import builtins as builtin_mod
29 import builtins as builtin_mod
30 from contextlib import contextmanager
30 from contextlib import contextmanager
31 import imp
31 import importlib
32 import sys
32 import sys
33
33
34 from types import ModuleType
34 from types import ModuleType
35 from warnings import warn
35 from warnings import warn
36 import types
36 import types
37
37
38 original_import = builtin_mod.__import__
38 original_import = builtin_mod.__import__
39
39
40 @contextmanager
40 @contextmanager
41 def replace_import_hook(new_import):
41 def replace_import_hook(new_import):
42 saved_import = builtin_mod.__import__
42 saved_import = builtin_mod.__import__
43 builtin_mod.__import__ = new_import
43 builtin_mod.__import__ = new_import
44 try:
44 try:
45 yield
45 yield
46 finally:
46 finally:
47 builtin_mod.__import__ = saved_import
47 builtin_mod.__import__ = saved_import
48
48
49 def get_parent(globals, level):
49 def get_parent(globals, level):
50 """
50 """
51 parent, name = get_parent(globals, level)
51 parent, name = get_parent(globals, level)
52
52
53 Return the package that an import is being performed in. If globals comes
53 Return the package that an import is being performed in. If globals comes
54 from the module foo.bar.bat (not itself a package), this returns the
54 from the module foo.bar.bat (not itself a package), this returns the
55 sys.modules entry for foo.bar. If globals is from a package's __init__.py,
55 sys.modules entry for foo.bar. If globals is from a package's __init__.py,
56 the package's entry in sys.modules is returned.
56 the package's entry in sys.modules is returned.
57
57
58 If globals doesn't come from a package or a module in a package, or a
58 If globals doesn't come from a package or a module in a package, or a
59 corresponding entry is not found in sys.modules, None is returned.
59 corresponding entry is not found in sys.modules, None is returned.
60 """
60 """
61 orig_level = level
61 orig_level = level
62
62
63 if not level or not isinstance(globals, dict):
63 if not level or not isinstance(globals, dict):
64 return None, ''
64 return None, ''
65
65
66 pkgname = globals.get('__package__', None)
66 pkgname = globals.get('__package__', None)
67
67
68 if pkgname is not None:
68 if pkgname is not None:
69 # __package__ is set, so use it
69 # __package__ is set, so use it
70 if not hasattr(pkgname, 'rindex'):
70 if not hasattr(pkgname, 'rindex'):
71 raise ValueError('__package__ set to non-string')
71 raise ValueError('__package__ set to non-string')
72 if len(pkgname) == 0:
72 if len(pkgname) == 0:
73 if level > 0:
73 if level > 0:
74 raise ValueError('Attempted relative import in non-package')
74 raise ValueError('Attempted relative import in non-package')
75 return None, ''
75 return None, ''
76 name = pkgname
76 name = pkgname
77 else:
77 else:
78 # __package__ not set, so figure it out and set it
78 # __package__ not set, so figure it out and set it
79 if '__name__' not in globals:
79 if '__name__' not in globals:
80 return None, ''
80 return None, ''
81 modname = globals['__name__']
81 modname = globals['__name__']
82
82
83 if '__path__' in globals:
83 if '__path__' in globals:
84 # __path__ is set, so modname is already the package name
84 # __path__ is set, so modname is already the package name
85 globals['__package__'] = name = modname
85 globals['__package__'] = name = modname
86 else:
86 else:
87 # Normal module, so work out the package name if any
87 # Normal module, so work out the package name if any
88 lastdot = modname.rfind('.')
88 lastdot = modname.rfind('.')
89 if lastdot < 0 < level:
89 if lastdot < 0 < level:
90 raise ValueError("Attempted relative import in non-package")
90 raise ValueError("Attempted relative import in non-package")
91 if lastdot < 0:
91 if lastdot < 0:
92 globals['__package__'] = None
92 globals['__package__'] = None
93 return None, ''
93 return None, ''
94 globals['__package__'] = name = modname[:lastdot]
94 globals['__package__'] = name = modname[:lastdot]
95
95
96 dot = len(name)
96 dot = len(name)
97 for x in range(level, 1, -1):
97 for x in range(level, 1, -1):
98 try:
98 try:
99 dot = name.rindex('.', 0, dot)
99 dot = name.rindex('.', 0, dot)
100 except ValueError as e:
100 except ValueError as e:
101 raise ValueError("attempted relative import beyond top-level "
101 raise ValueError("attempted relative import beyond top-level "
102 "package") from e
102 "package") from e
103 name = name[:dot]
103 name = name[:dot]
104
104
105 try:
105 try:
106 parent = sys.modules[name]
106 parent = sys.modules[name]
107 except BaseException as e:
107 except BaseException as e:
108 if orig_level < 1:
108 if orig_level < 1:
109 warn("Parent module '%.200s' not found while handling absolute "
109 warn("Parent module '%.200s' not found while handling absolute "
110 "import" % name)
110 "import" % name)
111 parent = None
111 parent = None
112 else:
112 else:
113 raise SystemError("Parent module '%.200s' not loaded, cannot "
113 raise SystemError("Parent module '%.200s' not loaded, cannot "
114 "perform relative import" % name) from e
114 "perform relative import" % name) from e
115
115
116 # We expect, but can't guarantee, if parent != None, that:
116 # We expect, but can't guarantee, if parent != None, that:
117 # - parent.__name__ == name
117 # - parent.__name__ == name
118 # - parent.__dict__ is globals
118 # - parent.__dict__ is globals
119 # If this is violated... Who cares?
119 # If this is violated... Who cares?
120 return parent, name
120 return parent, name
121
121
122 def load_next(mod, altmod, name, buf):
122 def load_next(mod, altmod, name, buf):
123 """
123 """
124 mod, name, buf = load_next(mod, altmod, name, buf)
124 mod, name, buf = load_next(mod, altmod, name, buf)
125
125
126 altmod is either None or same as mod
126 altmod is either None or same as mod
127 """
127 """
128
128
129 if len(name) == 0:
129 if len(name) == 0:
130 # completely empty module name should only happen in
130 # completely empty module name should only happen in
131 # 'from . import' (or '__import__("")')
131 # 'from . import' (or '__import__("")')
132 return mod, None, buf
132 return mod, None, buf
133
133
134 dot = name.find('.')
134 dot = name.find('.')
135 if dot == 0:
135 if dot == 0:
136 raise ValueError('Empty module name')
136 raise ValueError('Empty module name')
137
137
138 if dot < 0:
138 if dot < 0:
139 subname = name
139 subname = name
140 next = None
140 next = None
141 else:
141 else:
142 subname = name[:dot]
142 subname = name[:dot]
143 next = name[dot+1:]
143 next = name[dot+1:]
144
144
145 if buf != '':
145 if buf != '':
146 buf += '.'
146 buf += '.'
147 buf += subname
147 buf += subname
148
148
149 result = import_submodule(mod, subname, buf)
149 result = import_submodule(mod, subname, buf)
150 if result is None and mod != altmod:
150 if result is None and mod != altmod:
151 result = import_submodule(altmod, subname, subname)
151 result = import_submodule(altmod, subname, subname)
152 if result is not None:
152 if result is not None:
153 buf = subname
153 buf = subname
154
154
155 if result is None:
155 if result is None:
156 raise ImportError("No module named %.200s" % name)
156 raise ImportError("No module named %.200s" % name)
157
157
158 return result, next, buf
158 return result, next, buf
159
159
160
160
161 # Need to keep track of what we've already reloaded to prevent cyclic evil
161 # Need to keep track of what we've already reloaded to prevent cyclic evil
162 found_now = {}
162 found_now = {}
163
163
164 def import_submodule(mod, subname, fullname):
164 def import_submodule(mod, subname, fullname):
165 """m = import_submodule(mod, subname, fullname)"""
165 """m = import_submodule(mod, subname, fullname)"""
166 # Require:
166 # Require:
167 # if mod == None: subname == fullname
167 # if mod == None: subname == fullname
168 # else: mod.__name__ + "." + subname == fullname
168 # else: mod.__name__ + "." + subname == fullname
169
169
170 global found_now
170 global found_now
171 if fullname in found_now and fullname in sys.modules:
171 if fullname in found_now and fullname in sys.modules:
172 m = sys.modules[fullname]
172 m = sys.modules[fullname]
173 else:
173 else:
174 print('Reloading', fullname)
174 print('Reloading', fullname)
175 found_now[fullname] = 1
175 found_now[fullname] = 1
176 oldm = sys.modules.get(fullname, None)
176 oldm = sys.modules.get(fullname, None)
177
178 if mod is None:
179 path = None
180 elif hasattr(mod, '__path__'):
181 path = mod.__path__
182 else:
183 return None
184
185 try:
186 # This appears to be necessary on Python 3, because imp.find_module()
187 # tries to import standard libraries (like io) itself, and we don't
188 # want them to be processed by our deep_import_hook.
189 with replace_import_hook(original_import):
190 fp, filename, stuff = imp.find_module(subname, path)
191 except ImportError:
192 return None
193
194 try:
177 try:
195 m = imp.load_module(fullname, fp, filename, stuff)
178 if oldm is not None:
179 m = importlib.reload(oldm)
180 else:
181 m = importlib.import_module(subname, mod)
196 except:
182 except:
197 # load_module probably removed name from modules because of
183 # load_module probably removed name from modules because of
198 # the error. Put back the original module object.
184 # the error. Put back the original module object.
199 if oldm:
185 if oldm:
200 sys.modules[fullname] = oldm
186 sys.modules[fullname] = oldm
201 raise
187 raise
202 finally:
203 if fp: fp.close()
204
188
205 add_submodule(mod, m, fullname, subname)
189 add_submodule(mod, m, fullname, subname)
206
190
207 return m
191 return m
208
192
209 def add_submodule(mod, submod, fullname, subname):
193 def add_submodule(mod, submod, fullname, subname):
210 """mod.{subname} = submod"""
194 """mod.{subname} = submod"""
211 if mod is None:
195 if mod is None:
212 return #Nothing to do here.
196 return #Nothing to do here.
213
197
214 if submod is None:
198 if submod is None:
215 submod = sys.modules[fullname]
199 submod = sys.modules[fullname]
216
200
217 setattr(mod, subname, submod)
201 setattr(mod, subname, submod)
218
202
219 return
203 return
220
204
221 def ensure_fromlist(mod, fromlist, buf, recursive):
205 def ensure_fromlist(mod, fromlist, buf, recursive):
222 """Handle 'from module import a, b, c' imports."""
206 """Handle 'from module import a, b, c' imports."""
223 if not hasattr(mod, '__path__'):
207 if not hasattr(mod, '__path__'):
224 return
208 return
225 for item in fromlist:
209 for item in fromlist:
226 if not hasattr(item, 'rindex'):
210 if not hasattr(item, 'rindex'):
227 raise TypeError("Item in ``from list'' not a string")
211 raise TypeError("Item in ``from list'' not a string")
228 if item == '*':
212 if item == '*':
229 if recursive:
213 if recursive:
230 continue # avoid endless recursion
214 continue # avoid endless recursion
231 try:
215 try:
232 all = mod.__all__
216 all = mod.__all__
233 except AttributeError:
217 except AttributeError:
234 pass
218 pass
235 else:
219 else:
236 ret = ensure_fromlist(mod, all, buf, 1)
220 ret = ensure_fromlist(mod, all, buf, 1)
237 if not ret:
221 if not ret:
238 return 0
222 return 0
239 elif not hasattr(mod, item):
223 elif not hasattr(mod, item):
240 import_submodule(mod, item, buf + '.' + item)
224 import_submodule(mod, item, buf + '.' + item)
241
225
242 def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1):
226 def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1):
243 """Replacement for __import__()"""
227 """Replacement for __import__()"""
244 parent, buf = get_parent(globals, level)
228 parent, buf = get_parent(globals, level)
245
229
246 head, name, buf = load_next(parent, None if level < 0 else parent, name, buf)
230 head, name, buf = load_next(parent, None if level < 0 else parent, name, buf)
247
231
248 tail = head
232 tail = head
249 while name:
233 while name:
250 tail, name, buf = load_next(tail, tail, name, buf)
234 tail, name, buf = load_next(tail, tail, name, buf)
251
235
252 # If tail is None, both get_parent and load_next found
236 # If tail is None, both get_parent and load_next found
253 # an empty module name: someone called __import__("") or
237 # an empty module name: someone called __import__("") or
254 # doctored faulty bytecode
238 # doctored faulty bytecode
255 if tail is None:
239 if tail is None:
256 raise ValueError('Empty module name')
240 raise ValueError('Empty module name')
257
241
258 if not fromlist:
242 if not fromlist:
259 return head
243 return head
260
244
261 ensure_fromlist(tail, fromlist, buf, 0)
245 ensure_fromlist(tail, fromlist, buf, 0)
262 return tail
246 return tail
263
247
264 modules_reloading = {}
248 modules_reloading = {}
265
249
266 def deep_reload_hook(m):
250 def deep_reload_hook(m):
267 """Replacement for reload()."""
251 """Replacement for reload()."""
268 # Hardcode this one as it would raise a NotImplementedError from the
252 # Hardcode this one as it would raise a NotImplementedError from the
269 # bowels of Python and screw up the import machinery after.
253 # bowels of Python and screw up the import machinery after.
270 # unlike other imports the `exclude` list already in place is not enough.
254 # unlike other imports the `exclude` list already in place is not enough.
271
255
272 if m is types:
256 if m is types:
273 return m
257 return m
274 if not isinstance(m, ModuleType):
258 if not isinstance(m, ModuleType):
275 raise TypeError("reload() argument must be module")
259 raise TypeError("reload() argument must be module")
276
260
277 name = m.__name__
261 name = m.__name__
278
262
279 if name not in sys.modules:
263 if name not in sys.modules:
280 raise ImportError("reload(): module %.200s not in sys.modules" % name)
264 raise ImportError("reload(): module %.200s not in sys.modules" % name)
281
265
282 global modules_reloading
266 global modules_reloading
283 try:
267 try:
284 return modules_reloading[name]
268 return modules_reloading[name]
285 except:
269 except:
286 modules_reloading[name] = m
270 modules_reloading[name] = m
287
271
288 dot = name.rfind('.')
289 if dot < 0:
290 subname = name
291 path = None
292 else:
293 try:
272 try:
294 parent = sys.modules[name[:dot]]
273 newm = importlib.reload(m)
295 except KeyError as e:
296 modules_reloading.clear()
297 raise ImportError("reload(): parent %.200s not in sys.modules" % name[:dot]) from e
298 subname = name[dot+1:]
299 path = getattr(parent, "__path__", None)
300
301 try:
302 # This appears to be necessary on Python 3, because imp.find_module()
303 # tries to import standard libraries (like io) itself, and we don't
304 # want them to be processed by our deep_import_hook.
305 with replace_import_hook(original_import):
306 fp, filename, stuff = imp.find_module(subname, path)
307 finally:
308 modules_reloading.clear()
309
310 try:
311 newm = imp.load_module(name, fp, filename, stuff)
312 except:
274 except:
313 # load_module probably removed name from modules because of
314 # the error. Put back the original module object.
315 sys.modules[name] = m
275 sys.modules[name] = m
316 raise
276 raise
317 finally:
277 finally:
318 if fp: fp.close()
319
320 modules_reloading.clear()
278 modules_reloading.clear()
321 return newm
279 return newm
322
280
323 # Save the original hooks
281 # Save the original hooks
324 original_reload = imp.reload
282 original_reload = importlib.reload
325
283
326 # Replacement for reload()
284 # Replacement for reload()
327 def reload(module, exclude=('sys', 'os.path', 'builtins', '__main__',
285 def reload(module, exclude=('sys', 'os.path', 'builtins', '__main__',
328 'numpy', 'numpy._globals')):
286 'numpy', 'numpy._globals')):
329 """Recursively reload all modules used in the given module. Optionally
287 """Recursively reload all modules used in the given module. Optionally
330 takes a list of modules to exclude from reloading. The default exclude
288 takes a list of modules to exclude from reloading. The default exclude
331 list contains sys, __main__, and __builtin__, to prevent, e.g., resetting
289 list contains sys, __main__, and __builtin__, to prevent, e.g., resetting
332 display, exception, and io hooks.
290 display, exception, and io hooks.
333 """
291 """
334 global found_now
292 global found_now
335 for i in exclude:
293 for i in exclude:
336 found_now[i] = 1
294 found_now[i] = 1
337 try:
295 try:
338 with replace_import_hook(deep_import_hook):
296 with replace_import_hook(deep_import_hook):
339 return deep_reload_hook(module)
297 return deep_reload_hook(module)
340 finally:
298 finally:
341 found_now = {}
299 found_now = {}
@@ -1,34 +1,56
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Test suite for the deepreload module."""
2 """Test suite for the deepreload module."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import pytest
8 import types
9
7 from pathlib import Path
10 from pathlib import Path
8
11
9 from IPython.utils.syspathcontext import prepended_to_syspath
12 from IPython.utils.syspathcontext import prepended_to_syspath
10 from IPython.utils.tempdir import TemporaryDirectory
13 from IPython.utils.tempdir import TemporaryDirectory
11 from IPython.lib.deepreload import reload as dreload
14 from IPython.lib.deepreload import reload as dreload, modules_reloading
12
15
13
16
14 def test_deepreload():
17 def test_deepreload():
15 "Test that dreload does deep reloads and skips excluded modules."
18 "Test that dreload does deep reloads and skips excluded modules."
16 with TemporaryDirectory() as tmpdir:
19 with TemporaryDirectory() as tmpdir:
17 with prepended_to_syspath(tmpdir):
20 with prepended_to_syspath(tmpdir):
18 tmpdirpath = Path(tmpdir)
21 tmpdirpath = Path(tmpdir)
19 with open(tmpdirpath / "A.py", "w") as f:
22 with open(tmpdirpath / "A.py", "w") as f:
20 f.write("class Object(object):\n pass\n")
23 f.write("class Object:\n pass\nok = True\n")
21 with open(tmpdirpath / "B.py", "w") as f:
24 with open(tmpdirpath / "B.py", "w") as f:
22 f.write("import A\n")
25 f.write("import A\nassert A.ok, 'we are fine'\n")
23 import A
26 import A
24 import B
27 import B
25
28
26 # Test that A is not reloaded.
29 # Test that A is not reloaded.
27 obj = A.Object()
30 obj = A.Object()
28 dreload(B, exclude=["A"])
31 dreload(B, exclude=["A"])
29 assert isinstance(obj, A.Object) is True
32 assert isinstance(obj, A.Object) is True
30
33
34 # Test that an import failure will not blow-up us.
35 A.ok = False
36 with pytest.raises(AssertionError, match="we are fine"):
37 dreload(B, exclude=["A"])
38 assert len(modules_reloading) == 0
39 assert not A.ok
40
31 # Test that A is reloaded.
41 # Test that A is reloaded.
32 obj = A.Object()
42 obj = A.Object()
43 A.ok = False
33 dreload(B)
44 dreload(B)
45 assert A.ok
34 assert isinstance(obj, A.Object) is False
46 assert isinstance(obj, A.Object) is False
47
48
49 def test_not_module():
50 pytest.raises(TypeError, dreload, "modulename")
51
52
53 def test_not_in_sys_modules():
54 fake_module = types.ModuleType("fake_module")
55 with pytest.raises(ImportError, match="not in sys.modules"):
56 dreload(fake_module)
General Comments 0
You need to be logged in to leave comments. Login now