diff --git a/IPython/extensions/autoreload.py b/IPython/extensions/autoreload.py index 4edd3b3..da9cc0e 100644 --- a/IPython/extensions/autoreload.py +++ b/IPython/extensions/autoreload.py @@ -115,6 +115,7 @@ import sys import traceback import types import weakref +import inspect from importlib import import_module from importlib.util import source_from_cache from imp import reload @@ -267,6 +268,58 @@ def update_function(old, new): pass +def update_instances(old, new, objects=None, visited={}): + """Iterate through objects recursively, searching for instances of old and + replace their __class__ reference with new. If no objects are given, start + with the current ipython workspace. + """ + if objects is None: + # make sure visited is cleaned when not called recursively + visited = {} + # find ipython workspace stack frame + frame = next(frame_nfo.frame for frame_nfo in inspect.stack() + if 'trigger' in frame_nfo.function) + # build generator for non-private variable values from workspace + shell = frame.f_locals['self'].shell + user_ns = shell.user_ns + user_ns_hidden = shell.user_ns_hidden + nonmatching = object() + objects = ( value for key, value in user_ns.items() + if not key.startswith('_') + and (value is not user_ns_hidden.get(key, nonmatching)) + and not inspect.ismodule(value)) + + # use dict values if objects is a dict but don't touch private variables + if hasattr(objects, 'items'): + objects = (value for key, value in objects.items() + if not str(key).startswith('_') + and not inspect.ismodule(value) ) + + # try if objects is iterable + try: + for obj in (obj for obj in objects if id(obj) not in visited): + # add current object to visited to avoid revisiting + visited.update({id(obj):obj}) + + # update, if object is instance of old_class (but no subclasses) + if type(obj) is old: + obj.__class__ = new + + + # if object is instance of other class, look for nested instances + if hasattr(obj, '__dict__') and not (inspect.isfunction(obj) + or inspect.ismethod(obj)): + update_instances(old, new, obj.__dict__, visited) + + # if object is a container, search it + if hasattr(obj, 'items') or (hasattr(obj, '__contains__') + and not isinstance(obj, str)): + update_instances(old, new, obj, visited) + + except TypeError: + pass + + def update_class(old, new): """Replace stuff in the __dict__ of a class, and upgrade method code objects, and add new methods, if any""" @@ -300,6 +353,9 @@ def update_class(old, new): except (AttributeError, TypeError): pass # skip non-writable attributes + # update all instances of class + update_instances(old, new) + def update_property(old, new): """Replace get/set/del functions of a property""" diff --git a/IPython/extensions/tests/test_autoreload.py b/IPython/extensions/tests/test_autoreload.py index 74e0125..cb046d9 100644 --- a/IPython/extensions/tests/test_autoreload.py +++ b/IPython/extensions/tests/test_autoreload.py @@ -35,10 +35,12 @@ from IPython.core.events import EventManager, pre_run_cell noop = lambda *a, **kw: None -class FakeShell(object): +class FakeShell: def __init__(self): self.ns = {} + self.user_ns = self.ns + self.user_ns_hidden = {} self.events = EventManager(self, {'pre_run_cell', pre_run_cell}) self.auto_magics = AutoreloadMagics(shell=self) self.events.register('pre_run_cell', self.auto_magics.pre_run_cell) @@ -47,7 +49,7 @@ class FakeShell(object): def run_code(self, code): self.events.trigger('pre_run_cell') - exec(code, self.ns) + exec(code, self.user_ns) self.auto_magics.post_execute_hook() def push(self, items): @@ -104,7 +106,7 @@ class Fixture(object): (because that is stored in the file). The only reliable way to achieve this seems to be to sleep. """ - + content = textwrap.dedent(content) # Sleep one second + eps time.sleep(1.05) @@ -113,6 +115,7 @@ class Fixture(object): f.write(content) def new_module(self, code): + code = textwrap.dedent(code) mod_name, mod_fn = self.get_module() with open(mod_fn, 'w') as f: f.write(code) @@ -122,6 +125,17 @@ class Fixture(object): # Test automatic reloading #----------------------------------------------------------------------------- +def pickle_get_current_class(obj): + """ + Original issue comes from pickle; hence the name. + """ + name = obj.__class__.__name__ + module_name = getattr(obj, "__module__", None) + obj2 = sys.modules[module_name] + for subpath in name.split("."): + obj2 = getattr(obj2, subpath) + return obj2 + class TestAutoreload(Fixture): @skipif(sys.version_info < (3, 6)) @@ -145,6 +159,42 @@ class TestAutoreload(Fixture): with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'): self.shell.run_code("pass") # trigger another reload + def test_reload_class_type(self): + self.shell.magic_autoreload("2") + mod_name, mod_fn = self.new_module( + """ + class Test(): + def meth(self): + return "old" + """ + ) + assert "test" not in self.shell.ns + assert "result" not in self.shell.ns + + self.shell.run_code("from %s import Test" % mod_name) + self.shell.run_code("test = Test()") + + self.write_file( + mod_fn, + """ + class Test(): + def meth(self): + return "new" + """, + ) + + test_object = self.shell.ns["test"] + + # important to trigger autoreload logic ! + self.shell.run_code("pass") + + test_class = pickle_get_current_class(test_object) + assert isinstance(test_object, test_class) + + # extra check. + self.shell.run_code("import pickle") + self.shell.run_code("p = pickle.dumps(test)") + def test_reload_class_attributes(self): self.shell.magic_autoreload("2") mod_name, mod_fn = self.new_module(textwrap.dedent(""" @@ -396,3 +446,4 @@ x = -99 +