##// END OF EJS Templates
ensure `__all__` and `__dir__` are defined on shims...
Min RK -
Show More
@@ -1,78 +1,90 b''
1 1 """A shim module for deprecated imports
2 2 """
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import sys
7 7 import types
8 8
9 9 from .importstring import import_item
10 10
11 11
12 12 class ShimImporter(object):
13 13 """Import hook for a shim.
14 14
15 15 This ensures that submodule imports return the real target module,
16 16 not a clone that will confuse `is` and `isinstance` checks.
17 17 """
18 18 def __init__(self, src, mirror):
19 19 self.src = src
20 20 self.mirror = mirror
21 21
22 22 def _mirror_name(self, fullname):
23 23 """get the name of the mirrored module"""
24 24
25 25 return self.mirror + fullname[len(self.src):]
26 26
27 27 def find_module(self, fullname, path=None):
28 28 """Return self if we should be used to import the module."""
29 29 if fullname.startswith(self.src + '.'):
30 30 mirror_name = self._mirror_name(fullname)
31 31 try:
32 32 mod = import_item(mirror_name)
33 33 except ImportError:
34 34 return
35 35 else:
36 36 if not isinstance(mod, types.ModuleType):
37 37 # not a module
38 38 return None
39 39 return self
40 40
41 41 def load_module(self, fullname):
42 42 """Import the mirrored module, and insert it into sys.modules"""
43 43 mirror_name = self._mirror_name(fullname)
44 44 mod = import_item(mirror_name)
45 45 sys.modules[fullname] = mod
46 46 return mod
47 47
48 48
49 49 class ShimModule(types.ModuleType):
50 50
51 51 def __init__(self, *args, **kwargs):
52 52 self._mirror = kwargs.pop("mirror")
53 53 src = kwargs.pop("src", None)
54 54 if src:
55 55 kwargs['name'] = src.rsplit('.', 1)[-1]
56 56 super(ShimModule, self).__init__(*args, **kwargs)
57 57 # add import hook for descendent modules
58 58 if src:
59 59 sys.meta_path.append(
60 60 ShimImporter(src=src, mirror=self._mirror)
61 61 )
62 62
63 63 @property
64 64 def __path__(self):
65 65 return []
66 66
67 67 @property
68 68 def __spec__(self):
69 69 """Don't produce __spec__ until requested"""
70 70 return __import__(self._mirror).__spec__
71
72 def __dir__(self):
73 return dir(__import__(self._mirror))
74
75 @property
76 def __all__(self):
77 """Ensure __all__ is always defined"""
78 mod = __import__(self._mirror)
79 try:
80 return mod.__all__
81 except AttributeError:
82 return [name for name in dir(mod) if not name.startswith('_')]
71 83
72 84 def __getattr__(self, key):
73 85 # Use the equivalent of import_item(name), see below
74 86 name = "%s.%s" % (self._mirror, key)
75 87 try:
76 88 return import_item(name)
77 89 except ImportError:
78 90 raise AttributeError(key)
General Comments 0
You need to be logged in to leave comments. Login now