##// END OF EJS Templates
ShimImporter: implement modern interface
Nikita Kniazev -
Show More
@@ -3,6 +3,8 b''
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import importlib.abc
7 import importlib.util
6 import sys
8 import sys
7 import types
9 import types
8 from importlib import import_module
10 from importlib import import_module
@@ -13,41 +15,26 b' from .importstring import import_item'
13 class ShimWarning(Warning):
15 class ShimWarning(Warning):
14 """A warning to show when a module has moved, and a shim is in its place."""
16 """A warning to show when a module has moved, and a shim is in its place."""
15
17
16 class ShimImporter(object):
18
19 class ShimImporter(importlib.abc.MetaPathFinder):
17 """Import hook for a shim.
20 """Import hook for a shim.
18
21
19 This ensures that submodule imports return the real target module,
22 This ensures that submodule imports return the real target module,
20 not a clone that will confuse `is` and `isinstance` checks.
23 not a clone that will confuse `is` and `isinstance` checks.
21 """
24 """
22 def __init__(self, src, mirror):
25 def __init__(self, src, mirror):
23 self.src = src
26 self.src = src
24 self.mirror = mirror
27 self.mirror = mirror
25
28
26 def _mirror_name(self, fullname):
29 def _mirror_name(self, fullname):
27 """get the name of the mirrored module"""
30 """get the name of the mirrored module"""
28
29 return self.mirror + fullname[len(self.src):]
30
31
31 def find_module(self, fullname, path=None):
32 return self.mirror + fullname[len(self.src) :]
32 """Return self if we should be used to import the module."""
33 if fullname.startswith(self.src + '.'):
34 mirror_name = self._mirror_name(fullname)
35 try:
36 mod = import_item(mirror_name)
37 except ImportError:
38 return
39 else:
40 if not isinstance(mod, types.ModuleType):
41 # not a module
42 return None
43 return self
44
33
45 def load_module(self, fullname):
34 def find_spec(self, fullname, path, target=None):
46 """Import the mirrored module, and insert it into sys.modules"""
35 if fullname.startswith(self.src + "."):
47 mirror_name = self._mirror_name(fullname)
36 mirror_name = self._mirror_name(fullname)
48 mod = import_item(mirror_name)
37 return importlib.util.find_spec(mirror_name)
49 sys.modules[fullname] = mod
50 return mod
51
38
52
39
53 class ShimModule(types.ModuleType):
40 class ShimModule(types.ModuleType):
@@ -8,3 +8,7 b' def test_shim_warning():'
8 sys.modules.pop('IPython.config', None)
8 sys.modules.pop('IPython.config', None)
9 with pytest.warns(ShimWarning):
9 with pytest.warns(ShimWarning):
10 import IPython.config
10 import IPython.config
11
12 import traitlets.config
13
14 assert IPython.config.Config is traitlets.config.Config
General Comments 0
You need to be logged in to leave comments. Login now