##// END OF EJS Templates
ensure `__all__` and `__dir__` are defined on shims...
Min RK -
Show More
@@ -1,78 +1,90 b''
1 """A shim module for deprecated imports
1 """A shim module for deprecated imports
2 """
2 """
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import sys
6 import sys
7 import types
7 import types
8
8
9 from .importstring import import_item
9 from .importstring import import_item
10
10
11
11
12 class ShimImporter(object):
12 class ShimImporter(object):
13 """Import hook for a shim.
13 """Import hook for a shim.
14
14
15 This ensures that submodule imports return the real target module,
15 This ensures that submodule imports return the real target module,
16 not a clone that will confuse `is` and `isinstance` checks.
16 not a clone that will confuse `is` and `isinstance` checks.
17 """
17 """
18 def __init__(self, src, mirror):
18 def __init__(self, src, mirror):
19 self.src = src
19 self.src = src
20 self.mirror = mirror
20 self.mirror = mirror
21
21
22 def _mirror_name(self, fullname):
22 def _mirror_name(self, fullname):
23 """get the name of the mirrored module"""
23 """get the name of the mirrored module"""
24
24
25 return self.mirror + fullname[len(self.src):]
25 return self.mirror + fullname[len(self.src):]
26
26
27 def find_module(self, fullname, path=None):
27 def find_module(self, fullname, path=None):
28 """Return self if we should be used to import the module."""
28 """Return self if we should be used to import the module."""
29 if fullname.startswith(self.src + '.'):
29 if fullname.startswith(self.src + '.'):
30 mirror_name = self._mirror_name(fullname)
30 mirror_name = self._mirror_name(fullname)
31 try:
31 try:
32 mod = import_item(mirror_name)
32 mod = import_item(mirror_name)
33 except ImportError:
33 except ImportError:
34 return
34 return
35 else:
35 else:
36 if not isinstance(mod, types.ModuleType):
36 if not isinstance(mod, types.ModuleType):
37 # not a module
37 # not a module
38 return None
38 return None
39 return self
39 return self
40
40
41 def load_module(self, fullname):
41 def load_module(self, fullname):
42 """Import the mirrored module, and insert it into sys.modules"""
42 """Import the mirrored module, and insert it into sys.modules"""
43 mirror_name = self._mirror_name(fullname)
43 mirror_name = self._mirror_name(fullname)
44 mod = import_item(mirror_name)
44 mod = import_item(mirror_name)
45 sys.modules[fullname] = mod
45 sys.modules[fullname] = mod
46 return mod
46 return mod
47
47
48
48
49 class ShimModule(types.ModuleType):
49 class ShimModule(types.ModuleType):
50
50
51 def __init__(self, *args, **kwargs):
51 def __init__(self, *args, **kwargs):
52 self._mirror = kwargs.pop("mirror")
52 self._mirror = kwargs.pop("mirror")
53 src = kwargs.pop("src", None)
53 src = kwargs.pop("src", None)
54 if src:
54 if src:
55 kwargs['name'] = src.rsplit('.', 1)[-1]
55 kwargs['name'] = src.rsplit('.', 1)[-1]
56 super(ShimModule, self).__init__(*args, **kwargs)
56 super(ShimModule, self).__init__(*args, **kwargs)
57 # add import hook for descendent modules
57 # add import hook for descendent modules
58 if src:
58 if src:
59 sys.meta_path.append(
59 sys.meta_path.append(
60 ShimImporter(src=src, mirror=self._mirror)
60 ShimImporter(src=src, mirror=self._mirror)
61 )
61 )
62
62
63 @property
63 @property
64 def __path__(self):
64 def __path__(self):
65 return []
65 return []
66
66
67 @property
67 @property
68 def __spec__(self):
68 def __spec__(self):
69 """Don't produce __spec__ until requested"""
69 """Don't produce __spec__ until requested"""
70 return __import__(self._mirror).__spec__
70 return __import__(self._mirror).__spec__
71
71
72 def __dir__(self):
73 return dir(__import__(self._mirror))
74
75 @property
76 def __all__(self):
77 """Ensure __all__ is always defined"""
78 mod = __import__(self._mirror)
79 try:
80 return mod.__all__
81 except AttributeError:
82 return [name for name in dir(mod) if not name.startswith('_')]
83
72 def __getattr__(self, key):
84 def __getattr__(self, key):
73 # Use the equivalent of import_item(name), see below
85 # Use the equivalent of import_item(name), see below
74 name = "%s.%s" % (self._mirror, key)
86 name = "%s.%s" % (self._mirror, key)
75 try:
87 try:
76 return import_item(name)
88 return import_item(name)
77 except ImportError:
89 except ImportError:
78 raise AttributeError(key)
90 raise AttributeError(key)
General Comments 0
You need to be logged in to leave comments. Login now