Show More
@@ -115,6 +115,7 b' import sys' | |||||
115 | import traceback |
|
115 | import traceback | |
116 | import types |
|
116 | import types | |
117 | import weakref |
|
117 | import weakref | |
|
118 | import inspect | |||
118 | from importlib import import_module |
|
119 | from importlib import import_module | |
119 | from importlib.util import source_from_cache |
|
120 | from importlib.util import source_from_cache | |
120 | from imp import reload |
|
121 | from imp import reload | |
@@ -267,6 +268,58 b' def update_function(old, new):' | |||||
267 | pass |
|
268 | pass | |
268 |
|
269 | |||
269 |
|
270 | |||
|
271 | def update_instances(old, new, objects=None, visited={}): | |||
|
272 | """Iterate through objects recursively, searching for instances of old and | |||
|
273 | replace their __class__ reference with new. If no objects are given, start | |||
|
274 | with the current ipython workspace. | |||
|
275 | """ | |||
|
276 | if objects is None: | |||
|
277 | # make sure visited is cleaned when not called recursively | |||
|
278 | visited = {} | |||
|
279 | # find ipython workspace stack frame | |||
|
280 | frame = next(frame_nfo.frame for frame_nfo in inspect.stack() | |||
|
281 | if 'trigger' in frame_nfo.function) | |||
|
282 | # build generator for non-private variable values from workspace | |||
|
283 | shell = frame.f_locals['self'].shell | |||
|
284 | user_ns = shell.user_ns | |||
|
285 | user_ns_hidden = shell.user_ns_hidden | |||
|
286 | nonmatching = object() | |||
|
287 | objects = ( value for key, value in user_ns.items() | |||
|
288 | if not key.startswith('_') | |||
|
289 | and (value is not user_ns_hidden.get(key, nonmatching)) | |||
|
290 | and not inspect.ismodule(value)) | |||
|
291 | ||||
|
292 | # use dict values if objects is a dict but don't touch private variables | |||
|
293 | if hasattr(objects, 'items'): | |||
|
294 | objects = (value for key, value in objects.items() | |||
|
295 | if not str(key).startswith('_') | |||
|
296 | and not inspect.ismodule(value) ) | |||
|
297 | ||||
|
298 | # try if objects is iterable | |||
|
299 | try: | |||
|
300 | for obj in (obj for obj in objects if id(obj) not in visited): | |||
|
301 | # add current object to visited to avoid revisiting | |||
|
302 | visited.update({id(obj):obj}) | |||
|
303 | ||||
|
304 | # update, if object is instance of old_class (but no subclasses) | |||
|
305 | if type(obj) is old: | |||
|
306 | obj.__class__ = new | |||
|
307 | ||||
|
308 | ||||
|
309 | # if object is instance of other class, look for nested instances | |||
|
310 | if hasattr(obj, '__dict__') and not (inspect.isfunction(obj) | |||
|
311 | or inspect.ismethod(obj)): | |||
|
312 | update_instances(old, new, obj.__dict__, visited) | |||
|
313 | ||||
|
314 | # if object is a container, search it | |||
|
315 | if hasattr(obj, 'items') or (hasattr(obj, '__contains__') | |||
|
316 | and not isinstance(obj, str)): | |||
|
317 | update_instances(old, new, obj, visited) | |||
|
318 | ||||
|
319 | except TypeError: | |||
|
320 | pass | |||
|
321 | ||||
|
322 | ||||
270 | def update_class(old, new): |
|
323 | def update_class(old, new): | |
271 | """Replace stuff in the __dict__ of a class, and upgrade |
|
324 | """Replace stuff in the __dict__ of a class, and upgrade | |
272 | method code objects, and add new methods, if any""" |
|
325 | method code objects, and add new methods, if any""" | |
@@ -300,6 +353,9 b' def update_class(old, new):' | |||||
300 | except (AttributeError, TypeError): |
|
353 | except (AttributeError, TypeError): | |
301 | pass # skip non-writable attributes |
|
354 | pass # skip non-writable attributes | |
302 |
|
355 | |||
|
356 | # update all instances of class | |||
|
357 | update_instances(old, new) | |||
|
358 | ||||
303 |
|
359 | |||
304 | def update_property(old, new): |
|
360 | def update_property(old, new): | |
305 | """Replace get/set/del functions of a property""" |
|
361 | """Replace get/set/del functions of a property""" |
@@ -35,10 +35,12 b' from IPython.core.events import EventManager, pre_run_cell' | |||||
35 |
|
35 | |||
36 | noop = lambda *a, **kw: None |
|
36 | noop = lambda *a, **kw: None | |
37 |
|
37 | |||
38 |
class FakeShell |
|
38 | class FakeShell: | |
39 |
|
39 | |||
40 | def __init__(self): |
|
40 | def __init__(self): | |
41 | self.ns = {} |
|
41 | self.ns = {} | |
|
42 | self.user_ns = self.ns | |||
|
43 | self.user_ns_hidden = {} | |||
42 | self.events = EventManager(self, {'pre_run_cell', pre_run_cell}) |
|
44 | self.events = EventManager(self, {'pre_run_cell', pre_run_cell}) | |
43 | self.auto_magics = AutoreloadMagics(shell=self) |
|
45 | self.auto_magics = AutoreloadMagics(shell=self) | |
44 | self.events.register('pre_run_cell', self.auto_magics.pre_run_cell) |
|
46 | self.events.register('pre_run_cell', self.auto_magics.pre_run_cell) | |
@@ -47,7 +49,7 b' class FakeShell(object):' | |||||
47 |
|
49 | |||
48 | def run_code(self, code): |
|
50 | def run_code(self, code): | |
49 | self.events.trigger('pre_run_cell') |
|
51 | self.events.trigger('pre_run_cell') | |
50 | exec(code, self.ns) |
|
52 | exec(code, self.user_ns) | |
51 | self.auto_magics.post_execute_hook() |
|
53 | self.auto_magics.post_execute_hook() | |
52 |
|
54 | |||
53 | def push(self, items): |
|
55 | def push(self, items): | |
@@ -104,7 +106,7 b' class Fixture(object):' | |||||
104 | (because that is stored in the file). The only reliable way |
|
106 | (because that is stored in the file). The only reliable way | |
105 | to achieve this seems to be to sleep. |
|
107 | to achieve this seems to be to sleep. | |
106 | """ |
|
108 | """ | |
107 |
|
109 | content = textwrap.dedent(content) | ||
108 | # Sleep one second + eps |
|
110 | # Sleep one second + eps | |
109 | time.sleep(1.05) |
|
111 | time.sleep(1.05) | |
110 |
|
112 | |||
@@ -113,6 +115,7 b' class Fixture(object):' | |||||
113 | f.write(content) |
|
115 | f.write(content) | |
114 |
|
116 | |||
115 | def new_module(self, code): |
|
117 | def new_module(self, code): | |
|
118 | code = textwrap.dedent(code) | |||
116 | mod_name, mod_fn = self.get_module() |
|
119 | mod_name, mod_fn = self.get_module() | |
117 | with open(mod_fn, 'w') as f: |
|
120 | with open(mod_fn, 'w') as f: | |
118 | f.write(code) |
|
121 | f.write(code) | |
@@ -122,6 +125,17 b' class Fixture(object):' | |||||
122 | # Test automatic reloading |
|
125 | # Test automatic reloading | |
123 | #----------------------------------------------------------------------------- |
|
126 | #----------------------------------------------------------------------------- | |
124 |
|
127 | |||
|
128 | def pickle_get_current_class(obj): | |||
|
129 | """ | |||
|
130 | Original issue comes from pickle; hence the name. | |||
|
131 | """ | |||
|
132 | name = obj.__class__.__name__ | |||
|
133 | module_name = getattr(obj, "__module__", None) | |||
|
134 | obj2 = sys.modules[module_name] | |||
|
135 | for subpath in name.split("."): | |||
|
136 | obj2 = getattr(obj2, subpath) | |||
|
137 | return obj2 | |||
|
138 | ||||
125 | class TestAutoreload(Fixture): |
|
139 | class TestAutoreload(Fixture): | |
126 |
|
140 | |||
127 | @skipif(sys.version_info < (3, 6)) |
|
141 | @skipif(sys.version_info < (3, 6)) | |
@@ -145,6 +159,42 b' class TestAutoreload(Fixture):' | |||||
145 | with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'): |
|
159 | with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'): | |
146 | self.shell.run_code("pass") # trigger another reload |
|
160 | self.shell.run_code("pass") # trigger another reload | |
147 |
|
161 | |||
|
162 | def test_reload_class_type(self): | |||
|
163 | self.shell.magic_autoreload("2") | |||
|
164 | mod_name, mod_fn = self.new_module( | |||
|
165 | """ | |||
|
166 | class Test(): | |||
|
167 | def meth(self): | |||
|
168 | return "old" | |||
|
169 | """ | |||
|
170 | ) | |||
|
171 | assert "test" not in self.shell.ns | |||
|
172 | assert "result" not in self.shell.ns | |||
|
173 | ||||
|
174 | self.shell.run_code("from %s import Test" % mod_name) | |||
|
175 | self.shell.run_code("test = Test()") | |||
|
176 | ||||
|
177 | self.write_file( | |||
|
178 | mod_fn, | |||
|
179 | """ | |||
|
180 | class Test(): | |||
|
181 | def meth(self): | |||
|
182 | return "new" | |||
|
183 | """, | |||
|
184 | ) | |||
|
185 | ||||
|
186 | test_object = self.shell.ns["test"] | |||
|
187 | ||||
|
188 | # important to trigger autoreload logic ! | |||
|
189 | self.shell.run_code("pass") | |||
|
190 | ||||
|
191 | test_class = pickle_get_current_class(test_object) | |||
|
192 | assert isinstance(test_object, test_class) | |||
|
193 | ||||
|
194 | # extra check. | |||
|
195 | self.shell.run_code("import pickle") | |||
|
196 | self.shell.run_code("p = pickle.dumps(test)") | |||
|
197 | ||||
148 | def test_reload_class_attributes(self): |
|
198 | def test_reload_class_attributes(self): | |
149 | self.shell.magic_autoreload("2") |
|
199 | self.shell.magic_autoreload("2") | |
150 | mod_name, mod_fn = self.new_module(textwrap.dedent(""" |
|
200 | mod_name, mod_fn = self.new_module(textwrap.dedent(""" | |
@@ -396,3 +446,4 b' x = -99' | |||||
396 |
|
446 | |||
397 |
|
447 | |||
398 |
|
448 | |||
|
449 |
General Comments 0
You need to be logged in to leave comments.
Login now