##// END OF EJS Templates
Merge pull request #11644 from daharn/master...
Matthias Bussonnier -
r25092:bd5753b3 merge
parent child Browse files
Show More
@@ -115,6 +115,7 b' import sys'
115 import traceback
115 import traceback
116 import types
116 import types
117 import weakref
117 import weakref
118 import inspect
118 from importlib import import_module
119 from importlib import import_module
119 from importlib.util import source_from_cache
120 from importlib.util import source_from_cache
120 from imp import reload
121 from imp import reload
@@ -267,6 +268,58 b' def update_function(old, new):'
267 pass
268 pass
268
269
269
270
271 def update_instances(old, new, objects=None, visited={}):
272 """Iterate through objects recursively, searching for instances of old and
273 replace their __class__ reference with new. If no objects are given, start
274 with the current ipython workspace.
275 """
276 if objects is None:
277 # make sure visited is cleaned when not called recursively
278 visited = {}
279 # find ipython workspace stack frame
280 frame = next(frame_nfo.frame for frame_nfo in inspect.stack()
281 if 'trigger' in frame_nfo.function)
282 # build generator for non-private variable values from workspace
283 shell = frame.f_locals['self'].shell
284 user_ns = shell.user_ns
285 user_ns_hidden = shell.user_ns_hidden
286 nonmatching = object()
287 objects = ( value for key, value in user_ns.items()
288 if not key.startswith('_')
289 and (value is not user_ns_hidden.get(key, nonmatching))
290 and not inspect.ismodule(value))
291
292 # use dict values if objects is a dict but don't touch private variables
293 if hasattr(objects, 'items'):
294 objects = (value for key, value in objects.items()
295 if not str(key).startswith('_')
296 and not inspect.ismodule(value) )
297
298 # try if objects is iterable
299 try:
300 for obj in (obj for obj in objects if id(obj) not in visited):
301 # add current object to visited to avoid revisiting
302 visited.update({id(obj):obj})
303
304 # update, if object is instance of old_class (but no subclasses)
305 if type(obj) is old:
306 obj.__class__ = new
307
308
309 # if object is instance of other class, look for nested instances
310 if hasattr(obj, '__dict__') and not (inspect.isfunction(obj)
311 or inspect.ismethod(obj)):
312 update_instances(old, new, obj.__dict__, visited)
313
314 # if object is a container, search it
315 if hasattr(obj, 'items') or (hasattr(obj, '__contains__')
316 and not isinstance(obj, str)):
317 update_instances(old, new, obj, visited)
318
319 except TypeError:
320 pass
321
322
270 def update_class(old, new):
323 def update_class(old, new):
271 """Replace stuff in the __dict__ of a class, and upgrade
324 """Replace stuff in the __dict__ of a class, and upgrade
272 method code objects, and add new methods, if any"""
325 method code objects, and add new methods, if any"""
@@ -300,6 +353,9 b' def update_class(old, new):'
300 except (AttributeError, TypeError):
353 except (AttributeError, TypeError):
301 pass # skip non-writable attributes
354 pass # skip non-writable attributes
302
355
356 # update all instances of class
357 update_instances(old, new)
358
303
359
304 def update_property(old, new):
360 def update_property(old, new):
305 """Replace get/set/del functions of a property"""
361 """Replace get/set/del functions of a property"""
@@ -35,10 +35,12 b' from IPython.core.events import EventManager, pre_run_cell'
35
35
36 noop = lambda *a, **kw: None
36 noop = lambda *a, **kw: None
37
37
38 class FakeShell(object):
38 class FakeShell:
39
39
40 def __init__(self):
40 def __init__(self):
41 self.ns = {}
41 self.ns = {}
42 self.user_ns = self.ns
43 self.user_ns_hidden = {}
42 self.events = EventManager(self, {'pre_run_cell', pre_run_cell})
44 self.events = EventManager(self, {'pre_run_cell', pre_run_cell})
43 self.auto_magics = AutoreloadMagics(shell=self)
45 self.auto_magics = AutoreloadMagics(shell=self)
44 self.events.register('pre_run_cell', self.auto_magics.pre_run_cell)
46 self.events.register('pre_run_cell', self.auto_magics.pre_run_cell)
@@ -47,7 +49,7 b' class FakeShell(object):'
47
49
48 def run_code(self, code):
50 def run_code(self, code):
49 self.events.trigger('pre_run_cell')
51 self.events.trigger('pre_run_cell')
50 exec(code, self.ns)
52 exec(code, self.user_ns)
51 self.auto_magics.post_execute_hook()
53 self.auto_magics.post_execute_hook()
52
54
53 def push(self, items):
55 def push(self, items):
@@ -104,7 +106,7 b' class Fixture(object):'
104 (because that is stored in the file). The only reliable way
106 (because that is stored in the file). The only reliable way
105 to achieve this seems to be to sleep.
107 to achieve this seems to be to sleep.
106 """
108 """
107
109 content = textwrap.dedent(content)
108 # Sleep one second + eps
110 # Sleep one second + eps
109 time.sleep(1.05)
111 time.sleep(1.05)
110
112
@@ -113,6 +115,7 b' class Fixture(object):'
113 f.write(content)
115 f.write(content)
114
116
115 def new_module(self, code):
117 def new_module(self, code):
118 code = textwrap.dedent(code)
116 mod_name, mod_fn = self.get_module()
119 mod_name, mod_fn = self.get_module()
117 with open(mod_fn, 'w') as f:
120 with open(mod_fn, 'w') as f:
118 f.write(code)
121 f.write(code)
@@ -122,6 +125,17 b' class Fixture(object):'
122 # Test automatic reloading
125 # Test automatic reloading
123 #-----------------------------------------------------------------------------
126 #-----------------------------------------------------------------------------
124
127
128 def pickle_get_current_class(obj):
129 """
130 Original issue comes from pickle; hence the name.
131 """
132 name = obj.__class__.__name__
133 module_name = getattr(obj, "__module__", None)
134 obj2 = sys.modules[module_name]
135 for subpath in name.split("."):
136 obj2 = getattr(obj2, subpath)
137 return obj2
138
125 class TestAutoreload(Fixture):
139 class TestAutoreload(Fixture):
126
140
127 @skipif(sys.version_info < (3, 6))
141 @skipif(sys.version_info < (3, 6))
@@ -145,6 +159,42 b' class TestAutoreload(Fixture):'
145 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
159 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
146 self.shell.run_code("pass") # trigger another reload
160 self.shell.run_code("pass") # trigger another reload
147
161
162 def test_reload_class_type(self):
163 self.shell.magic_autoreload("2")
164 mod_name, mod_fn = self.new_module(
165 """
166 class Test():
167 def meth(self):
168 return "old"
169 """
170 )
171 assert "test" not in self.shell.ns
172 assert "result" not in self.shell.ns
173
174 self.shell.run_code("from %s import Test" % mod_name)
175 self.shell.run_code("test = Test()")
176
177 self.write_file(
178 mod_fn,
179 """
180 class Test():
181 def meth(self):
182 return "new"
183 """,
184 )
185
186 test_object = self.shell.ns["test"]
187
188 # important to trigger autoreload logic !
189 self.shell.run_code("pass")
190
191 test_class = pickle_get_current_class(test_object)
192 assert isinstance(test_object, test_class)
193
194 # extra check.
195 self.shell.run_code("import pickle")
196 self.shell.run_code("p = pickle.dumps(test)")
197
148 def test_reload_class_attributes(self):
198 def test_reload_class_attributes(self):
149 self.shell.magic_autoreload("2")
199 self.shell.magic_autoreload("2")
150 mod_name, mod_fn = self.new_module(textwrap.dedent("""
200 mod_name, mod_fn = self.new_module(textwrap.dedent("""
@@ -396,3 +446,4 b' x = -99'
396
446
397
447
398
448
449
General Comments 0
You need to be logged in to leave comments. Login now