From 088eb7d0a214dd341248c4f02c3ea43144a09530 2011-10-01 00:54:11 From: Fernando Perez Date: 2011-10-01 00:54:11 Subject: [PATCH] Merge branch 'autoreload': port autoreload extension to current API This restores an extension that had been in quarantine since our 0.11 refactoring. --- diff --git a/IPython/extensions/autoreload.py b/IPython/extensions/autoreload.py new file mode 100644 index 0000000..e62a5fb --- /dev/null +++ b/IPython/extensions/autoreload.py @@ -0,0 +1,483 @@ +""" +``autoreload`` is an IPython extension that reloads modules +automatically before executing the line of code typed. + +This makes for example the following workflow possible: + +.. sourcecode:: ipython + + In [1]: %load_ext autoreload + + In [2]: %autoreload 2 + + In [3]: from foo import some_function + + In [4]: some_function() + Out[4]: 42 + + In [5]: # open foo.py in an editor and change some_function to return 43 + + In [6]: some_function() + Out[6]: 43 + +The module was reloaded without reloading it explicitly, and the +object imported with ``from foo import ...`` was also updated. + +Usage +===== + +The following magic commands are provided: + +``%autoreload`` + + Reload all modules (except those excluded by ``%aimport``) + automatically now. + +``%autoreload 0`` + + Disable automatic reloading. + +``%autoreload 1`` + + Reload all modules imported with ``%aimport`` every time before + executing the Python code typed. + +``%autoreload 2`` + + Reload all modules (except those excluded by ``%aimport``) every + time before executing the Python code typed. + +``%aimport`` + + List modules which are to be automatically imported or not to be imported. + +``%aimport foo`` + + Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1`` + +``%aimport -foo`` + + Mark module 'foo' to not be autoreloaded. + +Caveats +======= + +Reloading Python modules in a reliable way is in general difficult, +and unexpected things may occur. ``%autoreload`` tries to work around +common pitfalls by replacing function code objects and parts of +classes previously in the module with new versions. This makes the +following things to work: + +- Functions and classes imported via 'from xxx import foo' are upgraded + to new versions when 'xxx' is reloaded. + +- Methods and properties of classes are upgraded on reload, so that + calling 'c.foo()' on an object 'c' created before the reload causes + the new code for 'foo' to be executed. + +Some of the known remaining caveats are: + +- Replacing code objects does not always succeed: changing a @property + in a class to an ordinary method or a method to a member variable + can cause problems (but in old objects only). + +- Functions that are removed (eg. via monkey-patching) from a module + before it is reloaded are not upgraded. + +- C extension modules cannot be reloaded, and so cannot be + autoreloaded. + +""" + +skip_doctest = True + +# Pauli Virtanen , 2008. +# Thomas Heller, 2000. +# +# This IPython module is written by Pauli Virtanen, based on the autoreload +# code by Thomas Heller. + +#------------------------------------------------------------------------------ +# Autoreload functionality +#------------------------------------------------------------------------------ + +import time, os, threading, sys, types, imp, inspect, traceback, atexit +import weakref + +def _get_compiled_ext(): + """Official way to get the extension of compiled files (.pyc or .pyo)""" + for ext, mode, typ in imp.get_suffixes(): + if typ == imp.PY_COMPILED: + return ext + +PY_COMPILED_EXT = _get_compiled_ext() + +class ModuleReloader(object): + enabled = False + """Whether this reloader is enabled""" + + failed = {} + """Modules that failed to reload: {module: mtime-on-failed-reload, ...}""" + + modules = {} + """Modules specially marked as autoreloadable.""" + + skip_modules = {} + """Modules specially marked as not autoreloadable.""" + + check_all = True + """Autoreload all modules, not just those listed in 'modules'""" + + old_objects = {} + """(module-name, name) -> weakref, for replacing old code objects""" + + def mark_module_skipped(self, module_name): + """Skip reloading the named module in the future""" + try: + del self.modules[module_name] + except KeyError: + pass + self.skip_modules[module_name] = True + + def mark_module_reloadable(self, module_name): + """Reload the named module in the future (if it is imported)""" + try: + del self.skip_modules[module_name] + except KeyError: + pass + self.modules[module_name] = True + + def aimport_module(self, module_name): + """Import a module, and mark it reloadable + + Returns + ------- + top_module : module + The imported module if it is top-level, or the top-level + top_name : module + Name of top_module + + """ + self.mark_module_reloadable(module_name) + + __import__(module_name) + top_name = module_name.split('.')[0] + top_module = sys.modules[top_name] + return top_module, top_name + + def check(self, check_all=False): + """Check whether some modules need to be reloaded.""" + + if not self.enabled and not check_all: + return + + if check_all or self.check_all: + modules = sys.modules.keys() + else: + modules = self.modules.keys() + + for modname in modules: + m = sys.modules.get(modname, None) + + if modname in self.skip_modules: + continue + + if not hasattr(m, '__file__'): + continue + + if m.__name__ == '__main__': + # we cannot reload(__main__) + continue + + filename = m.__file__ + path, ext = os.path.splitext(filename) + + if ext.lower() == '.py': + ext = PY_COMPILED_EXT + pyc_filename = path + PY_COMPILED_EXT + py_filename = filename + else: + pyc_filename = filename + py_filename = filename[:-1] + + if ext != PY_COMPILED_EXT: + continue + + try: + pymtime = os.stat(py_filename).st_mtime + if pymtime <= os.stat(pyc_filename).st_mtime: + continue + if self.failed.get(py_filename, None) == pymtime: + continue + except OSError: + continue + + try: + superreload(m, reload, self.old_objects) + if py_filename in self.failed: + del self.failed[py_filename] + except: + print >> sys.stderr, "[autoreload of %s failed: %s]" % ( + modname, traceback.format_exc(1)) + self.failed[py_filename] = pymtime + +#------------------------------------------------------------------------------ +# superreload +#------------------------------------------------------------------------------ + +def update_function(old, new): + """Upgrade the code object of a function""" + for name in ['func_code', 'func_defaults', 'func_doc', + 'func_closure', 'func_globals', 'func_dict']: + try: + setattr(old, name, getattr(new, name)) + except (AttributeError, TypeError): + pass + +def update_class(old, new): + """Replace stuff in the __dict__ of a class, and upgrade + method code objects""" + for key in old.__dict__.keys(): + old_obj = getattr(old, key) + + try: + new_obj = getattr(new, key) + except AttributeError: + # obsolete attribute: remove it + try: + delattr(old, key) + except (AttributeError, TypeError): + pass + continue + + if update_generic(old_obj, new_obj): continue + + try: + setattr(old, key, getattr(new, key)) + except (AttributeError, TypeError): + pass # skip non-writable attributes + +def update_property(old, new): + """Replace get/set/del functions of a property""" + update_generic(old.fdel, new.fdel) + update_generic(old.fget, new.fget) + update_generic(old.fset, new.fset) + +def isinstance2(a, b, typ): + return isinstance(a, typ) and isinstance(b, typ) + +UPDATE_RULES = [ + (lambda a, b: isinstance2(a, b, types.ClassType), + update_class), + (lambda a, b: isinstance2(a, b, types.TypeType), + update_class), + (lambda a, b: isinstance2(a, b, types.FunctionType), + update_function), + (lambda a, b: isinstance2(a, b, property), + update_property), + (lambda a, b: isinstance2(a, b, types.MethodType), + lambda a, b: update_function(a.im_func, b.im_func)), +] + +def update_generic(a, b): + for type_check, update in UPDATE_RULES: + if type_check(a, b): + update(a, b) + return True + return False + +class StrongRef(object): + def __init__(self, obj): + self.obj = obj + def __call__(self): + return self.obj + +def superreload(module, reload=reload, old_objects={}): + """Enhanced version of the builtin reload function. + + superreload remembers objects previously in the module, and + + - upgrades the class dictionary of every old class in the module + - upgrades the code object of every old function and method + - clears the module's namespace before reloading + + """ + + # collect old objects in the module + for name, obj in module.__dict__.items(): + if not hasattr(obj, '__module__') or obj.__module__ != module.__name__: + continue + key = (module.__name__, name) + try: + old_objects.setdefault(key, []).append(weakref.ref(obj)) + except TypeError: + # weakref doesn't work for all types; + # create strong references for 'important' cases + if isinstance(obj, types.ClassType): + old_objects.setdefault(key, []).append(StrongRef(obj)) + + # reload module + try: + # clear namespace first from old cruft + old_dict = module.__dict__.copy() + old_name = module.__name__ + module.__dict__.clear() + module.__dict__['__name__'] = old_name + except (TypeError, AttributeError, KeyError): + pass + + try: + module = reload(module) + except: + # restore module dictionary on failed reload + module.__dict__.update(old_dict) + raise + + # iterate over all objects and update functions & classes + for name, new_obj in module.__dict__.items(): + key = (module.__name__, name) + if key not in old_objects: continue + + new_refs = [] + for old_ref in old_objects[key]: + old_obj = old_ref() + if old_obj is None: continue + new_refs.append(old_ref) + update_generic(old_obj, new_obj) + + if new_refs: + old_objects[key] = new_refs + else: + del old_objects[key] + + return module + +#------------------------------------------------------------------------------ +# IPython connectivity +#------------------------------------------------------------------------------ + +from IPython.core.plugin import Plugin +from IPython.core.hooks import TryNext + +class AutoreloadInterface(object): + def __init__(self, *a, **kw): + super(AutoreloadInterface, self).__init__(*a, **kw) + self._reloader = ModuleReloader() + self._reloader.check_all = False + + def magic_autoreload(self, ipself, parameter_s=''): + r"""%autoreload => Reload modules automatically + + %autoreload + Reload all modules (except those excluded by %aimport) automatically + now. + + %autoreload 0 + Disable automatic reloading. + + %autoreload 1 + Reload all modules imported with %aimport every time before executing + the Python code typed. + + %autoreload 2 + Reload all modules (except those excluded by %aimport) every time + before executing the Python code typed. + + Reloading Python modules in a reliable way is in general + difficult, and unexpected things may occur. %autoreload tries to + work around common pitfalls by replacing function code objects and + parts of classes previously in the module with new versions. This + makes the following things to work: + + - Functions and classes imported via 'from xxx import foo' are upgraded + to new versions when 'xxx' is reloaded. + + - Methods and properties of classes are upgraded on reload, so that + calling 'c.foo()' on an object 'c' created before the reload causes + the new code for 'foo' to be executed. + + Some of the known remaining caveats are: + + - Replacing code objects does not always succeed: changing a @property + in a class to an ordinary method or a method to a member variable + can cause problems (but in old objects only). + + - Functions that are removed (eg. via monkey-patching) from a module + before it is reloaded are not upgraded. + + - C extension modules cannot be reloaded, and so cannot be + autoreloaded. + + """ + if parameter_s == '': + self._reloader.check(True) + elif parameter_s == '0': + self._reloader.enabled = False + elif parameter_s == '1': + self._reloader.check_all = False + self._reloader.enabled = True + elif parameter_s == '2': + self._reloader.check_all = True + self._reloader.enabled = True + + def magic_aimport(self, ipself, parameter_s='', stream=None): + """%aimport => Import modules for automatic reloading. + + %aimport + List modules to automatically import and not to import. + + %aimport foo + Import module 'foo' and mark it to be autoreloaded for %autoreload 1 + + %aimport -foo + Mark module 'foo' to not be autoreloaded for %autoreload 1 + + """ + + modname = parameter_s + if not modname: + to_reload = self._reloader.modules.keys() + to_reload.sort() + to_skip = self._reloader.skip_modules.keys() + to_skip.sort() + if stream is None: + stream = sys.stdout + if self._reloader.check_all: + stream.write("Modules to reload:\nall-except-skipped\n") + else: + stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload)) + stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip)) + elif modname.startswith('-'): + modname = modname[1:] + self._reloader.mark_module_skipped(modname) + else: + top_module, top_name = self._reloader.aimport_module(modname) + + # Inject module to user namespace + ipself.push({top_name: top_module}) + + def pre_run_code_hook(self, ipself): + if not self._reloader.enabled: + raise TryNext + try: + self._reloader.check() + except: + pass + +class AutoreloadPlugin(AutoreloadInterface, Plugin): + def __init__(self, shell=None, config=None): + super(AutoreloadPlugin, self).__init__(shell=shell, config=config) + + self.shell.define_magic('autoreload', self.magic_autoreload) + self.shell.define_magic('aimport', self.magic_aimport) + self.shell.set_hook('pre_run_code_hook', self.pre_run_code_hook) + +_loaded = False + +def load_ipython_extension(ip): + """Load the extension in IPython.""" + global _loaded + if not _loaded: + plugin = AutoreloadPlugin(shell=ip, config=ip.config) + ip.plugin_manager.register_plugin('autoreload', plugin) + _loaded = True diff --git a/IPython/extensions/tests/test_autoreload.py b/IPython/extensions/tests/test_autoreload.py new file mode 100644 index 0000000..c0bb425 --- /dev/null +++ b/IPython/extensions/tests/test_autoreload.py @@ -0,0 +1,307 @@ +import os +import sys +import tempfile +import shutil +import random +import time +from StringIO import StringIO + +import nose.tools as nt + +from IPython.extensions.autoreload import AutoreloadInterface +from IPython.core.hooks import TryNext + +#----------------------------------------------------------------------------- +# Test fixture +#----------------------------------------------------------------------------- + +class FakeShell(object): + def __init__(self): + self.ns = {} + self.reloader = AutoreloadInterface() + + def run_code(self, code): + try: + self.reloader.pre_run_code_hook(self) + except TryNext: + pass + exec code in self.ns + + def push(self, items): + self.ns.update(items) + + def magic_autoreload(self, parameter): + self.reloader.magic_autoreload(self, parameter) + + def magic_aimport(self, parameter, stream=None): + self.reloader.magic_aimport(self, parameter, stream=stream) + + +class Fixture(object): + """Fixture for creating test module files""" + + test_dir = None + old_sys_path = None + filename_chars = "abcdefghijklmopqrstuvwxyz0123456789" + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + self.old_sys_path = list(sys.path) + sys.path.insert(0, self.test_dir) + self.shell = FakeShell() + + def tearDown(self): + shutil.rmtree(self.test_dir) + sys.path = self.old_sys_path + self.shell.reloader.enabled = False + + self.test_dir = None + self.old_sys_path = None + self.shell = None + + def get_module(self): + module_name = "tmpmod_" + "".join(random.sample(self.filename_chars,20)) + if module_name in sys.modules: + del sys.modules[module_name] + file_name = os.path.join(self.test_dir, module_name + ".py") + return module_name, file_name + + def write_file(self, filename, content): + """ + Write a file, and force a timestamp difference of at least one second + + Notes + ----- + Python's .pyc files record the timestamp of their compilation + with a time resolution of one second. + + Therefore, we need to force a timestamp difference between .py + and .pyc, without having the .py file be timestamped in the + future, and without changing the timestamp of the .pyc file + (because that is stored in the file). The only reliable way + to achieve this seems to be to sleep. + + """ + + # Sleep one second + eps + time.sleep(1.05) + + # Write + f = open(filename, 'w') + try: + f.write(content) + finally: + f.close() + + def new_module(self, code): + mod_name, mod_fn = self.get_module() + f = open(mod_fn, 'w') + try: + f.write(code) + finally: + f.close() + return mod_name, mod_fn + +#----------------------------------------------------------------------------- +# Test automatic reloading +#----------------------------------------------------------------------------- + +class TestAutoreload(Fixture): + def _check_smoketest(self, use_aimport=True): + """ + Functional test for the automatic reloader using either + '%autoreload 1' or '%autoreload 2' + """ + + mod_name, mod_fn = self.new_module(""" +x = 9 + +z = 123 # this item will be deleted + +def foo(y): + return y + 3 + +class Baz(object): + def __init__(self, x): + self.x = x + def bar(self, y): + return self.x + y + @property + def quux(self): + return 42 + def zzz(self): + '''This method will be deleted below''' + return 99 + +class Bar: # old-style class: weakref doesn't work for it on Python < 2.7 + def foo(self): + return 1 +""") + + # + # Import module, and mark for reloading + # + if use_aimport: + self.shell.magic_autoreload("1") + self.shell.magic_aimport(mod_name) + stream = StringIO() + self.shell.magic_aimport("", stream=stream) + nt.assert_true(("Modules to reload:\n%s" % mod_name) in + stream.getvalue()) + + nt.assert_raises( + ImportError, + self.shell.magic_aimport, "tmpmod_as318989e89ds") + else: + self.shell.magic_autoreload("2") + self.shell.run_code("import %s" % mod_name) + stream = StringIO() + self.shell.magic_aimport("", stream=stream) + nt.assert_true("Modules to reload:\nall-except-skipped" in + stream.getvalue()) + nt.assert_true(mod_name in self.shell.ns) + + mod = sys.modules[mod_name] + + # + # Test module contents + # + old_foo = mod.foo + old_obj = mod.Baz(9) + old_obj2 = mod.Bar() + + def check_module_contents(): + nt.assert_equal(mod.x, 9) + nt.assert_equal(mod.z, 123) + + nt.assert_equal(old_foo(0), 3) + nt.assert_equal(mod.foo(0), 3) + + obj = mod.Baz(9) + nt.assert_equal(old_obj.bar(1), 10) + nt.assert_equal(obj.bar(1), 10) + nt.assert_equal(obj.quux, 42) + nt.assert_equal(obj.zzz(), 99) + + obj2 = mod.Bar() + nt.assert_equal(old_obj2.foo(), 1) + nt.assert_equal(obj2.foo(), 1) + + check_module_contents() + + # + # Simulate a failed reload: no reload should occur and exactly + # one error message should be printed + # + self.write_file(mod_fn, """ +a syntax error +""") + + old_stderr = sys.stderr + new_stderr = StringIO() + sys.stderr = new_stderr + try: + self.shell.run_code("pass") # trigger reload + self.shell.run_code("pass") # trigger another reload + check_module_contents() + finally: + sys.stderr = old_stderr + + nt.assert_true(('[autoreload of %s failed:' % mod_name) in + new_stderr.getvalue()) + nt.assert_equal(new_stderr.getvalue().count('[autoreload of'), 1) + + # + # Rewrite module (this time reload should succeed) + # + self.write_file(mod_fn, """ +x = 10 + +def foo(y): + return y + 4 + +class Baz(object): + def __init__(self, x): + self.x = x + def bar(self, y): + return self.x + y + 1 + @property + def quux(self): + return 43 + +class Bar: # old-style class + def foo(self): + return 2 +""") + + def check_module_contents(): + nt.assert_equal(mod.x, 10) + nt.assert_false(hasattr(mod, 'z')) + + nt.assert_equal(old_foo(0), 4) # superreload magic! + nt.assert_equal(mod.foo(0), 4) + + obj = mod.Baz(9) + nt.assert_equal(old_obj.bar(1), 11) # superreload magic! + nt.assert_equal(obj.bar(1), 11) + + nt.assert_equal(old_obj.quux, 43) + nt.assert_equal(obj.quux, 43) + + nt.assert_false(hasattr(old_obj, 'zzz')) + nt.assert_false(hasattr(obj, 'zzz')) + + obj2 = mod.Bar() + nt.assert_equal(old_obj2.foo(), 2) + nt.assert_equal(obj2.foo(), 2) + + self.shell.run_code("pass") # trigger reload + check_module_contents() + + # + # Another failure case: deleted file (shouldn't reload) + # + os.unlink(mod_fn) + + self.shell.run_code("pass") # trigger reload + check_module_contents() + + # + # Disable autoreload and rewrite module: no reload should occur + # + if use_aimport: + self.shell.magic_aimport("-" + mod_name) + stream = StringIO() + self.shell.magic_aimport("", stream=stream) + nt.assert_true(("Modules to skip:\n%s" % mod_name) in + stream.getvalue()) + + # This should succeed, although no such module exists + self.shell.magic_aimport("-tmpmod_as318989e89ds") + else: + self.shell.magic_autoreload("0") + + self.write_file(mod_fn, """ +x = -99 +""") + + self.shell.run_code("pass") # trigger reload + self.shell.run_code("pass") + check_module_contents() + + # + # Re-enable autoreload: reload should now occur + # + if use_aimport: + self.shell.magic_aimport(mod_name) + else: + self.shell.magic_autoreload("") + + self.shell.run_code("pass") # trigger reload + nt.assert_equal(mod.x, -99) + + def test_smoketest_aimport(self): + self._check_smoketest(use_aimport=True) + + def test_smoketest_autoreload(self): + self._check_smoketest(use_aimport=False) diff --git a/IPython/quarantine/ipy_autoreload.py b/IPython/quarantine/ipy_autoreload.py deleted file mode 100644 index 3976551..0000000 --- a/IPython/quarantine/ipy_autoreload.py +++ /dev/null @@ -1,350 +0,0 @@ -""" -IPython extension: autoreload modules before executing the next line - -Try:: - - %autoreload? - -for documentation. -""" - -# Pauli Virtanen , 2008. -# Thomas Heller, 2000. -# -# This IPython module is written by Pauli Virtanen, based on the autoreload -# code by Thomas Heller. - -#------------------------------------------------------------------------------ -# Autoreload functionality -#------------------------------------------------------------------------------ - -import time, os, threading, sys, types, imp, inspect, traceback, atexit -import weakref - -def _get_compiled_ext(): - """Official way to get the extension of compiled files (.pyc or .pyo)""" - for ext, mode, typ in imp.get_suffixes(): - if typ == imp.PY_COMPILED: - return ext - -PY_COMPILED_EXT = _get_compiled_ext() - -class ModuleReloader(object): - failed = {} - """Modules that failed to reload: {module: mtime-on-failed-reload, ...}""" - - modules = {} - """Modules specially marked as autoreloadable.""" - - skip_modules = {} - """Modules specially marked as not autoreloadable.""" - - check_all = True - """Autoreload all modules, not just those listed in 'modules'""" - - old_objects = {} - """(module-name, name) -> weakref, for replacing old code objects""" - - def check(self, check_all=False): - """Check whether some modules need to be reloaded.""" - - if check_all or self.check_all: - modules = sys.modules.keys() - else: - modules = self.modules.keys() - - for modname in modules: - m = sys.modules.get(modname, None) - - if modname in self.skip_modules: - continue - - if not hasattr(m, '__file__'): - continue - - if m.__name__ == '__main__': - # we cannot reload(__main__) - continue - - filename = m.__file__ - dirname = os.path.dirname(filename) - path, ext = os.path.splitext(filename) - - if ext.lower() == '.py': - ext = PY_COMPILED_EXT - filename = os.path.join(dirname, path + PY_COMPILED_EXT) - - if ext != PY_COMPILED_EXT: - continue - - try: - pymtime = os.stat(filename[:-1]).st_mtime - if pymtime <= os.stat(filename).st_mtime: - continue - if self.failed.get(filename[:-1], None) == pymtime: - continue - except OSError: - continue - - try: - superreload(m, reload, self.old_objects) - if filename[:-1] in self.failed: - del self.failed[filename[:-1]] - except: - print >> sys.stderr, "[autoreload of %s failed: %s]" % ( - modname, traceback.format_exc(1)) - self.failed[filename[:-1]] = pymtime - -#------------------------------------------------------------------------------ -# superreload -#------------------------------------------------------------------------------ - -def update_function(old, new): - """Upgrade the code object of a function""" - for name in ['func_code', 'func_defaults', 'func_doc', - 'func_closure', 'func_globals', 'func_dict']: - try: - setattr(old, name, getattr(new, name)) - except (AttributeError, TypeError): - pass - -def update_class(old, new): - """Replace stuff in the __dict__ of a class, and upgrade - method code objects""" - for key in old.__dict__.keys(): - old_obj = getattr(old, key) - - try: - new_obj = getattr(new, key) - except AttributeError: - # obsolete attribute: remove it - try: - delattr(old, key) - except (AttributeError, TypeError): - pass - continue - - if update_generic(old_obj, new_obj): continue - - try: - setattr(old, key, getattr(new, key)) - except (AttributeError, TypeError): - pass # skip non-writable attributes - -def update_property(old, new): - """Replace get/set/del functions of a property""" - update_generic(old.fdel, new.fdel) - update_generic(old.fget, new.fget) - update_generic(old.fset, new.fset) - -def isinstance2(a, b, typ): - return isinstance(a, typ) and isinstance(b, typ) - -UPDATE_RULES = [ - (lambda a, b: isinstance2(a, b, types.ClassType), - update_class), - (lambda a, b: isinstance2(a, b, types.TypeType), - update_class), - (lambda a, b: isinstance2(a, b, types.FunctionType), - update_function), - (lambda a, b: isinstance2(a, b, property), - update_property), - (lambda a, b: isinstance2(a, b, types.MethodType), - lambda a, b: update_function(a.im_func, b.im_func)), -] - -def update_generic(a, b): - for type_check, update in UPDATE_RULES: - if type_check(a, b): - update(a, b) - return True - return False - -class StrongRef(object): - def __init__(self, obj): - self.obj = obj - def __call__(self): - return self.obj - -def superreload(module, reload=reload, old_objects={}): - """Enhanced version of the builtin reload function. - - superreload remembers objects previously in the module, and - - - upgrades the class dictionary of every old class in the module - - upgrades the code object of every old function and method - - clears the module's namespace before reloading - - """ - - # collect old objects in the module - for name, obj in module.__dict__.items(): - if not hasattr(obj, '__module__') or obj.__module__ != module.__name__: - continue - key = (module.__name__, name) - try: - old_objects.setdefault(key, []).append(weakref.ref(obj)) - except TypeError: - # weakref doesn't work for all types; - # create strong references for 'important' cases - if isinstance(obj, types.ClassType): - old_objects.setdefault(key, []).append(StrongRef(obj)) - - # reload module - try: - # clear namespace first from old cruft - old_name = module.__name__ - module.__dict__.clear() - module.__dict__['__name__'] = old_name - except (TypeError, AttributeError, KeyError): - pass - module = reload(module) - - # iterate over all objects and update functions & classes - for name, new_obj in module.__dict__.items(): - key = (module.__name__, name) - if key not in old_objects: continue - - new_refs = [] - for old_ref in old_objects[key]: - old_obj = old_ref() - if old_obj is None: continue - new_refs.append(old_ref) - update_generic(old_obj, new_obj) - - if new_refs: - old_objects[key] = new_refs - else: - del old_objects[key] - - return module - -reloader = ModuleReloader() - -#------------------------------------------------------------------------------ -# IPython connectivity -#------------------------------------------------------------------------------ -from IPython.core import ipapi -from IPython.core.error import TryNext - -ip = ipapi.get() - -autoreload_enabled = False - -def runcode_hook(self): - if not autoreload_enabled: - raise TryNext - try: - reloader.check() - except: - pass - -def enable_autoreload(): - global autoreload_enabled - autoreload_enabled = True - -def disable_autoreload(): - global autoreload_enabled - autoreload_enabled = False - -def autoreload_f(self, parameter_s=''): - r""" %autoreload => Reload modules automatically - - %autoreload - Reload all modules (except those excluded by %aimport) automatically now. - - %autoreload 0 - Disable automatic reloading. - - %autoreload 1 - Reload all modules imported with %aimport every time before executing - the Python code typed. - - %autoreload 2 - Reload all modules (except those excluded by %aimport) every time - before executing the Python code typed. - - Reloading Python modules in a reliable way is in general - difficult, and unexpected things may occur. %autoreload tries to - work around common pitfalls by replacing function code objects and - parts of classes previously in the module with new versions. This - makes the following things to work: - - - Functions and classes imported via 'from xxx import foo' are upgraded - to new versions when 'xxx' is reloaded. - - - Methods and properties of classes are upgraded on reload, so that - calling 'c.foo()' on an object 'c' created before the reload causes - the new code for 'foo' to be executed. - - Some of the known remaining caveats are: - - - Replacing code objects does not always succeed: changing a @property - in a class to an ordinary method or a method to a member variable - can cause problems (but in old objects only). - - - Functions that are removed (eg. via monkey-patching) from a module - before it is reloaded are not upgraded. - - - C extension modules cannot be reloaded, and so cannot be - autoreloaded. - - """ - if parameter_s == '': - reloader.check(True) - elif parameter_s == '0': - disable_autoreload() - elif parameter_s == '1': - reloader.check_all = False - enable_autoreload() - elif parameter_s == '2': - reloader.check_all = True - enable_autoreload() - -def aimport_f(self, parameter_s=''): - """%aimport => Import modules for automatic reloading. - - %aimport - List modules to automatically import and not to import. - - %aimport foo - Import module 'foo' and mark it to be autoreloaded for %autoreload 1 - - %aimport -foo - Mark module 'foo' to not be autoreloaded for %autoreload 1 - - """ - - modname = parameter_s - if not modname: - to_reload = reloader.modules.keys() - to_reload.sort() - to_skip = reloader.skip_modules.keys() - to_skip.sort() - if reloader.check_all: - print "Modules to reload:\nall-expect-skipped" - else: - print "Modules to reload:\n%s" % ' '.join(to_reload) - print "\nModules to skip:\n%s" % ' '.join(to_skip) - elif modname.startswith('-'): - modname = modname[1:] - try: del reloader.modules[modname] - except KeyError: pass - reloader.skip_modules[modname] = True - else: - try: del reloader.skip_modules[modname] - except KeyError: pass - reloader.modules[modname] = True - - # Inject module to user namespace; handle also submodules properly - __import__(modname) - basename = modname.split('.')[0] - mod = sys.modules[basename] - ip.push({basename: mod}) - -def init(): - ip.define_magic('autoreload', autoreload_f) - ip.define_magic('aimport', aimport_f) - ip.set_hook('pre_runcode_hook', runcode_hook) - -init()