##// END OF EJS Templates
Merge pull request #11876 from daharn/autoreload-fix-using-gc.get_referrers()...
Matthias Bussonnier -
r25187:cd54f155 merge
parent child Browse files
Show More
@@ -1,590 +1,550 b''
1 1 """IPython extension to reload modules before executing user code.
2 2
3 3 ``autoreload`` reloads modules automatically before entering the execution of
4 4 code typed at the IPython prompt.
5 5
6 6 This makes for example the following workflow possible:
7 7
8 8 .. sourcecode:: ipython
9 9
10 10 In [1]: %load_ext autoreload
11 11
12 12 In [2]: %autoreload 2
13 13
14 14 In [3]: from foo import some_function
15 15
16 16 In [4]: some_function()
17 17 Out[4]: 42
18 18
19 19 In [5]: # open foo.py in an editor and change some_function to return 43
20 20
21 21 In [6]: some_function()
22 22 Out[6]: 43
23 23
24 24 The module was reloaded without reloading it explicitly, and the object
25 25 imported with ``from foo import ...`` was also updated.
26 26
27 27 Usage
28 28 =====
29 29
30 30 The following magic commands are provided:
31 31
32 32 ``%autoreload``
33 33
34 34 Reload all modules (except those excluded by ``%aimport``)
35 35 automatically now.
36 36
37 37 ``%autoreload 0``
38 38
39 39 Disable automatic reloading.
40 40
41 41 ``%autoreload 1``
42 42
43 43 Reload all modules imported with ``%aimport`` every time before
44 44 executing the Python code typed.
45 45
46 46 ``%autoreload 2``
47 47
48 48 Reload all modules (except those excluded by ``%aimport``) every
49 49 time before executing the Python code typed.
50 50
51 51 ``%aimport``
52 52
53 53 List modules which are to be automatically imported or not to be imported.
54 54
55 55 ``%aimport foo``
56 56
57 57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
58 58
59 59 ``%aimport foo, bar``
60 60
61 61 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
62 62
63 63 ``%aimport -foo``
64 64
65 65 Mark module 'foo' to not be autoreloaded.
66 66
67 67 Caveats
68 68 =======
69 69
70 70 Reloading Python modules in a reliable way is in general difficult,
71 71 and unexpected things may occur. ``%autoreload`` tries to work around
72 72 common pitfalls by replacing function code objects and parts of
73 73 classes previously in the module with new versions. This makes the
74 74 following things to work:
75 75
76 76 - Functions and classes imported via 'from xxx import foo' are upgraded
77 77 to new versions when 'xxx' is reloaded.
78 78
79 79 - Methods and properties of classes are upgraded on reload, so that
80 80 calling 'c.foo()' on an object 'c' created before the reload causes
81 81 the new code for 'foo' to be executed.
82 82
83 83 Some of the known remaining caveats are:
84 84
85 85 - Replacing code objects does not always succeed: changing a @property
86 86 in a class to an ordinary method or a method to a member variable
87 87 can cause problems (but in old objects only).
88 88
89 89 - Functions that are removed (eg. via monkey-patching) from a module
90 90 before it is reloaded are not upgraded.
91 91
92 92 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
93 93 """
94 94
95 95 skip_doctest = True
96 96
97 97 #-----------------------------------------------------------------------------
98 98 # Copyright (C) 2000 Thomas Heller
99 99 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
100 100 # Copyright (C) 2012 The IPython Development Team
101 101 #
102 102 # Distributed under the terms of the BSD License. The full license is in
103 103 # the file COPYING, distributed as part of this software.
104 104 #-----------------------------------------------------------------------------
105 105 #
106 106 # This IPython module is written by Pauli Virtanen, based on the autoreload
107 107 # code by Thomas Heller.
108 108
109 109 #-----------------------------------------------------------------------------
110 110 # Imports
111 111 #-----------------------------------------------------------------------------
112 112
113 113 import os
114 114 import sys
115 115 import traceback
116 116 import types
117 117 import weakref
118 import inspect
118 import gc
119 119 from importlib import import_module
120 120 from importlib.util import source_from_cache
121 121 from imp import reload
122 122
123 123 #------------------------------------------------------------------------------
124 124 # Autoreload functionality
125 125 #------------------------------------------------------------------------------
126 126
127 127 class ModuleReloader(object):
128 128 enabled = False
129 129 """Whether this reloader is enabled"""
130 130
131 131 check_all = True
132 132 """Autoreload all modules, not just those listed in 'modules'"""
133 133
134 134 def __init__(self):
135 135 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
136 136 self.failed = {}
137 137 # Modules specially marked as autoreloadable.
138 138 self.modules = {}
139 139 # Modules specially marked as not autoreloadable.
140 140 self.skip_modules = {}
141 141 # (module-name, name) -> weakref, for replacing old code objects
142 142 self.old_objects = {}
143 143 # Module modification timestamps
144 144 self.modules_mtimes = {}
145 145
146 146 # Cache module modification times
147 147 self.check(check_all=True, do_reload=False)
148 148
149 149 def mark_module_skipped(self, module_name):
150 150 """Skip reloading the named module in the future"""
151 151 try:
152 152 del self.modules[module_name]
153 153 except KeyError:
154 154 pass
155 155 self.skip_modules[module_name] = True
156 156
157 157 def mark_module_reloadable(self, module_name):
158 158 """Reload the named module in the future (if it is imported)"""
159 159 try:
160 160 del self.skip_modules[module_name]
161 161 except KeyError:
162 162 pass
163 163 self.modules[module_name] = True
164 164
165 165 def aimport_module(self, module_name):
166 166 """Import a module, and mark it reloadable
167 167
168 168 Returns
169 169 -------
170 170 top_module : module
171 171 The imported module if it is top-level, or the top-level
172 172 top_name : module
173 173 Name of top_module
174 174
175 175 """
176 176 self.mark_module_reloadable(module_name)
177 177
178 178 import_module(module_name)
179 179 top_name = module_name.split('.')[0]
180 180 top_module = sys.modules[top_name]
181 181 return top_module, top_name
182 182
183 183 def filename_and_mtime(self, module):
184 184 if not hasattr(module, '__file__') or module.__file__ is None:
185 185 return None, None
186 186
187 187 if getattr(module, '__name__', None) in [None, '__mp_main__', '__main__']:
188 188 # we cannot reload(__main__) or reload(__mp_main__)
189 189 return None, None
190 190
191 191 filename = module.__file__
192 192 path, ext = os.path.splitext(filename)
193 193
194 194 if ext.lower() == '.py':
195 195 py_filename = filename
196 196 else:
197 197 try:
198 198 py_filename = source_from_cache(filename)
199 199 except ValueError:
200 200 return None, None
201 201
202 202 try:
203 203 pymtime = os.stat(py_filename).st_mtime
204 204 except OSError:
205 205 return None, None
206 206
207 207 return py_filename, pymtime
208 208
209 209 def check(self, check_all=False, do_reload=True):
210 210 """Check whether some modules need to be reloaded."""
211 211
212 212 if not self.enabled and not check_all:
213 213 return
214 214
215 215 if check_all or self.check_all:
216 216 modules = list(sys.modules.keys())
217 217 else:
218 218 modules = list(self.modules.keys())
219 219
220 220 for modname in modules:
221 221 m = sys.modules.get(modname, None)
222 222
223 223 if modname in self.skip_modules:
224 224 continue
225 225
226 226 py_filename, pymtime = self.filename_and_mtime(m)
227 227 if py_filename is None:
228 228 continue
229 229
230 230 try:
231 231 if pymtime <= self.modules_mtimes[modname]:
232 232 continue
233 233 except KeyError:
234 234 self.modules_mtimes[modname] = pymtime
235 235 continue
236 236 else:
237 237 if self.failed.get(py_filename, None) == pymtime:
238 238 continue
239 239
240 240 self.modules_mtimes[modname] = pymtime
241 241
242 242 # If we've reached this point, we should try to reload the module
243 243 if do_reload:
244 244 try:
245 245 superreload(m, reload, self.old_objects)
246 246 if py_filename in self.failed:
247 247 del self.failed[py_filename]
248 248 except:
249 249 print("[autoreload of %s failed: %s]" % (
250 250 modname, traceback.format_exc(10)), file=sys.stderr)
251 251 self.failed[py_filename] = pymtime
252 252
253 253 #------------------------------------------------------------------------------
254 254 # superreload
255 255 #------------------------------------------------------------------------------
256 256
257 257
258 258 func_attrs = ['__code__', '__defaults__', '__doc__',
259 259 '__closure__', '__globals__', '__dict__']
260 260
261 261
262 262 def update_function(old, new):
263 263 """Upgrade the code object of a function"""
264 264 for name in func_attrs:
265 265 try:
266 266 setattr(old, name, getattr(new, name))
267 267 except (AttributeError, TypeError):
268 268 pass
269 269
270 270
271 def update_instances(old, new, objects=None, visited={}):
272 """Iterate through objects recursively, searching for instances of old and
273 replace their __class__ reference with new. If no objects are given, start
274 with the current ipython workspace.
275 """
276 if objects is None:
277 # make sure visited is cleaned when not called recursively
278 visited = {}
279 # find ipython workspace stack frame
280 frame = next(frame_nfo.frame for frame_nfo in inspect.stack()
281 if 'trigger' in frame_nfo.function)
282 # build generator for non-private variable values from workspace
283 shell = frame.f_locals['self'].shell
284 user_ns = shell.user_ns
285 user_ns_hidden = shell.user_ns_hidden
286 nonmatching = object()
287 objects = ( value for key, value in user_ns.items()
288 if not key.startswith('_')
289 and (value is not user_ns_hidden.get(key, nonmatching))
290 and not inspect.ismodule(value))
271 def update_instances(old, new):
272 """Use garbage collector to find all instances that refer to the old
273 class definition and update their __class__ to point to the new class
274 definition"""
291 275
292 # use dict values if objects is a dict but don't touch private variables
293 if hasattr(objects, 'items'):
294 objects = (value for key, value in objects.items()
295 if not str(key).startswith('_')
296 and not inspect.ismodule(value) )
276 refs = gc.get_referrers(old)
297 277
298 # try if objects is iterable
299 try:
300 for obj in (obj for obj in objects if id(obj) not in visited):
301 # add current object to visited to avoid revisiting
302 visited.update({id(obj):obj})
303
304 # update, if object is instance of old_class (but no subclasses)
305 if type(obj) is old:
306 obj.__class__ = new
307
308
309 # if object is instance of other class, look for nested instances
310 if hasattr(obj, '__dict__') and not (inspect.isfunction(obj)
311 or inspect.ismethod(obj)):
312 update_instances(old, new, obj.__dict__, visited)
313
314 # if object is a container, search it
315 if hasattr(obj, 'items') or (hasattr(obj, '__contains__')
316 and not isinstance(obj, str)):
317 update_instances(old, new, obj, visited)
318
319 except TypeError:
320 pass
278 for ref in refs:
279 if type(ref) is old:
280 ref.__class__ = new
321 281
322 282
323 283 def update_class(old, new):
324 284 """Replace stuff in the __dict__ of a class, and upgrade
325 285 method code objects, and add new methods, if any"""
326 286 for key in list(old.__dict__.keys()):
327 287 old_obj = getattr(old, key)
328 288 try:
329 289 new_obj = getattr(new, key)
330 290 # explicitly checking that comparison returns True to handle
331 291 # cases where `==` doesn't return a boolean.
332 292 if (old_obj == new_obj) is True:
333 293 continue
334 294 except AttributeError:
335 295 # obsolete attribute: remove it
336 296 try:
337 297 delattr(old, key)
338 298 except (AttributeError, TypeError):
339 299 pass
340 300 continue
341 301
342 302 if update_generic(old_obj, new_obj): continue
343 303
344 304 try:
345 305 setattr(old, key, getattr(new, key))
346 306 except (AttributeError, TypeError):
347 307 pass # skip non-writable attributes
348 308
349 309 for key in list(new.__dict__.keys()):
350 310 if key not in list(old.__dict__.keys()):
351 311 try:
352 312 setattr(old, key, getattr(new, key))
353 313 except (AttributeError, TypeError):
354 314 pass # skip non-writable attributes
355 315
356 316 # update all instances of class
357 317 update_instances(old, new)
358 318
359 319
360 320 def update_property(old, new):
361 321 """Replace get/set/del functions of a property"""
362 322 update_generic(old.fdel, new.fdel)
363 323 update_generic(old.fget, new.fget)
364 324 update_generic(old.fset, new.fset)
365 325
366 326
367 327 def isinstance2(a, b, typ):
368 328 return isinstance(a, typ) and isinstance(b, typ)
369 329
370 330
371 331 UPDATE_RULES = [
372 332 (lambda a, b: isinstance2(a, b, type),
373 333 update_class),
374 334 (lambda a, b: isinstance2(a, b, types.FunctionType),
375 335 update_function),
376 336 (lambda a, b: isinstance2(a, b, property),
377 337 update_property),
378 338 ]
379 339 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
380 340 lambda a, b: update_function(a.__func__, b.__func__)),
381 341 ])
382 342
383 343
384 344 def update_generic(a, b):
385 345 for type_check, update in UPDATE_RULES:
386 346 if type_check(a, b):
387 347 update(a, b)
388 348 return True
389 349 return False
390 350
391 351
392 352 class StrongRef(object):
393 353 def __init__(self, obj):
394 354 self.obj = obj
395 355 def __call__(self):
396 356 return self.obj
397 357
398 358
399 359 def superreload(module, reload=reload, old_objects=None):
400 360 """Enhanced version of the builtin reload function.
401 361
402 362 superreload remembers objects previously in the module, and
403 363
404 364 - upgrades the class dictionary of every old class in the module
405 365 - upgrades the code object of every old function and method
406 366 - clears the module's namespace before reloading
407 367
408 368 """
409 369 if old_objects is None:
410 370 old_objects = {}
411 371
412 372 # collect old objects in the module
413 373 for name, obj in list(module.__dict__.items()):
414 374 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
415 375 continue
416 376 key = (module.__name__, name)
417 377 try:
418 378 old_objects.setdefault(key, []).append(weakref.ref(obj))
419 379 except TypeError:
420 380 pass
421 381
422 382 # reload module
423 383 try:
424 384 # clear namespace first from old cruft
425 385 old_dict = module.__dict__.copy()
426 386 old_name = module.__name__
427 387 module.__dict__.clear()
428 388 module.__dict__['__name__'] = old_name
429 389 module.__dict__['__loader__'] = old_dict['__loader__']
430 390 except (TypeError, AttributeError, KeyError):
431 391 pass
432 392
433 393 try:
434 394 module = reload(module)
435 395 except:
436 396 # restore module dictionary on failed reload
437 397 module.__dict__.update(old_dict)
438 398 raise
439 399
440 400 # iterate over all objects and update functions & classes
441 401 for name, new_obj in list(module.__dict__.items()):
442 402 key = (module.__name__, name)
443 403 if key not in old_objects: continue
444 404
445 405 new_refs = []
446 406 for old_ref in old_objects[key]:
447 407 old_obj = old_ref()
448 408 if old_obj is None: continue
449 409 new_refs.append(old_ref)
450 410 update_generic(old_obj, new_obj)
451 411
452 412 if new_refs:
453 413 old_objects[key] = new_refs
454 414 else:
455 415 del old_objects[key]
456 416
457 417 return module
458 418
459 419 #------------------------------------------------------------------------------
460 420 # IPython connectivity
461 421 #------------------------------------------------------------------------------
462 422
463 423 from IPython.core.magic import Magics, magics_class, line_magic
464 424
465 425 @magics_class
466 426 class AutoreloadMagics(Magics):
467 427 def __init__(self, *a, **kw):
468 428 super(AutoreloadMagics, self).__init__(*a, **kw)
469 429 self._reloader = ModuleReloader()
470 430 self._reloader.check_all = False
471 431 self.loaded_modules = set(sys.modules)
472 432
473 433 @line_magic
474 434 def autoreload(self, parameter_s=''):
475 435 r"""%autoreload => Reload modules automatically
476 436
477 437 %autoreload
478 438 Reload all modules (except those excluded by %aimport) automatically
479 439 now.
480 440
481 441 %autoreload 0
482 442 Disable automatic reloading.
483 443
484 444 %autoreload 1
485 445 Reload all modules imported with %aimport every time before executing
486 446 the Python code typed.
487 447
488 448 %autoreload 2
489 449 Reload all modules (except those excluded by %aimport) every time
490 450 before executing the Python code typed.
491 451
492 452 Reloading Python modules in a reliable way is in general
493 453 difficult, and unexpected things may occur. %autoreload tries to
494 454 work around common pitfalls by replacing function code objects and
495 455 parts of classes previously in the module with new versions. This
496 456 makes the following things to work:
497 457
498 458 - Functions and classes imported via 'from xxx import foo' are upgraded
499 459 to new versions when 'xxx' is reloaded.
500 460
501 461 - Methods and properties of classes are upgraded on reload, so that
502 462 calling 'c.foo()' on an object 'c' created before the reload causes
503 463 the new code for 'foo' to be executed.
504 464
505 465 Some of the known remaining caveats are:
506 466
507 467 - Replacing code objects does not always succeed: changing a @property
508 468 in a class to an ordinary method or a method to a member variable
509 469 can cause problems (but in old objects only).
510 470
511 471 - Functions that are removed (eg. via monkey-patching) from a module
512 472 before it is reloaded are not upgraded.
513 473
514 474 - C extension modules cannot be reloaded, and so cannot be
515 475 autoreloaded.
516 476
517 477 """
518 478 if parameter_s == '':
519 479 self._reloader.check(True)
520 480 elif parameter_s == '0':
521 481 self._reloader.enabled = False
522 482 elif parameter_s == '1':
523 483 self._reloader.check_all = False
524 484 self._reloader.enabled = True
525 485 elif parameter_s == '2':
526 486 self._reloader.check_all = True
527 487 self._reloader.enabled = True
528 488
529 489 @line_magic
530 490 def aimport(self, parameter_s='', stream=None):
531 491 """%aimport => Import modules for automatic reloading.
532 492
533 493 %aimport
534 494 List modules to automatically import and not to import.
535 495
536 496 %aimport foo
537 497 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
538 498
539 499 %aimport foo, bar
540 500 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
541 501
542 502 %aimport -foo
543 503 Mark module 'foo' to not be autoreloaded for %autoreload 1
544 504 """
545 505 modname = parameter_s
546 506 if not modname:
547 507 to_reload = sorted(self._reloader.modules.keys())
548 508 to_skip = sorted(self._reloader.skip_modules.keys())
549 509 if stream is None:
550 510 stream = sys.stdout
551 511 if self._reloader.check_all:
552 512 stream.write("Modules to reload:\nall-except-skipped\n")
553 513 else:
554 514 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
555 515 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
556 516 elif modname.startswith('-'):
557 517 modname = modname[1:]
558 518 self._reloader.mark_module_skipped(modname)
559 519 else:
560 520 for _module in ([_.strip() for _ in modname.split(',')]):
561 521 top_module, top_name = self._reloader.aimport_module(_module)
562 522
563 523 # Inject module to user namespace
564 524 self.shell.push({top_name: top_module})
565 525
566 526 def pre_run_cell(self):
567 527 if self._reloader.enabled:
568 528 try:
569 529 self._reloader.check()
570 530 except:
571 531 pass
572 532
573 533 def post_execute_hook(self):
574 534 """Cache the modification times of any modules imported in this execution
575 535 """
576 536 newly_loaded_modules = set(sys.modules) - self.loaded_modules
577 537 for modname in newly_loaded_modules:
578 538 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
579 539 if pymtime is not None:
580 540 self._reloader.modules_mtimes[modname] = pymtime
581 541
582 542 self.loaded_modules.update(newly_loaded_modules)
583 543
584 544
585 545 def load_ipython_extension(ip):
586 546 """Load the extension in IPython."""
587 547 auto_reload = AutoreloadMagics(ip)
588 548 ip.register_magics(auto_reload)
589 549 ip.events.register('pre_run_cell', auto_reload.pre_run_cell)
590 550 ip.events.register('post_execute', auto_reload.post_execute_hook)
General Comments 0
You need to be logged in to leave comments. Login now