From d05065c04115748d7aa9cbe5540970147e8c55d4 2019-03-31 23:13:47 From: Niclas Date: 2019-03-31 23:13:47 Subject: [PATCH] Combined recursive approach with check for already visited objects to avoid infinite recursion. This should pass the already existing autoreload tests, however, new tests for this feature still need to be implemented. --- diff --git a/IPython/extensions/autoreload.py b/IPython/extensions/autoreload.py index 5bd38b3..353124a 100644 --- a/IPython/extensions/autoreload.py +++ b/IPython/extensions/autoreload.py @@ -268,66 +268,14 @@ def update_function(old, new): pass -def _find_instances(old_type): - """Try to find all instances of a class that need updating. - - Classic graph exploration, we want to avoid re-visiting object multiple times. - """ - # find ipython workspace stack frame, this is just to bootstrap where we - # find the object that need updating. - frame = next(frame_nfo.frame for frame_nfo in inspect.stack() - if 'trigger' in frame_nfo.function) - # build generator for non-private variable values from workspace - shell = frame.f_locals['self'].shell - user_ns = shell.user_ns - user_ns_hidden = shell.user_ns_hidden - nonmatching = object() - objects = ( value for key, value in user_ns.items() - if not key.startswith('_') - and (value is not user_ns_hidden.get(key, nonmatching)) - and not inspect.ismodule(value)) - - # note: in the following we do use dict as object might not be hashable. - # list of objects we found that will need an update. - to_update = {} - - # list of object we have not recursed into yet - open_set = {} - - # list of object we have visited already - closed_set = {} - - open_set.update({id(o):o for o in objects}) - - it = 0 - while len(open_set) > 0: - it += 1 - if it > 100_000: - raise ValueError('infinite') - (current_id,current) = next(iter(open_set.items())) - if type(current) is old_type: - to_update[current_id] = current - if hasattr(current, '__dict__') and not (inspect.isfunction(current) - or inspect.ismethod(current)): - potential_new = {id(o):o for o in current.__dict__.values() if id(o) not in closed_set.keys()} - open_set.update(potential_new) - # if object is a container, search it - if hasattr(current, 'items') or (hasattr(current, '__contains__') - and not isinstance(current, str)): - potential_new = (value for key, value in current.items() - if not str(key).startswith('_') - and not inspect.ismodule(value) and not id(value) in closed_set.keys()) - open_set.update(potential_new) - del open_set[id(current)] - closed_set[id(current)] = current - return to_update.values() - -def update_instances(old, new, objects=None): +def update_instances(old, new, objects=None, visited={}): """Iterate through objects recursively, searching for instances of old and replace their __class__ reference with new. If no objects are given, start with the current ipython workspace. """ - if not objects: + if objects is None: + # make sure visited is cleaned when not called recursively + visited = {} # find ipython workspace stack frame frame = next(frame_nfo.frame for frame_nfo in inspect.stack() if 'trigger' in frame_nfo.function) @@ -349,7 +297,9 @@ def update_instances(old, new, objects=None): # try if objects is iterable try: - for obj in objects: + for obj in (obj for obj in objects if id(obj) not in visited): + # add current object to visited to avoid revisiting + visited.update({id(obj):obj}) # update, if object is instance of old_class (but no subclasses) if type(obj) is old: @@ -359,21 +309,21 @@ def update_instances(old, new, objects=None): # if object is instance of other class, look for nested instances if hasattr(obj, '__dict__') and not (inspect.isfunction(obj) or inspect.ismethod(obj)): - update_instances(old, new, obj.__dict__) + update_instances(old, new, obj.__dict__, visited) # if object is a container, search it if hasattr(obj, 'items') or (hasattr(obj, '__contains__') and not isinstance(obj, str)): - update_instances(old, new, obj) + update_instances(old, new, obj, visited) except TypeError: pass - + def update_class(old, new): """Replace stuff in the __dict__ of a class, and upgrade method code objects, and add new methods, if any""" - print('old is', old) + print('old is', id(old)) for key in list(old.__dict__.keys()): old_obj = getattr(old, key) try: @@ -405,8 +355,7 @@ def update_class(old, new): pass # skip non-writable attributes # update all instances of class - for instance in _find_instances(old): - instance.__class__ = new + update_instances(old, new) def update_property(old, new):