##// END OF EJS Templates
Merge pull request #13732 from hhoppe/main...
Matthias Bussonnier -
r27730:8eef757f merge
parent child Browse files
Show More
@@ -1,627 +1,627
1 1 """IPython extension to reload modules before executing user code.
2 2
3 3 ``autoreload`` reloads modules automatically before entering the execution of
4 4 code typed at the IPython prompt.
5 5
6 6 This makes for example the following workflow possible:
7 7
8 8 .. sourcecode:: ipython
9 9
10 10 In [1]: %load_ext autoreload
11 11
12 12 In [2]: %autoreload 2
13 13
14 14 In [3]: from foo import some_function
15 15
16 16 In [4]: some_function()
17 17 Out[4]: 42
18 18
19 19 In [5]: # open foo.py in an editor and change some_function to return 43
20 20
21 21 In [6]: some_function()
22 22 Out[6]: 43
23 23
24 24 The module was reloaded without reloading it explicitly, and the object
25 25 imported with ``from foo import ...`` was also updated.
26 26
27 27 Usage
28 28 =====
29 29
30 30 The following magic commands are provided:
31 31
32 32 ``%autoreload``
33 33
34 34 Reload all modules (except those excluded by ``%aimport``)
35 35 automatically now.
36 36
37 37 ``%autoreload 0``
38 38
39 39 Disable automatic reloading.
40 40
41 41 ``%autoreload 1``
42 42
43 43 Reload all modules imported with ``%aimport`` every time before
44 44 executing the Python code typed.
45 45
46 46 ``%autoreload 2``
47 47
48 48 Reload all modules (except those excluded by ``%aimport``) every
49 49 time before executing the Python code typed.
50 50
51 51 ``%autoreload 3``
52 52
53 53 Reload all modules AND autoload newly added objects
54 54 every time before executing the Python code typed.
55 55
56 56 ``%aimport``
57 57
58 58 List modules which are to be automatically imported or not to be imported.
59 59
60 60 ``%aimport foo``
61 61
62 62 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
63 63
64 64 ``%aimport foo, bar``
65 65
66 66 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
67 67
68 68 ``%aimport -foo``
69 69
70 70 Mark module 'foo' to not be autoreloaded.
71 71
72 72 Caveats
73 73 =======
74 74
75 75 Reloading Python modules in a reliable way is in general difficult,
76 76 and unexpected things may occur. ``%autoreload`` tries to work around
77 77 common pitfalls by replacing function code objects and parts of
78 78 classes previously in the module with new versions. This makes the
79 79 following things to work:
80 80
81 81 - Functions and classes imported via 'from xxx import foo' are upgraded
82 82 to new versions when 'xxx' is reloaded.
83 83
84 84 - Methods and properties of classes are upgraded on reload, so that
85 85 calling 'c.foo()' on an object 'c' created before the reload causes
86 86 the new code for 'foo' to be executed.
87 87
88 88 Some of the known remaining caveats are:
89 89
90 90 - Replacing code objects does not always succeed: changing a @property
91 91 in a class to an ordinary method or a method to a member variable
92 92 can cause problems (but in old objects only).
93 93
94 94 - Functions that are removed (eg. via monkey-patching) from a module
95 95 before it is reloaded are not upgraded.
96 96
97 97 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
98 98 """
99 99
100 100 __skip_doctest__ = True
101 101
102 102 # -----------------------------------------------------------------------------
103 103 # Copyright (C) 2000 Thomas Heller
104 104 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
105 105 # Copyright (C) 2012 The IPython Development Team
106 106 #
107 107 # Distributed under the terms of the BSD License. The full license is in
108 108 # the file COPYING, distributed as part of this software.
109 109 # -----------------------------------------------------------------------------
110 110 #
111 111 # This IPython module is written by Pauli Virtanen, based on the autoreload
112 112 # code by Thomas Heller.
113 113
114 114 # -----------------------------------------------------------------------------
115 115 # Imports
116 116 # -----------------------------------------------------------------------------
117 117
118 118 import os
119 119 import sys
120 120 import traceback
121 121 import types
122 122 import weakref
123 123 import gc
124 124 from importlib import import_module, reload
125 125 from importlib.util import source_from_cache
126 126
127 127 # ------------------------------------------------------------------------------
128 128 # Autoreload functionality
129 129 # ------------------------------------------------------------------------------
130 130
131 131
132 132 class ModuleReloader:
133 133 enabled = False
134 134 """Whether this reloader is enabled"""
135 135
136 136 check_all = True
137 137 """Autoreload all modules, not just those listed in 'modules'"""
138 138
139 139 autoload_obj = False
140 140 """Autoreload all modules AND autoload all new objects"""
141 141
142 142 def __init__(self, shell=None):
143 143 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
144 144 self.failed = {}
145 145 # Modules specially marked as autoreloadable.
146 146 self.modules = {}
147 147 # Modules specially marked as not autoreloadable.
148 148 self.skip_modules = {}
149 149 # (module-name, name) -> weakref, for replacing old code objects
150 150 self.old_objects = {}
151 151 # Module modification timestamps
152 152 self.modules_mtimes = {}
153 153 self.shell = shell
154 154
155 155 # Cache module modification times
156 156 self.check(check_all=True, do_reload=False)
157 157
158 158 def mark_module_skipped(self, module_name):
159 159 """Skip reloading the named module in the future"""
160 160 try:
161 161 del self.modules[module_name]
162 162 except KeyError:
163 163 pass
164 164 self.skip_modules[module_name] = True
165 165
166 166 def mark_module_reloadable(self, module_name):
167 167 """Reload the named module in the future (if it is imported)"""
168 168 try:
169 169 del self.skip_modules[module_name]
170 170 except KeyError:
171 171 pass
172 172 self.modules[module_name] = True
173 173
174 174 def aimport_module(self, module_name):
175 175 """Import a module, and mark it reloadable
176 176
177 177 Returns
178 178 -------
179 179 top_module : module
180 180 The imported module if it is top-level, or the top-level
181 181 top_name : module
182 182 Name of top_module
183 183
184 184 """
185 185 self.mark_module_reloadable(module_name)
186 186
187 187 import_module(module_name)
188 188 top_name = module_name.split(".")[0]
189 189 top_module = sys.modules[top_name]
190 190 return top_module, top_name
191 191
192 192 def filename_and_mtime(self, module):
193 193 if not hasattr(module, "__file__") or module.__file__ is None:
194 194 return None, None
195 195
196 196 if getattr(module, "__name__", None) in [None, "__mp_main__", "__main__"]:
197 197 # we cannot reload(__main__) or reload(__mp_main__)
198 198 return None, None
199 199
200 200 filename = module.__file__
201 201 path, ext = os.path.splitext(filename)
202 202
203 203 if ext.lower() == ".py":
204 204 py_filename = filename
205 205 else:
206 206 try:
207 207 py_filename = source_from_cache(filename)
208 208 except ValueError:
209 209 return None, None
210 210
211 211 try:
212 212 pymtime = os.stat(py_filename).st_mtime
213 213 except OSError:
214 214 return None, None
215 215
216 216 return py_filename, pymtime
217 217
218 218 def check(self, check_all=False, do_reload=True):
219 219 """Check whether some modules need to be reloaded."""
220 220
221 221 if not self.enabled and not check_all:
222 222 return
223 223
224 224 if check_all or self.check_all:
225 225 modules = list(sys.modules.keys())
226 226 else:
227 227 modules = list(self.modules.keys())
228 228
229 229 for modname in modules:
230 230 m = sys.modules.get(modname, None)
231 231
232 232 if modname in self.skip_modules:
233 233 continue
234 234
235 235 py_filename, pymtime = self.filename_and_mtime(m)
236 236 if py_filename is None:
237 237 continue
238 238
239 239 try:
240 240 if pymtime <= self.modules_mtimes[modname]:
241 241 continue
242 242 except KeyError:
243 243 self.modules_mtimes[modname] = pymtime
244 244 continue
245 245 else:
246 246 if self.failed.get(py_filename, None) == pymtime:
247 247 continue
248 248
249 249 self.modules_mtimes[modname] = pymtime
250 250
251 251 # If we've reached this point, we should try to reload the module
252 252 if do_reload:
253 253 try:
254 254 if self.autoload_obj:
255 255 superreload(m, reload, self.old_objects, self.shell)
256 256 else:
257 257 superreload(m, reload, self.old_objects)
258 258 if py_filename in self.failed:
259 259 del self.failed[py_filename]
260 260 except:
261 261 print(
262 262 "[autoreload of {} failed: {}]".format(
263 263 modname, traceback.format_exc(10)
264 264 ),
265 265 file=sys.stderr,
266 266 )
267 267 self.failed[py_filename] = pymtime
268 268
269 269
270 270 # ------------------------------------------------------------------------------
271 271 # superreload
272 272 # ------------------------------------------------------------------------------
273 273
274 274
275 275 func_attrs = [
276 276 "__code__",
277 277 "__defaults__",
278 278 "__doc__",
279 279 "__closure__",
280 280 "__globals__",
281 281 "__dict__",
282 282 ]
283 283
284 284
285 285 def update_function(old, new):
286 286 """Upgrade the code object of a function"""
287 287 for name in func_attrs:
288 288 try:
289 289 setattr(old, name, getattr(new, name))
290 290 except (AttributeError, TypeError):
291 291 pass
292 292
293 293
294 294 def update_instances(old, new):
295 295 """Use garbage collector to find all instances that refer to the old
296 296 class definition and update their __class__ to point to the new class
297 297 definition"""
298 298
299 299 refs = gc.get_referrers(old)
300 300
301 301 for ref in refs:
302 302 if type(ref) is old:
303 ref.__class__ = new
303 object.__setattr__(ref, "__class__", new)
304 304
305 305
306 306 def update_class(old, new):
307 307 """Replace stuff in the __dict__ of a class, and upgrade
308 308 method code objects, and add new methods, if any"""
309 309 for key in list(old.__dict__.keys()):
310 310 old_obj = getattr(old, key)
311 311 try:
312 312 new_obj = getattr(new, key)
313 313 # explicitly checking that comparison returns True to handle
314 314 # cases where `==` doesn't return a boolean.
315 315 if (old_obj == new_obj) is True:
316 316 continue
317 317 except AttributeError:
318 318 # obsolete attribute: remove it
319 319 try:
320 320 delattr(old, key)
321 321 except (AttributeError, TypeError):
322 322 pass
323 323 continue
324 324 except ValueError:
325 325 # can't compare nested structures containing
326 326 # numpy arrays using `==`
327 327 pass
328 328
329 329 if update_generic(old_obj, new_obj):
330 330 continue
331 331
332 332 try:
333 333 setattr(old, key, getattr(new, key))
334 334 except (AttributeError, TypeError):
335 335 pass # skip non-writable attributes
336 336
337 337 for key in list(new.__dict__.keys()):
338 338 if key not in list(old.__dict__.keys()):
339 339 try:
340 340 setattr(old, key, getattr(new, key))
341 341 except (AttributeError, TypeError):
342 342 pass # skip non-writable attributes
343 343
344 344 # update all instances of class
345 345 update_instances(old, new)
346 346
347 347
348 348 def update_property(old, new):
349 349 """Replace get/set/del functions of a property"""
350 350 update_generic(old.fdel, new.fdel)
351 351 update_generic(old.fget, new.fget)
352 352 update_generic(old.fset, new.fset)
353 353
354 354
355 355 def isinstance2(a, b, typ):
356 356 return isinstance(a, typ) and isinstance(b, typ)
357 357
358 358
359 359 UPDATE_RULES = [
360 360 (lambda a, b: isinstance2(a, b, type), update_class),
361 361 (lambda a, b: isinstance2(a, b, types.FunctionType), update_function),
362 362 (lambda a, b: isinstance2(a, b, property), update_property),
363 363 ]
364 364 UPDATE_RULES.extend(
365 365 [
366 366 (
367 367 lambda a, b: isinstance2(a, b, types.MethodType),
368 368 lambda a, b: update_function(a.__func__, b.__func__),
369 369 ),
370 370 ]
371 371 )
372 372
373 373
374 374 def update_generic(a, b):
375 375 for type_check, update in UPDATE_RULES:
376 376 if type_check(a, b):
377 377 update(a, b)
378 378 return True
379 379 return False
380 380
381 381
382 382 class StrongRef:
383 383 def __init__(self, obj):
384 384 self.obj = obj
385 385
386 386 def __call__(self):
387 387 return self.obj
388 388
389 389
390 390 mod_attrs = [
391 391 "__name__",
392 392 "__doc__",
393 393 "__package__",
394 394 "__loader__",
395 395 "__spec__",
396 396 "__file__",
397 397 "__cached__",
398 398 "__builtins__",
399 399 ]
400 400
401 401
402 402 def append_obj(module, d, name, obj, autoload=False):
403 403 in_module = hasattr(obj, "__module__") and obj.__module__ == module.__name__
404 404 if autoload:
405 405 # check needed for module global built-ins
406 406 if not in_module and name in mod_attrs:
407 407 return False
408 408 else:
409 409 if not in_module:
410 410 return False
411 411
412 412 key = (module.__name__, name)
413 413 try:
414 414 d.setdefault(key, []).append(weakref.ref(obj))
415 415 except TypeError:
416 416 pass
417 417 return True
418 418
419 419
420 420 def superreload(module, reload=reload, old_objects=None, shell=None):
421 421 """Enhanced version of the builtin reload function.
422 422
423 423 superreload remembers objects previously in the module, and
424 424
425 425 - upgrades the class dictionary of every old class in the module
426 426 - upgrades the code object of every old function and method
427 427 - clears the module's namespace before reloading
428 428
429 429 """
430 430 if old_objects is None:
431 431 old_objects = {}
432 432
433 433 # collect old objects in the module
434 434 for name, obj in list(module.__dict__.items()):
435 435 if not append_obj(module, old_objects, name, obj):
436 436 continue
437 437 key = (module.__name__, name)
438 438 try:
439 439 old_objects.setdefault(key, []).append(weakref.ref(obj))
440 440 except TypeError:
441 441 pass
442 442
443 443 # reload module
444 444 try:
445 445 # clear namespace first from old cruft
446 446 old_dict = module.__dict__.copy()
447 447 old_name = module.__name__
448 448 module.__dict__.clear()
449 449 module.__dict__["__name__"] = old_name
450 450 module.__dict__["__loader__"] = old_dict["__loader__"]
451 451 except (TypeError, AttributeError, KeyError):
452 452 pass
453 453
454 454 try:
455 455 module = reload(module)
456 456 except:
457 457 # restore module dictionary on failed reload
458 458 module.__dict__.update(old_dict)
459 459 raise
460 460
461 461 # iterate over all objects and update functions & classes
462 462 for name, new_obj in list(module.__dict__.items()):
463 463 key = (module.__name__, name)
464 464 if key not in old_objects:
465 465 # here 'shell' acts both as a flag and as an output var
466 466 if (
467 467 shell is None
468 468 or name == "Enum"
469 469 or not append_obj(module, old_objects, name, new_obj, True)
470 470 ):
471 471 continue
472 472 shell.user_ns[name] = new_obj
473 473
474 474 new_refs = []
475 475 for old_ref in old_objects[key]:
476 476 old_obj = old_ref()
477 477 if old_obj is None:
478 478 continue
479 479 new_refs.append(old_ref)
480 480 update_generic(old_obj, new_obj)
481 481
482 482 if new_refs:
483 483 old_objects[key] = new_refs
484 484 else:
485 485 del old_objects[key]
486 486
487 487 return module
488 488
489 489
490 490 # ------------------------------------------------------------------------------
491 491 # IPython connectivity
492 492 # ------------------------------------------------------------------------------
493 493
494 494 from IPython.core.magic import Magics, magics_class, line_magic
495 495
496 496
497 497 @magics_class
498 498 class AutoreloadMagics(Magics):
499 499 def __init__(self, *a, **kw):
500 500 super().__init__(*a, **kw)
501 501 self._reloader = ModuleReloader(self.shell)
502 502 self._reloader.check_all = False
503 503 self._reloader.autoload_obj = False
504 504 self.loaded_modules = set(sys.modules)
505 505
506 506 @line_magic
507 507 def autoreload(self, parameter_s=""):
508 508 r"""%autoreload => Reload modules automatically
509 509
510 510 %autoreload
511 511 Reload all modules (except those excluded by %aimport) automatically
512 512 now.
513 513
514 514 %autoreload 0
515 515 Disable automatic reloading.
516 516
517 517 %autoreload 1
518 518 Reload all modules imported with %aimport every time before executing
519 519 the Python code typed.
520 520
521 521 %autoreload 2
522 522 Reload all modules (except those excluded by %aimport) every time
523 523 before executing the Python code typed.
524 524
525 525 Reloading Python modules in a reliable way is in general
526 526 difficult, and unexpected things may occur. %autoreload tries to
527 527 work around common pitfalls by replacing function code objects and
528 528 parts of classes previously in the module with new versions. This
529 529 makes the following things to work:
530 530
531 531 - Functions and classes imported via 'from xxx import foo' are upgraded
532 532 to new versions when 'xxx' is reloaded.
533 533
534 534 - Methods and properties of classes are upgraded on reload, so that
535 535 calling 'c.foo()' on an object 'c' created before the reload causes
536 536 the new code for 'foo' to be executed.
537 537
538 538 Some of the known remaining caveats are:
539 539
540 540 - Replacing code objects does not always succeed: changing a @property
541 541 in a class to an ordinary method or a method to a member variable
542 542 can cause problems (but in old objects only).
543 543
544 544 - Functions that are removed (eg. via monkey-patching) from a module
545 545 before it is reloaded are not upgraded.
546 546
547 547 - C extension modules cannot be reloaded, and so cannot be
548 548 autoreloaded.
549 549
550 550 """
551 551 if parameter_s == "":
552 552 self._reloader.check(True)
553 553 elif parameter_s == "0":
554 554 self._reloader.enabled = False
555 555 elif parameter_s == "1":
556 556 self._reloader.check_all = False
557 557 self._reloader.enabled = True
558 558 elif parameter_s == "2":
559 559 self._reloader.check_all = True
560 560 self._reloader.enabled = True
561 561 self._reloader.enabled = True
562 562 elif parameter_s == "3":
563 563 self._reloader.check_all = True
564 564 self._reloader.enabled = True
565 565 self._reloader.autoload_obj = True
566 566
567 567 @line_magic
568 568 def aimport(self, parameter_s="", stream=None):
569 569 """%aimport => Import modules for automatic reloading.
570 570
571 571 %aimport
572 572 List modules to automatically import and not to import.
573 573
574 574 %aimport foo
575 575 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
576 576
577 577 %aimport foo, bar
578 578 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
579 579
580 580 %aimport -foo
581 581 Mark module 'foo' to not be autoreloaded for %autoreload 1
582 582 """
583 583 modname = parameter_s
584 584 if not modname:
585 585 to_reload = sorted(self._reloader.modules.keys())
586 586 to_skip = sorted(self._reloader.skip_modules.keys())
587 587 if stream is None:
588 588 stream = sys.stdout
589 589 if self._reloader.check_all:
590 590 stream.write("Modules to reload:\nall-except-skipped\n")
591 591 else:
592 592 stream.write("Modules to reload:\n%s\n" % " ".join(to_reload))
593 593 stream.write("\nModules to skip:\n%s\n" % " ".join(to_skip))
594 594 elif modname.startswith("-"):
595 595 modname = modname[1:]
596 596 self._reloader.mark_module_skipped(modname)
597 597 else:
598 598 for _module in [_.strip() for _ in modname.split(",")]:
599 599 top_module, top_name = self._reloader.aimport_module(_module)
600 600
601 601 # Inject module to user namespace
602 602 self.shell.push({top_name: top_module})
603 603
604 604 def pre_run_cell(self):
605 605 if self._reloader.enabled:
606 606 try:
607 607 self._reloader.check()
608 608 except:
609 609 pass
610 610
611 611 def post_execute_hook(self):
612 612 """Cache the modification times of any modules imported in this execution"""
613 613 newly_loaded_modules = set(sys.modules) - self.loaded_modules
614 614 for modname in newly_loaded_modules:
615 615 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
616 616 if pymtime is not None:
617 617 self._reloader.modules_mtimes[modname] = pymtime
618 618
619 619 self.loaded_modules.update(newly_loaded_modules)
620 620
621 621
622 622 def load_ipython_extension(ip):
623 623 """Load the extension in IPython."""
624 624 auto_reload = AutoreloadMagics(ip)
625 625 ip.register_magics(auto_reload)
626 626 ip.events.register("pre_run_cell", auto_reload.pre_run_cell)
627 627 ip.events.register("post_execute", auto_reload.post_execute_hook)
General Comments 0
You need to be logged in to leave comments. Login now