##// END OF EJS Templates
Port `deepreload` to `importlib`...
Nikita Kniazev -
Show More
@@ -1,341 +1,299 b''
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 Provides a reload() function that acts recursively.
4 4
5 5 Python's normal :func:`python:reload` function only reloads the module that it's
6 6 passed. The :func:`reload` function in this module also reloads everything
7 7 imported from that module, which is useful when you're changing files deep
8 8 inside a package.
9 9
10 10 To use this as your default reload function, type this::
11 11
12 12 import builtins
13 13 from IPython.lib import deepreload
14 14 builtins.reload = deepreload.reload
15 15
16 16 A reference to the original :func:`python:reload` is stored in this module as
17 17 :data:`original_reload`, so you can restore it later.
18 18
19 19 This code is almost entirely based on knee.py, which is a Python
20 20 re-implementation of hierarchical module import.
21 21 """
22 22 #*****************************************************************************
23 23 # Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
24 24 #
25 25 # Distributed under the terms of the BSD License. The full license is in
26 26 # the file COPYING, distributed as part of this software.
27 27 #*****************************************************************************
28 28
29 29 import builtins as builtin_mod
30 30 from contextlib import contextmanager
31 import imp
31 import importlib
32 32 import sys
33 33
34 34 from types import ModuleType
35 35 from warnings import warn
36 36 import types
37 37
38 38 original_import = builtin_mod.__import__
39 39
40 40 @contextmanager
41 41 def replace_import_hook(new_import):
42 42 saved_import = builtin_mod.__import__
43 43 builtin_mod.__import__ = new_import
44 44 try:
45 45 yield
46 46 finally:
47 47 builtin_mod.__import__ = saved_import
48 48
49 49 def get_parent(globals, level):
50 50 """
51 51 parent, name = get_parent(globals, level)
52 52
53 53 Return the package that an import is being performed in. If globals comes
54 54 from the module foo.bar.bat (not itself a package), this returns the
55 55 sys.modules entry for foo.bar. If globals is from a package's __init__.py,
56 56 the package's entry in sys.modules is returned.
57 57
58 58 If globals doesn't come from a package or a module in a package, or a
59 59 corresponding entry is not found in sys.modules, None is returned.
60 60 """
61 61 orig_level = level
62 62
63 63 if not level or not isinstance(globals, dict):
64 64 return None, ''
65 65
66 66 pkgname = globals.get('__package__', None)
67 67
68 68 if pkgname is not None:
69 69 # __package__ is set, so use it
70 70 if not hasattr(pkgname, 'rindex'):
71 71 raise ValueError('__package__ set to non-string')
72 72 if len(pkgname) == 0:
73 73 if level > 0:
74 74 raise ValueError('Attempted relative import in non-package')
75 75 return None, ''
76 76 name = pkgname
77 77 else:
78 78 # __package__ not set, so figure it out and set it
79 79 if '__name__' not in globals:
80 80 return None, ''
81 81 modname = globals['__name__']
82 82
83 83 if '__path__' in globals:
84 84 # __path__ is set, so modname is already the package name
85 85 globals['__package__'] = name = modname
86 86 else:
87 87 # Normal module, so work out the package name if any
88 88 lastdot = modname.rfind('.')
89 89 if lastdot < 0 < level:
90 90 raise ValueError("Attempted relative import in non-package")
91 91 if lastdot < 0:
92 92 globals['__package__'] = None
93 93 return None, ''
94 94 globals['__package__'] = name = modname[:lastdot]
95 95
96 96 dot = len(name)
97 97 for x in range(level, 1, -1):
98 98 try:
99 99 dot = name.rindex('.', 0, dot)
100 100 except ValueError as e:
101 101 raise ValueError("attempted relative import beyond top-level "
102 102 "package") from e
103 103 name = name[:dot]
104 104
105 105 try:
106 106 parent = sys.modules[name]
107 107 except BaseException as e:
108 108 if orig_level < 1:
109 109 warn("Parent module '%.200s' not found while handling absolute "
110 110 "import" % name)
111 111 parent = None
112 112 else:
113 113 raise SystemError("Parent module '%.200s' not loaded, cannot "
114 114 "perform relative import" % name) from e
115 115
116 116 # We expect, but can't guarantee, if parent != None, that:
117 117 # - parent.__name__ == name
118 118 # - parent.__dict__ is globals
119 119 # If this is violated... Who cares?
120 120 return parent, name
121 121
122 122 def load_next(mod, altmod, name, buf):
123 123 """
124 124 mod, name, buf = load_next(mod, altmod, name, buf)
125 125
126 126 altmod is either None or same as mod
127 127 """
128 128
129 129 if len(name) == 0:
130 130 # completely empty module name should only happen in
131 131 # 'from . import' (or '__import__("")')
132 132 return mod, None, buf
133 133
134 134 dot = name.find('.')
135 135 if dot == 0:
136 136 raise ValueError('Empty module name')
137 137
138 138 if dot < 0:
139 139 subname = name
140 140 next = None
141 141 else:
142 142 subname = name[:dot]
143 143 next = name[dot+1:]
144 144
145 145 if buf != '':
146 146 buf += '.'
147 147 buf += subname
148 148
149 149 result = import_submodule(mod, subname, buf)
150 150 if result is None and mod != altmod:
151 151 result = import_submodule(altmod, subname, subname)
152 152 if result is not None:
153 153 buf = subname
154 154
155 155 if result is None:
156 156 raise ImportError("No module named %.200s" % name)
157 157
158 158 return result, next, buf
159 159
160 160
161 161 # Need to keep track of what we've already reloaded to prevent cyclic evil
162 162 found_now = {}
163 163
164 164 def import_submodule(mod, subname, fullname):
165 165 """m = import_submodule(mod, subname, fullname)"""
166 166 # Require:
167 167 # if mod == None: subname == fullname
168 168 # else: mod.__name__ + "." + subname == fullname
169 169
170 170 global found_now
171 171 if fullname in found_now and fullname in sys.modules:
172 172 m = sys.modules[fullname]
173 173 else:
174 174 print('Reloading', fullname)
175 175 found_now[fullname] = 1
176 176 oldm = sys.modules.get(fullname, None)
177
178 if mod is None:
179 path = None
180 elif hasattr(mod, '__path__'):
181 path = mod.__path__
182 else:
183 return None
184
185 try:
186 # This appears to be necessary on Python 3, because imp.find_module()
187 # tries to import standard libraries (like io) itself, and we don't
188 # want them to be processed by our deep_import_hook.
189 with replace_import_hook(original_import):
190 fp, filename, stuff = imp.find_module(subname, path)
191 except ImportError:
192 return None
193
194 177 try:
195 m = imp.load_module(fullname, fp, filename, stuff)
178 if oldm is not None:
179 m = importlib.reload(oldm)
180 else:
181 m = importlib.import_module(subname, mod)
196 182 except:
197 183 # load_module probably removed name from modules because of
198 184 # the error. Put back the original module object.
199 185 if oldm:
200 186 sys.modules[fullname] = oldm
201 187 raise
202 finally:
203 if fp: fp.close()
204 188
205 189 add_submodule(mod, m, fullname, subname)
206 190
207 191 return m
208 192
209 193 def add_submodule(mod, submod, fullname, subname):
210 194 """mod.{subname} = submod"""
211 195 if mod is None:
212 196 return #Nothing to do here.
213 197
214 198 if submod is None:
215 199 submod = sys.modules[fullname]
216 200
217 201 setattr(mod, subname, submod)
218 202
219 203 return
220 204
221 205 def ensure_fromlist(mod, fromlist, buf, recursive):
222 206 """Handle 'from module import a, b, c' imports."""
223 207 if not hasattr(mod, '__path__'):
224 208 return
225 209 for item in fromlist:
226 210 if not hasattr(item, 'rindex'):
227 211 raise TypeError("Item in ``from list'' not a string")
228 212 if item == '*':
229 213 if recursive:
230 214 continue # avoid endless recursion
231 215 try:
232 216 all = mod.__all__
233 217 except AttributeError:
234 218 pass
235 219 else:
236 220 ret = ensure_fromlist(mod, all, buf, 1)
237 221 if not ret:
238 222 return 0
239 223 elif not hasattr(mod, item):
240 224 import_submodule(mod, item, buf + '.' + item)
241 225
242 226 def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1):
243 227 """Replacement for __import__()"""
244 228 parent, buf = get_parent(globals, level)
245 229
246 230 head, name, buf = load_next(parent, None if level < 0 else parent, name, buf)
247 231
248 232 tail = head
249 233 while name:
250 234 tail, name, buf = load_next(tail, tail, name, buf)
251 235
252 236 # If tail is None, both get_parent and load_next found
253 237 # an empty module name: someone called __import__("") or
254 238 # doctored faulty bytecode
255 239 if tail is None:
256 240 raise ValueError('Empty module name')
257 241
258 242 if not fromlist:
259 243 return head
260 244
261 245 ensure_fromlist(tail, fromlist, buf, 0)
262 246 return tail
263 247
264 248 modules_reloading = {}
265 249
266 250 def deep_reload_hook(m):
267 251 """Replacement for reload()."""
268 252 # Hardcode this one as it would raise a NotImplementedError from the
269 253 # bowels of Python and screw up the import machinery after.
270 254 # unlike other imports the `exclude` list already in place is not enough.
271 255
272 256 if m is types:
273 257 return m
274 258 if not isinstance(m, ModuleType):
275 259 raise TypeError("reload() argument must be module")
276 260
277 261 name = m.__name__
278 262
279 263 if name not in sys.modules:
280 264 raise ImportError("reload(): module %.200s not in sys.modules" % name)
281 265
282 266 global modules_reloading
283 267 try:
284 268 return modules_reloading[name]
285 269 except:
286 270 modules_reloading[name] = m
287 271
288 dot = name.rfind('.')
289 if dot < 0:
290 subname = name
291 path = None
292 else:
293 try:
294 parent = sys.modules[name[:dot]]
295 except KeyError as e:
296 modules_reloading.clear()
297 raise ImportError("reload(): parent %.200s not in sys.modules" % name[:dot]) from e
298 subname = name[dot+1:]
299 path = getattr(parent, "__path__", None)
300
301 try:
302 # This appears to be necessary on Python 3, because imp.find_module()
303 # tries to import standard libraries (like io) itself, and we don't
304 # want them to be processed by our deep_import_hook.
305 with replace_import_hook(original_import):
306 fp, filename, stuff = imp.find_module(subname, path)
307 finally:
308 modules_reloading.clear()
309
310 272 try:
311 newm = imp.load_module(name, fp, filename, stuff)
273 newm = importlib.reload(m)
312 274 except:
313 # load_module probably removed name from modules because of
314 # the error. Put back the original module object.
315 275 sys.modules[name] = m
316 276 raise
317 277 finally:
318 if fp: fp.close()
319
320 modules_reloading.clear()
278 modules_reloading.clear()
321 279 return newm
322 280
323 281 # Save the original hooks
324 original_reload = imp.reload
282 original_reload = importlib.reload
325 283
326 284 # Replacement for reload()
327 285 def reload(module, exclude=('sys', 'os.path', 'builtins', '__main__',
328 286 'numpy', 'numpy._globals')):
329 287 """Recursively reload all modules used in the given module. Optionally
330 288 takes a list of modules to exclude from reloading. The default exclude
331 289 list contains sys, __main__, and __builtin__, to prevent, e.g., resetting
332 290 display, exception, and io hooks.
333 291 """
334 292 global found_now
335 293 for i in exclude:
336 294 found_now[i] = 1
337 295 try:
338 296 with replace_import_hook(deep_import_hook):
339 297 return deep_reload_hook(module)
340 298 finally:
341 299 found_now = {}
@@ -1,34 +1,56 b''
1 1 # -*- coding: utf-8 -*-
2 2 """Test suite for the deepreload module."""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6
7 import pytest
8 import types
9
7 10 from pathlib import Path
8 11
9 12 from IPython.utils.syspathcontext import prepended_to_syspath
10 13 from IPython.utils.tempdir import TemporaryDirectory
11 from IPython.lib.deepreload import reload as dreload
14 from IPython.lib.deepreload import reload as dreload, modules_reloading
12 15
13 16
14 17 def test_deepreload():
15 18 "Test that dreload does deep reloads and skips excluded modules."
16 19 with TemporaryDirectory() as tmpdir:
17 20 with prepended_to_syspath(tmpdir):
18 21 tmpdirpath = Path(tmpdir)
19 22 with open(tmpdirpath / "A.py", "w") as f:
20 f.write("class Object(object):\n pass\n")
23 f.write("class Object:\n pass\nok = True\n")
21 24 with open(tmpdirpath / "B.py", "w") as f:
22 f.write("import A\n")
25 f.write("import A\nassert A.ok, 'we are fine'\n")
23 26 import A
24 27 import B
25 28
26 29 # Test that A is not reloaded.
27 30 obj = A.Object()
28 31 dreload(B, exclude=["A"])
29 32 assert isinstance(obj, A.Object) is True
30 33
34 # Test that an import failure will not blow-up us.
35 A.ok = False
36 with pytest.raises(AssertionError, match="we are fine"):
37 dreload(B, exclude=["A"])
38 assert len(modules_reloading) == 0
39 assert not A.ok
40
31 41 # Test that A is reloaded.
32 42 obj = A.Object()
43 A.ok = False
33 44 dreload(B)
45 assert A.ok
34 46 assert isinstance(obj, A.Object) is False
47
48
49 def test_not_module():
50 pytest.raises(TypeError, dreload, "modulename")
51
52
53 def test_not_in_sys_modules():
54 fake_module = types.ModuleType("fake_module")
55 with pytest.raises(ImportError, match="not in sys.modules"):
56 dreload(fake_module)
General Comments 0
You need to be logged in to leave comments. Login now