##// END OF EJS Templates
ShimImporter: implement modern interface
Nikita Kniazev -
Show More
@@ -1,94 +1,81 b''
1 """A shim module for deprecated imports
1 """A shim module for deprecated imports
2 """
2 """
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import importlib.abc
7 import importlib.util
6 import sys
8 import sys
7 import types
9 import types
8 from importlib import import_module
10 from importlib import import_module
9
11
10 from .importstring import import_item
12 from .importstring import import_item
11
13
12
14
13 class ShimWarning(Warning):
15 class ShimWarning(Warning):
14 """A warning to show when a module has moved, and a shim is in its place."""
16 """A warning to show when a module has moved, and a shim is in its place."""
15
17
16 class ShimImporter(object):
18
19 class ShimImporter(importlib.abc.MetaPathFinder):
17 """Import hook for a shim.
20 """Import hook for a shim.
18
21
19 This ensures that submodule imports return the real target module,
22 This ensures that submodule imports return the real target module,
20 not a clone that will confuse `is` and `isinstance` checks.
23 not a clone that will confuse `is` and `isinstance` checks.
21 """
24 """
22 def __init__(self, src, mirror):
25 def __init__(self, src, mirror):
23 self.src = src
26 self.src = src
24 self.mirror = mirror
27 self.mirror = mirror
25
28
26 def _mirror_name(self, fullname):
29 def _mirror_name(self, fullname):
27 """get the name of the mirrored module"""
30 """get the name of the mirrored module"""
28
29 return self.mirror + fullname[len(self.src):]
30
31
31 def find_module(self, fullname, path=None):
32 return self.mirror + fullname[len(self.src) :]
32 """Return self if we should be used to import the module."""
33 if fullname.startswith(self.src + '.'):
34 mirror_name = self._mirror_name(fullname)
35 try:
36 mod = import_item(mirror_name)
37 except ImportError:
38 return
39 else:
40 if not isinstance(mod, types.ModuleType):
41 # not a module
42 return None
43 return self
44
33
45 def load_module(self, fullname):
34 def find_spec(self, fullname, path, target=None):
46 """Import the mirrored module, and insert it into sys.modules"""
35 if fullname.startswith(self.src + "."):
47 mirror_name = self._mirror_name(fullname)
36 mirror_name = self._mirror_name(fullname)
48 mod = import_item(mirror_name)
37 return importlib.util.find_spec(mirror_name)
49 sys.modules[fullname] = mod
50 return mod
51
38
52
39
53 class ShimModule(types.ModuleType):
40 class ShimModule(types.ModuleType):
54
41
55 def __init__(self, *args, **kwargs):
42 def __init__(self, *args, **kwargs):
56 self._mirror = kwargs.pop("mirror")
43 self._mirror = kwargs.pop("mirror")
57 src = kwargs.pop("src", None)
44 src = kwargs.pop("src", None)
58 if src:
45 if src:
59 kwargs['name'] = src.rsplit('.', 1)[-1]
46 kwargs['name'] = src.rsplit('.', 1)[-1]
60 super(ShimModule, self).__init__(*args, **kwargs)
47 super(ShimModule, self).__init__(*args, **kwargs)
61 # add import hook for descendent modules
48 # add import hook for descendent modules
62 if src:
49 if src:
63 sys.meta_path.append(
50 sys.meta_path.append(
64 ShimImporter(src=src, mirror=self._mirror)
51 ShimImporter(src=src, mirror=self._mirror)
65 )
52 )
66
53
67 @property
54 @property
68 def __path__(self):
55 def __path__(self):
69 return []
56 return []
70
57
71 @property
58 @property
72 def __spec__(self):
59 def __spec__(self):
73 """Don't produce __spec__ until requested"""
60 """Don't produce __spec__ until requested"""
74 return import_module(self._mirror).__spec__
61 return import_module(self._mirror).__spec__
75
62
76 def __dir__(self):
63 def __dir__(self):
77 return dir(import_module(self._mirror))
64 return dir(import_module(self._mirror))
78
65
79 @property
66 @property
80 def __all__(self):
67 def __all__(self):
81 """Ensure __all__ is always defined"""
68 """Ensure __all__ is always defined"""
82 mod = import_module(self._mirror)
69 mod = import_module(self._mirror)
83 try:
70 try:
84 return mod.__all__
71 return mod.__all__
85 except AttributeError:
72 except AttributeError:
86 return [name for name in dir(mod) if not name.startswith('_')]
73 return [name for name in dir(mod) if not name.startswith('_')]
87
74
88 def __getattr__(self, key):
75 def __getattr__(self, key):
89 # Use the equivalent of import_item(name), see below
76 # Use the equivalent of import_item(name), see below
90 name = "%s.%s" % (self._mirror, key)
77 name = "%s.%s" % (self._mirror, key)
91 try:
78 try:
92 return import_item(name)
79 return import_item(name)
93 except ImportError as e:
80 except ImportError as e:
94 raise AttributeError(key) from e
81 raise AttributeError(key) from e
@@ -1,10 +1,14 b''
1 import pytest
1 import pytest
2 import sys
2 import sys
3
3
4 from IPython.utils.shimmodule import ShimWarning
4 from IPython.utils.shimmodule import ShimWarning
5
5
6
6
7 def test_shim_warning():
7 def test_shim_warning():
8 sys.modules.pop('IPython.config', None)
8 sys.modules.pop('IPython.config', None)
9 with pytest.warns(ShimWarning):
9 with pytest.warns(ShimWarning):
10 import IPython.config
10 import IPython.config
11
12 import traitlets.config
13
14 assert IPython.config.Config is traitlets.config.Config
General Comments 0
You need to be logged in to leave comments. Login now