diff --git a/IPython/lib/deepreload.py b/IPython/lib/deepreload.py index 81ac6bf..0f9c242 100644 --- a/IPython/lib/deepreload.py +++ b/IPython/lib/deepreload.py @@ -14,9 +14,9 @@ Alternatively, you can add a dreload builtin alongside normal reload with:: __builtin__.dreload = deepreload.reload -This code is almost entirely based on knee.py from the standard library. +This code is almost entirely based on knee.py, which is a Python +re-implementation of hierarchical module import. """ - #***************************************************************************** # Copyright (C) 2001 Nathaniel Gray # @@ -28,135 +28,267 @@ import __builtin__ import imp import sys -# Replacement for __import__() -def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1): - # For now level is ignored, it's just there to prevent crash - # with from __future__ import absolute_import - parent = determine_parent(globals) - q, tail = find_head_package(parent, name) - m = load_tail(q, tail) - if not fromlist: - return q - if hasattr(m, "__path__"): - ensure_fromlist(m, fromlist) - return m +from types import ModuleType +from warnings import warn + +def get_parent(globals, level): + """ + parent, name = get_parent(globals, level) + + Return the package that an import is being performed in. If globals comes + from the module foo.bar.bat (not itself a package), this returns the + sys.modules entry for foo.bar. If globals is from a package's __init__.py, + the package's entry in sys.modules is returned. + + If globals doesn't come from a package or a module in a package, or a + corresponding entry is not found in sys.modules, None is returned. + """ + orig_level = level -def determine_parent(globals): - if not globals or not globals.has_key("__name__"): - return None - pname = globals['__name__'] - if globals.has_key("__path__"): - parent = sys.modules[pname] - assert globals is parent.__dict__ - return parent - if '.' in pname: - i = pname.rfind('.') - pname = pname[:i] - parent = sys.modules[pname] - assert parent.__name__ == pname - return parent - return None - -def find_head_package(parent, name): - # Import the first - if '.' in name: - # 'some.nested.package' -> head = 'some', tail = 'nested.package' - i = name.find('.') - head = name[:i] - tail = name[i+1:] + if not level or not isinstance(globals, dict): + return None, '' + + pkgname = globals.get('__package__', None) + + if pkgname is not None: + # __package__ is set, so use it + if not hasattr(pkgname, 'rindex'): + raise ValueError('__package__ set to non-string') + if len(pkgname) == 0: + if level > 0: + raise ValueError('Attempted relative import in non-package') + return None, '' + name = pkgname else: - # 'packagename' -> head = 'packagename', tail = '' - head = name - tail = "" - if parent: - # If this is a subpackage then qname = parent's name + head - qname = "%s.%s" % (parent.__name__, head) + # __package__ not set, so figure it out and set it + if '__name__' not in globals: + return None, '' + modname = globals['__name__'] + + if '__path__' in globals: + # __path__ is set, so modname is already the package name + globals['__package__'] = name = modname + else: + # Normal module, so work out the package name if any + lastdot = modname.rfind('.') + if lastdot < 0 and level > 0: + raise ValueError("Attempted relative import in non-package") + if lastdot < 0: + globals['__package__'] = None + return None, '' + globals['__package__'] = name = modname[:lastdot] + + dot = len(name) + for x in xrange(level, 1, -1): + try: + dot = name.rindex('.', 0, dot) + except ValueError: + raise ValueError("attempted relative import beyond top-level " + "package") + name = name[:dot] + + try: + parent = sys.modules[name] + except: + if orig_level < 1: + warn("Parent module '%.200s' not found while handling absolute " + "import" % name) + parent = None + else: + raise SystemError("Parent module '%.200s' not loaded, cannot " + "perform relative import" % name) + + # We expect, but can't guarantee, if parent != None, that: + # - parent.__name__ == name + # - parent.__dict__ is globals + # If this is violated... Who cares? + return parent, name + +def load_next(mod, altmod, name, buf): + """ + mod, name, buf = load_next(mod, altmod, name, buf) + + altmod is either None or same as mod + """ + + if len(name) == 0: + # completely empty module name should only happen in + # 'from . import' (or '__import__("")') + return mod, None, buf + + dot = name.find('.') + if dot == 0: + raise ValueError('Empty module name') + + if dot < 0: + subname = name + next = None else: - qname = head - q = import_module(head, qname, parent) - if q: return q, tail - if parent: - qname = head - parent = None - q = import_module(head, qname, parent) - if q: return q, tail - raise ImportError, "No module named " + qname - -def load_tail(q, tail): - m = q - while tail: - i = tail.find('.') - if i < 0: i = len(tail) - head, tail = tail[:i], tail[i+1:] - - # fperez: fix dotted.name reloading failures by changing: - #mname = "%s.%s" % (m.__name__, head) - # to: - mname = m.__name__ - # This needs more testing!!! (I don't understand this module too well) - - #print '** head,tail=|%s|->|%s|, mname=|%s|' % (head,tail,mname) # dbg - m = import_module(head, mname, m) - if not m: - raise ImportError, "No module named " + mname - return m + subname = name[:dot] + next = name[dot+1:] -def ensure_fromlist(m, fromlist, recursive=0): - for sub in fromlist: - if sub == "*": - if not recursive: - try: - all = m.__all__ - except AttributeError: - pass - else: - ensure_fromlist(m, all, 1) - continue - if sub != "*" and not hasattr(m, sub): - subname = "%s.%s" % (m.__name__, sub) - submod = import_module(sub, subname, m) - if not submod: - raise ImportError, "No module named " + subname + if buf != '': + buf += '.' + buf += subname + + result = import_submodule(mod, subname, buf) + if result is None and mod != altmod: + result = import_submodule(altmod, subname, subname) + if result is not None: + buf = subname + + if result is None: + raise ImportError("No module named %.200s" % name) + + return result, next, buf # Need to keep track of what we've already reloaded to prevent cyclic evil found_now = {} -def import_module(partname, fqname, parent): +def import_submodule(mod, subname, fullname): + """m = import_submodule(mod, subname, fullname)""" + # Require: + # if mod == None: subname == fullname + # else: mod.__name__ + "." + subname == fullname + global found_now - if found_now.has_key(fqname): + if fullname in found_now and fullname in sys.modules: + m = sys.modules[fullname] + else: + print 'Reloading', fullname + found_now[fullname] = 1 + oldm = sys.modules.get(fullname, None) + + if mod is None: + path = None + elif hasattr(mod, '__path__'): + path = mod.__path__ + else: + return None + try: - return sys.modules[fqname] - except KeyError: - pass + fp, filename, stuff = imp.find_module(subname, path) + except ImportError: + return None + + try: + m = imp.load_module(fullname, fp, filename, stuff) + except: + # load_module probably removed name from modules because of + # the error. Put back the original module object. + if oldm: + sys.modules[fullname] = oldm + raise + finally: + if fp: fp.close() + + add_submodule(mod, m, fullname, subname) + + return m + +def add_submodule(mod, submod, fullname, subname): + """mod.{subname} = submod""" + if mod is None: + return #Nothing to do here. + + if submod is None: + submod = sys.modules[fullname] - print 'Reloading', fqname #, sys.excepthook is sys.__excepthook__, \ - #sys.displayhook is sys.__displayhook__ + setattr(mod, subname, submod) + + return + +def ensure_fromlist(mod, fromlist, buf, recursive): + """Handle 'from module import a, b, c' imports.""" + if not hasattr(mod, '__path__'): + return + for item in fromlist: + if not hasattr(item, 'rindex'): + raise TypeError("Item in ``from list'' not a string") + if item == '*': + if recursive: + continue # avoid endless recursion + try: + all = mod.__all__ + except AttributeError: + pass + else: + ret = ensure_fromlist(mod, all, buf, 1) + if not ret: + return 0 + elif not hasattr(mod, item): + import_submodule(mod, item, buf + '.' + item) + +def import_module_level(name, globals=None, locals=None, fromlist=None, level=-1): + """Replacement for __import__()""" + parent, buf = get_parent(globals, level) + + head, name, buf = load_next(parent, None if level < 0 else parent, name, buf) + + tail = head + while name: + tail, name, buf = load_next(tail, tail, name, buf) + + # If tail is None, both get_parent and load_next found + # an empty module name: someone called __import__("") or + # doctored faulty bytecode + if tail is None: + raise ValueError('Empty module name') + + if not fromlist: + return head - found_now[fqname] = 1 + ensure_fromlist(tail, fromlist, buf, 0) + return tail + +modules_reloading = {} + +def reload_module(m): + """Replacement for reload().""" + if not isinstance(m, ModuleType): + raise TypeError("reload() argument must be module") + + name = m.__name__ + + if name not in sys.modules: + raise ImportError("reload(): module %.200s not in sys.modules" % name) + + global modules_reloading try: - fp, pathname, stuff = imp.find_module(partname, - parent and parent.__path__) - except ImportError: - return None + return modules_reloading[name] + except: + modules_reloading[name] = m + + dot = name.rfind('.') + if dot < 0: + subname = name + path = None + else: + try: + parent = sys.modules[name[:dot]] + except KeyError: + modules_reloading.clear() + raise ImportError("reload(): parent %.200s not in sys.modules" % name[:dot]) + subname = name[dot+1:] + path = getattr(parent, "__path__", None) try: - m = imp.load_module(fqname, fp, pathname, stuff) + fp, filename, stuff = imp.find_module(subname, path) finally: - if fp: fp.close() + modules_reloading.clear() - if parent: - setattr(parent, partname, m) - - return m + try: + newm = imp.load_module(name, fp, filename, stuff) + except: + # load_module probably removed name from modules because of + # the error. Put back the original module object. + sys.modules[name] = m + raise + finally: + if fp: fp.close() -def deep_reload_hook(module): - name = module.__name__ - if '.' not in name: - return import_module(name, name, None) - i = name.rfind('.') - pname = name[:i] - parent = sys.modules[pname] - return import_module(name[i+1:], name, parent) + modules_reloading.clear() + return newm # Save the original hooks try: @@ -165,7 +297,7 @@ except AttributeError: original_reload = imp.reload # Python 3 # Replacement for reload() -def reload(module, exclude=['sys', '__builtin__', '__main__']): +def reload(module, exclude=['sys', 'os.path', '__builtin__', '__main__']): """Recursively reload all modules used in the given module. Optionally takes a list of modules to exclude from reloading. The default exclude list contains sys, __main__, and __builtin__, to prevent, e.g., resetting @@ -175,9 +307,9 @@ def reload(module, exclude=['sys', '__builtin__', '__main__']): for i in exclude: found_now[i] = 1 original_import = __builtin__.__import__ - __builtin__.__import__ = deep_import_hook + __builtin__.__import__ = import_module_level try: - ret = deep_reload_hook(module) + ret = reload_module(module) finally: __builtin__.__import__ = original_import found_now = {}