diff --git a/IPython/extensions/autoreload.py b/IPython/extensions/autoreload.py index 7bb5a54..5bd38b3 100644 --- a/IPython/extensions/autoreload.py +++ b/IPython/extensions/autoreload.py @@ -268,6 +268,60 @@ def update_function(old, new): pass +def _find_instances(old_type): + """Try to find all instances of a class that need updating. + + Classic graph exploration, we want to avoid re-visiting object multiple times. + """ + # find ipython workspace stack frame, this is just to bootstrap where we + # find the object that need updating. + frame = next(frame_nfo.frame for frame_nfo in inspect.stack() + if 'trigger' in frame_nfo.function) + # build generator for non-private variable values from workspace + shell = frame.f_locals['self'].shell + user_ns = shell.user_ns + user_ns_hidden = shell.user_ns_hidden + nonmatching = object() + objects = ( value for key, value in user_ns.items() + if not key.startswith('_') + and (value is not user_ns_hidden.get(key, nonmatching)) + and not inspect.ismodule(value)) + + # note: in the following we do use dict as object might not be hashable. + # list of objects we found that will need an update. + to_update = {} + + # list of object we have not recursed into yet + open_set = {} + + # list of object we have visited already + closed_set = {} + + open_set.update({id(o):o for o in objects}) + + it = 0 + while len(open_set) > 0: + it += 1 + if it > 100_000: + raise ValueError('infinite') + (current_id,current) = next(iter(open_set.items())) + if type(current) is old_type: + to_update[current_id] = current + if hasattr(current, '__dict__') and not (inspect.isfunction(current) + or inspect.ismethod(current)): + potential_new = {id(o):o for o in current.__dict__.values() if id(o) not in closed_set.keys()} + open_set.update(potential_new) + # if object is a container, search it + if hasattr(current, 'items') or (hasattr(current, '__contains__') + and not isinstance(current, str)): + potential_new = (value for key, value in current.items() + if not str(key).startswith('_') + and not inspect.ismodule(value) and not id(value) in closed_set.keys()) + open_set.update(potential_new) + del open_set[id(current)] + closed_set[id(current)] = current + return to_update.values() + def update_instances(old, new, objects=None): """Iterate through objects recursively, searching for instances of old and replace their __class__ reference with new. If no objects are given, start @@ -319,6 +373,7 @@ def update_instances(old, new, objects=None): def update_class(old, new): """Replace stuff in the __dict__ of a class, and upgrade method code objects, and add new methods, if any""" + print('old is', old) for key in list(old.__dict__.keys()): old_obj = getattr(old, key) try: @@ -350,7 +405,8 @@ def update_class(old, new): pass # skip non-writable attributes # update all instances of class - update_instances(old, new) + for instance in _find_instances(old): + instance.__class__ = new def update_property(old, new): diff --git a/IPython/extensions/tests/test_autoreload.py b/IPython/extensions/tests/test_autoreload.py index 74e0125..40d63ff 100644 --- a/IPython/extensions/tests/test_autoreload.py +++ b/IPython/extensions/tests/test_autoreload.py @@ -35,10 +35,12 @@ from IPython.core.events import EventManager, pre_run_cell noop = lambda *a, **kw: None -class FakeShell(object): +class FakeShell: def __init__(self): self.ns = {} + self.user_ns = {} + self.user_ns_hidden = {} self.events = EventManager(self, {'pre_run_cell', pre_run_cell}) self.auto_magics = AutoreloadMagics(shell=self) self.events.register('pre_run_cell', self.auto_magics.pre_run_cell)