##// END OF EJS Templates
Don't find importable non-modules in shim
Min RK -
Show More
@@ -1,75 +1,78 b''
1 1 """A shim module for deprecated imports
2 2 """
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import sys
7 7 import types
8 8
9 9 from .importstring import import_item
10 10
11 11
12 12 class ShimImporter(object):
13 13 """Import hook for a shim.
14 14
15 15 This ensures that submodule imports return the real target module,
16 16 not a clone that will confuse `is` and `isinstance` checks.
17 17 """
18 18 def __init__(self, src, mirror):
19 19 self.src = src
20 20 self.mirror = mirror
21 21
22 22 def _mirror_name(self, fullname):
23 23 """get the name of the mirrored module"""
24 24
25 25 return self.mirror + fullname[len(self.src):]
26 26
27 27 def find_module(self, fullname, path=None):
28 28 """Return self if we should be used to import the module."""
29 29 if fullname.startswith(self.src + '.'):
30 30 mirror_name = self._mirror_name(fullname)
31 31 try:
32 __import__(mirror_name)
32 mod = import_item(mirror_name)
33 33 except ImportError:
34 34 return
35 35 else:
36 if not isinstance(mod, types.ModuleType):
37 # not a module
38 return None
36 39 return self
37 40
38 41 def load_module(self, fullname):
39 42 """Import the mirrored module, and insert it into sys.modules"""
40 43 mirror_name = self._mirror_name(fullname)
41 44 mod = import_item(mirror_name)
42 45 sys.modules[fullname] = mod
43 46 return mod
44 47
45 48
46 49 class ShimModule(types.ModuleType):
47 50
48 51 def __init__(self, *args, **kwargs):
49 52 self._mirror = kwargs.pop("mirror")
50 53 src = kwargs.pop("src", None)
51 54 if src:
52 55 kwargs['name'] = src.rsplit('.', 1)[-1]
53 56 super(ShimModule, self).__init__(*args, **kwargs)
54 57 # add import hook for descendent modules
55 58 if src:
56 59 sys.meta_path.append(
57 60 ShimImporter(src=src, mirror=self._mirror)
58 61 )
59 62
60 63 @property
61 64 def __path__(self):
62 65 return []
63 66
64 67 @property
65 68 def __spec__(self):
66 69 """Don't produce __spec__ until requested"""
67 70 return __import__(self._mirror).__spec__
68 71
69 72 def __getattr__(self, key):
70 73 # Use the equivalent of import_item(name), see below
71 74 name = "%s.%s" % (self._mirror, key)
72 75 try:
73 76 return import_item(name)
74 77 except ImportError:
75 78 raise AttributeError(key)
General Comments 0
You need to be logged in to leave comments. Login now