diff --git a/IPython/extensions/autoreload.py b/IPython/extensions/autoreload.py index b5e2ece..54ed495 100644 --- a/IPython/extensions/autoreload.py +++ b/IPython/extensions/autoreload.py @@ -130,20 +130,23 @@ class ModuleReloader(object): enabled = False """Whether this reloader is enabled""" - failed = {} - """Modules that failed to reload: {module: mtime-on-failed-reload, ...}""" - - modules = {} - """Modules specially marked as autoreloadable.""" - - skip_modules = {} - """Modules specially marked as not autoreloadable.""" - check_all = True """Autoreload all modules, not just those listed in 'modules'""" - old_objects = {} - """(module-name, name) -> weakref, for replacing old code objects""" + def __init__(self): + # Modules that failed to reload: {module: mtime-on-failed-reload, ...} + self.failed = {} + # Modules specially marked as autoreloadable. + self.modules = {} + # Modules specially marked as not autoreloadable. + self.skip_modules = {} + # (module-name, name) -> weakref, for replacing old code objects + self.old_objects = {} + # Module modification timestamps + self.modules_mtimes = {} + + # Cache module modification times + self.check(check_all=True, do_reload=False) def mark_module_skipped(self, module_name): """Skip reloading the named module in the future""" @@ -179,7 +182,33 @@ class ModuleReloader(object): top_module = sys.modules[top_name] return top_module, top_name - def check(self, check_all=False): + def filename_and_mtime(self, module): + if not hasattr(module, '__file__'): + return None, None + + if module.__name__ == '__main__': + # we cannot reload(__main__) + return None, None + + filename = module.__file__ + path, ext = os.path.splitext(filename) + + if ext.lower() == '.py': + py_filename = filename + else: + try: + py_filename = openpy.source_from_cache(filename) + except ValueError: + return None, None + + try: + pymtime = os.stat(py_filename).st_mtime + except OSError: + return None, None + + return py_filename, pymtime + + def check(self, check_all=False, do_reload=True): """Check whether some modules need to be reloaded.""" if not self.enabled and not check_all: @@ -196,43 +225,32 @@ class ModuleReloader(object): if modname in self.skip_modules: continue - if not hasattr(m, '__file__'): + py_filename, pymtime = self.filename_and_mtime(m) + if py_filename is None: continue - if m.__name__ == '__main__': - # we cannot reload(__main__) - continue - - filename = m.__file__ - path, ext = os.path.splitext(filename) - - if ext.lower() == '.py': - pyc_filename = openpy.cache_from_source(filename) - py_filename = filename - else: - pyc_filename = filename - try: - py_filename = openpy.source_from_cache(filename) - except ValueError: - continue - try: - pymtime = os.stat(py_filename).st_mtime - if pymtime <= os.stat(pyc_filename).st_mtime: + if pymtime <= self.modules_mtimes[modname]: continue + except KeyError: + self.modules_mtimes[modname] = pymtime + continue + else: if self.failed.get(py_filename, None) == pymtime: continue - except OSError: - continue - try: - superreload(m, reload, self.old_objects) - if py_filename in self.failed: - del self.failed[py_filename] - except: - print("[autoreload of %s failed: %s]" % ( - modname, traceback.format_exc(1)), file=sys.stderr) - self.failed[py_filename] = pymtime + self.modules_mtimes[modname] = pymtime + + # If we've reached this point, we should try to reload the module + if do_reload: + try: + superreload(m, reload, self.old_objects) + if py_filename in self.failed: + del self.failed[py_filename] + except: + print("[autoreload of %s failed: %s]" % ( + modname, traceback.format_exc(1)), file=sys.stderr) + self.failed[py_filename] = pymtime #------------------------------------------------------------------------------ # superreload @@ -400,6 +418,7 @@ class AutoreloadMagics(Magics): super(AutoreloadMagics, self).__init__(*a, **kw) self._reloader = ModuleReloader() self._reloader.check_all = False + self.loaded_modules = set(sys.modules) @line_magic def autoreload(self, parameter_s=''): @@ -497,9 +516,21 @@ class AutoreloadMagics(Magics): except: pass + def post_execute_hook(self): + """Cache the modification times of any modules imported in this execution + """ + newly_loaded_modules = set(sys.modules) - self.loaded_modules + for modname in newly_loaded_modules: + _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname]) + if pymtime is not None: + self._reloader.modules_mtimes[modname] = pymtime + + self.loaded_modules.update(newly_loaded_modules) + def load_ipython_extension(ip): """Load the extension in IPython.""" auto_reload = AutoreloadMagics(ip) ip.register_magics(auto_reload) ip.events.register('pre_run_cell', auto_reload.pre_run_cell) + ip.register_post_execute(auto_reload.post_execute_hook) diff --git a/IPython/extensions/tests/test_autoreload.py b/IPython/extensions/tests/test_autoreload.py index d3b061e..968a19f 100644 --- a/IPython/extensions/tests/test_autoreload.py +++ b/IPython/extensions/tests/test_autoreload.py @@ -50,6 +50,7 @@ class FakeShell(object): def run_code(self, code): self.events.trigger('pre_run_cell') exec(code, self.ns) + self.auto_magics.post_execute_hook() def push(self, items): self.ns.update(items) @@ -59,6 +60,7 @@ class FakeShell(object): def magic_aimport(self, parameter, stream=None): self.auto_magics.aimport(parameter, stream=stream) + self.auto_magics.post_execute_hook() class Fixture(object): @@ -168,12 +170,10 @@ class Bar: # old-style class: weakref doesn't work for it on Python < 2.7 self.shell.magic_aimport(mod_name) stream = StringIO() self.shell.magic_aimport("", stream=stream) - nt.assert_true(("Modules to reload:\n%s" % mod_name) in - stream.getvalue()) + nt.assert_in(("Modules to reload:\n%s" % mod_name), stream.getvalue()) - nt.assert_raises( - ImportError, - self.shell.magic_aimport, "tmpmod_as318989e89ds") + with nt.assert_raises(ImportError): + self.shell.magic_aimport("tmpmod_as318989e89ds") else: self.shell.magic_autoreload("2") self.shell.run_code("import %s" % mod_name)