##// END OF EJS Templates
Call `lower` only once
Emilio Graff -
Show More
@@ -1,680 +1,681 b''
1 1 """IPython extension to reload modules before executing user code.
2 2
3 3 ``autoreload`` reloads modules automatically before entering the execution of
4 4 code typed at the IPython prompt.
5 5
6 6 This makes for example the following workflow possible:
7 7
8 8 .. sourcecode:: ipython
9 9
10 10 In [1]: %load_ext autoreload
11 11
12 12 In [2]: %autoreload 2
13 13
14 14 In [3]: from foo import some_function
15 15
16 16 In [4]: some_function()
17 17 Out[4]: 42
18 18
19 19 In [5]: # open foo.py in an editor and change some_function to return 43
20 20
21 21 In [6]: some_function()
22 22 Out[6]: 43
23 23
24 24 The module was reloaded without reloading it explicitly, and the object
25 25 imported with ``from foo import ...`` was also updated.
26 26
27 27 Usage
28 28 =====
29 29
30 30 The following magic commands are provided:
31 31
32 32 ``%autoreload``, ``%autoreload now``
33 33
34 34 Reload all modules (except those excluded by ``%aimport``)
35 35 automatically now.
36 36
37 37 ``%autoreload 0``, ``%autoreload off``
38 38
39 39 Disable automatic reloading.
40 40
41 41 ``%autoreload 1``, ``%autoreload explicit``
42 42
43 43 Reload all modules imported with ``%aimport`` every time before
44 44 executing the Python code typed.
45 45
46 46 ``%autoreload 2``, ``%autoreload all``
47 47
48 48 Reload all modules (except those excluded by ``%aimport``) every
49 49 time before executing the Python code typed.
50 50
51 51 ``%autoreload 3``, ``%autoreload complete``
52 52
53 53 Same as 2/all, but also adds any new objects in the module. See
54 54 unit test at IPython/extensions/tests/test_autoreload.py::test_autoload_newly_added_objects
55 55
56 56 ``%aimport``
57 57
58 58 List modules which are to be automatically imported or not to be imported.
59 59
60 60 ``%aimport foo``
61 61
62 62 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
63 63
64 64 ``%aimport foo, bar``
65 65
66 66 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
67 67
68 68 ``%aimport -foo``
69 69
70 70 Mark module 'foo' to not be autoreloaded.
71 71
72 72 ``%averbose off``
73 73
74 74 Perform autoreload tasks quietly
75 75
76 76 ``%averbose on``
77 77
78 78 Report activity with `print` statements.
79 79
80 80 ``%averbose log``
81 81
82 82 Report activity with the logger.
83 83
84 84 Caveats
85 85 =======
86 86
87 87 Reloading Python modules in a reliable way is in general difficult,
88 88 and unexpected things may occur. ``%autoreload`` tries to work around
89 89 common pitfalls by replacing function code objects and parts of
90 90 classes previously in the module with new versions. This makes the
91 91 following things to work:
92 92
93 93 - Functions and classes imported via 'from xxx import foo' are upgraded
94 94 to new versions when 'xxx' is reloaded.
95 95
96 96 - Methods and properties of classes are upgraded on reload, so that
97 97 calling 'c.foo()' on an object 'c' created before the reload causes
98 98 the new code for 'foo' to be executed.
99 99
100 100 Some of the known remaining caveats are:
101 101
102 102 - Replacing code objects does not always succeed: changing a @property
103 103 in a class to an ordinary method or a method to a member variable
104 104 can cause problems (but in old objects only).
105 105
106 106 - Functions that are removed (eg. via monkey-patching) from a module
107 107 before it is reloaded are not upgraded.
108 108
109 109 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
110 110
111 111 - While comparing Enum and Flag, the 'is' Identity Operator is used (even in the case '==' has been used (Similar to the 'None' keyword)).
112 112
113 113 - Reloading a module, or importing the same module by a different name, creates new Enums. These may look the same, but are not.
114 114 """
115 115
116 116 from IPython.core.magic import Magics, magics_class, line_magic
117 117
118 118 __skip_doctest__ = True
119 119
120 120 # -----------------------------------------------------------------------------
121 121 # Copyright (C) 2000 Thomas Heller
122 122 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
123 123 # Copyright (C) 2012 The IPython Development Team
124 124 #
125 125 # Distributed under the terms of the BSD License. The full license is in
126 126 # the file COPYING, distributed as part of this software.
127 127 # -----------------------------------------------------------------------------
128 128 #
129 129 # This IPython module is written by Pauli Virtanen, based on the autoreload
130 130 # code by Thomas Heller.
131 131
132 132 # -----------------------------------------------------------------------------
133 133 # Imports
134 134 # -----------------------------------------------------------------------------
135 135
136 136 import os
137 137 import sys
138 138 import traceback
139 139 import types
140 140 import weakref
141 141 import gc
142 142 import logging
143 143 from importlib import import_module, reload
144 144 from importlib.util import source_from_cache
145 145
146 146 # ------------------------------------------------------------------------------
147 147 # Autoreload functionality
148 148 # ------------------------------------------------------------------------------
149 149
150 150
151 151 class ModuleReloader:
152 152 enabled = False
153 153 """Whether this reloader is enabled"""
154 154
155 155 check_all = True
156 156 """Autoreload all modules, not just those listed in 'modules'"""
157 157
158 158 autoload_obj = False
159 159 """Autoreload all modules AND autoload all new objects"""
160 160
161 161 def __init__(self, shell=None):
162 162 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
163 163 self.failed = {}
164 164 # Modules specially marked as autoreloadable.
165 165 self.modules = {}
166 166 # Modules specially marked as not autoreloadable.
167 167 self.skip_modules = {}
168 168 # (module-name, name) -> weakref, for replacing old code objects
169 169 self.old_objects = {}
170 170 # Module modification timestamps
171 171 self.modules_mtimes = {}
172 172 self.shell = shell
173 173
174 174 # Reporting callable for verbosity
175 175 self._report = lambda msg: None # by default, be quiet.
176 176
177 177 # Cache module modification times
178 178 self.check(check_all=True, do_reload=False)
179 179
180 180 def mark_module_skipped(self, module_name):
181 181 """Skip reloading the named module in the future"""
182 182 try:
183 183 del self.modules[module_name]
184 184 except KeyError:
185 185 pass
186 186 self.skip_modules[module_name] = True
187 187
188 188 def mark_module_reloadable(self, module_name):
189 189 """Reload the named module in the future (if it is imported)"""
190 190 try:
191 191 del self.skip_modules[module_name]
192 192 except KeyError:
193 193 pass
194 194 self.modules[module_name] = True
195 195
196 196 def aimport_module(self, module_name):
197 197 """Import a module, and mark it reloadable
198 198
199 199 Returns
200 200 -------
201 201 top_module : module
202 202 The imported module if it is top-level, or the top-level
203 203 top_name : module
204 204 Name of top_module
205 205
206 206 """
207 207 self.mark_module_reloadable(module_name)
208 208
209 209 import_module(module_name)
210 210 top_name = module_name.split(".")[0]
211 211 top_module = sys.modules[top_name]
212 212 return top_module, top_name
213 213
214 214 def filename_and_mtime(self, module):
215 215 if not hasattr(module, "__file__") or module.__file__ is None:
216 216 return None, None
217 217
218 218 if getattr(module, "__name__", None) in [None, "__mp_main__", "__main__"]:
219 219 # we cannot reload(__main__) or reload(__mp_main__)
220 220 return None, None
221 221
222 222 filename = module.__file__
223 223 path, ext = os.path.splitext(filename)
224 224
225 225 if ext.lower() == ".py":
226 226 py_filename = filename
227 227 else:
228 228 try:
229 229 py_filename = source_from_cache(filename)
230 230 except ValueError:
231 231 return None, None
232 232
233 233 try:
234 234 pymtime = os.stat(py_filename).st_mtime
235 235 except OSError:
236 236 return None, None
237 237
238 238 return py_filename, pymtime
239 239
240 240 def check(self, check_all=False, do_reload=True):
241 241 """Check whether some modules need to be reloaded."""
242 242
243 243 if not self.enabled and not check_all:
244 244 return
245 245
246 246 if check_all or self.check_all:
247 247 modules = list(sys.modules.keys())
248 248 else:
249 249 modules = list(self.modules.keys())
250 250
251 251 for modname in modules:
252 252 m = sys.modules.get(modname, None)
253 253
254 254 if modname in self.skip_modules:
255 255 continue
256 256
257 257 py_filename, pymtime = self.filename_and_mtime(m)
258 258 if py_filename is None:
259 259 continue
260 260
261 261 try:
262 262 if pymtime <= self.modules_mtimes[modname]:
263 263 continue
264 264 except KeyError:
265 265 self.modules_mtimes[modname] = pymtime
266 266 continue
267 267 else:
268 268 if self.failed.get(py_filename, None) == pymtime:
269 269 continue
270 270
271 271 self.modules_mtimes[modname] = pymtime
272 272
273 273 # If we've reached this point, we should try to reload the module
274 274 if do_reload:
275 275 self._report(f"Reloading '{modname}'.")
276 276 try:
277 277 if self.autoload_obj:
278 278 superreload(m, reload, self.old_objects, self.shell)
279 279 else:
280 280 superreload(m, reload, self.old_objects)
281 281 if py_filename in self.failed:
282 282 del self.failed[py_filename]
283 283 except:
284 284 print(
285 285 "[autoreload of {} failed: {}]".format(
286 286 modname, traceback.format_exc(10)
287 287 ),
288 288 file=sys.stderr,
289 289 )
290 290 self.failed[py_filename] = pymtime
291 291
292 292
293 293 # ------------------------------------------------------------------------------
294 294 # superreload
295 295 # ------------------------------------------------------------------------------
296 296
297 297
298 298 func_attrs = [
299 299 "__code__",
300 300 "__defaults__",
301 301 "__doc__",
302 302 "__closure__",
303 303 "__globals__",
304 304 "__dict__",
305 305 ]
306 306
307 307
308 308 def update_function(old, new):
309 309 """Upgrade the code object of a function"""
310 310 for name in func_attrs:
311 311 try:
312 312 setattr(old, name, getattr(new, name))
313 313 except (AttributeError, TypeError):
314 314 pass
315 315
316 316
317 317 def update_instances(old, new):
318 318 """Use garbage collector to find all instances that refer to the old
319 319 class definition and update their __class__ to point to the new class
320 320 definition"""
321 321
322 322 refs = gc.get_referrers(old)
323 323
324 324 for ref in refs:
325 325 if type(ref) is old:
326 326 object.__setattr__(ref, "__class__", new)
327 327
328 328
329 329 def update_class(old, new):
330 330 """Replace stuff in the __dict__ of a class, and upgrade
331 331 method code objects, and add new methods, if any"""
332 332 for key in list(old.__dict__.keys()):
333 333 old_obj = getattr(old, key)
334 334 try:
335 335 new_obj = getattr(new, key)
336 336 # explicitly checking that comparison returns True to handle
337 337 # cases where `==` doesn't return a boolean.
338 338 if (old_obj == new_obj) is True:
339 339 continue
340 340 except AttributeError:
341 341 # obsolete attribute: remove it
342 342 try:
343 343 delattr(old, key)
344 344 except (AttributeError, TypeError):
345 345 pass
346 346 continue
347 347 except ValueError:
348 348 # can't compare nested structures containing
349 349 # numpy arrays using `==`
350 350 pass
351 351
352 352 if update_generic(old_obj, new_obj):
353 353 continue
354 354
355 355 try:
356 356 setattr(old, key, getattr(new, key))
357 357 except (AttributeError, TypeError):
358 358 pass # skip non-writable attributes
359 359
360 360 for key in list(new.__dict__.keys()):
361 361 if key not in list(old.__dict__.keys()):
362 362 try:
363 363 setattr(old, key, getattr(new, key))
364 364 except (AttributeError, TypeError):
365 365 pass # skip non-writable attributes
366 366
367 367 # update all instances of class
368 368 update_instances(old, new)
369 369
370 370
371 371 def update_property(old, new):
372 372 """Replace get/set/del functions of a property"""
373 373 update_generic(old.fdel, new.fdel)
374 374 update_generic(old.fget, new.fget)
375 375 update_generic(old.fset, new.fset)
376 376
377 377
378 378 def isinstance2(a, b, typ):
379 379 return isinstance(a, typ) and isinstance(b, typ)
380 380
381 381
382 382 UPDATE_RULES = [
383 383 (lambda a, b: isinstance2(a, b, type), update_class),
384 384 (lambda a, b: isinstance2(a, b, types.FunctionType), update_function),
385 385 (lambda a, b: isinstance2(a, b, property), update_property),
386 386 ]
387 387 UPDATE_RULES.extend(
388 388 [
389 389 (
390 390 lambda a, b: isinstance2(a, b, types.MethodType),
391 391 lambda a, b: update_function(a.__func__, b.__func__),
392 392 ),
393 393 ]
394 394 )
395 395
396 396
397 397 def update_generic(a, b):
398 398 for type_check, update in UPDATE_RULES:
399 399 if type_check(a, b):
400 400 update(a, b)
401 401 return True
402 402 return False
403 403
404 404
405 405 class StrongRef:
406 406 def __init__(self, obj):
407 407 self.obj = obj
408 408
409 409 def __call__(self):
410 410 return self.obj
411 411
412 412
413 413 mod_attrs = [
414 414 "__name__",
415 415 "__doc__",
416 416 "__package__",
417 417 "__loader__",
418 418 "__spec__",
419 419 "__file__",
420 420 "__cached__",
421 421 "__builtins__",
422 422 ]
423 423
424 424
425 425 def append_obj(module, d, name, obj, autoload=False):
426 426 in_module = hasattr(obj, "__module__") and obj.__module__ == module.__name__
427 427 if autoload:
428 428 # check needed for module global built-ins
429 429 if not in_module and name in mod_attrs:
430 430 return False
431 431 else:
432 432 if not in_module:
433 433 return False
434 434
435 435 key = (module.__name__, name)
436 436 try:
437 437 d.setdefault(key, []).append(weakref.ref(obj))
438 438 except TypeError:
439 439 pass
440 440 return True
441 441
442 442
443 443 def superreload(module, reload=reload, old_objects=None, shell=None):
444 444 """Enhanced version of the builtin reload function.
445 445
446 446 superreload remembers objects previously in the module, and
447 447
448 448 - upgrades the class dictionary of every old class in the module
449 449 - upgrades the code object of every old function and method
450 450 - clears the module's namespace before reloading
451 451
452 452 """
453 453 if old_objects is None:
454 454 old_objects = {}
455 455
456 456 # collect old objects in the module
457 457 for name, obj in list(module.__dict__.items()):
458 458 if not append_obj(module, old_objects, name, obj):
459 459 continue
460 460 key = (module.__name__, name)
461 461 try:
462 462 old_objects.setdefault(key, []).append(weakref.ref(obj))
463 463 except TypeError:
464 464 pass
465 465
466 466 # reload module
467 467 try:
468 468 # clear namespace first from old cruft
469 469 old_dict = module.__dict__.copy()
470 470 old_name = module.__name__
471 471 module.__dict__.clear()
472 472 module.__dict__["__name__"] = old_name
473 473 module.__dict__["__loader__"] = old_dict["__loader__"]
474 474 except (TypeError, AttributeError, KeyError):
475 475 pass
476 476
477 477 try:
478 478 module = reload(module)
479 479 except:
480 480 # restore module dictionary on failed reload
481 481 module.__dict__.update(old_dict)
482 482 raise
483 483
484 484 # iterate over all objects and update functions & classes
485 485 for name, new_obj in list(module.__dict__.items()):
486 486 key = (module.__name__, name)
487 487 if key not in old_objects:
488 488 # here 'shell' acts both as a flag and as an output var
489 489 if (
490 490 shell is None
491 491 or name == "Enum"
492 492 or not append_obj(module, old_objects, name, new_obj, True)
493 493 ):
494 494 continue
495 495 shell.user_ns[name] = new_obj
496 496
497 497 new_refs = []
498 498 for old_ref in old_objects[key]:
499 499 old_obj = old_ref()
500 500 if old_obj is None:
501 501 continue
502 502 new_refs.append(old_ref)
503 503 update_generic(old_obj, new_obj)
504 504
505 505 if new_refs:
506 506 old_objects[key] = new_refs
507 507 else:
508 508 del old_objects[key]
509 509
510 510 return module
511 511
512 512
513 513 # ------------------------------------------------------------------------------
514 514 # IPython connectivity
515 515 # ------------------------------------------------------------------------------
516 516
517 517
518 518 @magics_class
519 519 class AutoreloadMagics(Magics):
520 520 def __init__(self, *a, **kw):
521 521 super().__init__(*a, **kw)
522 522 self._reloader = ModuleReloader(self.shell)
523 523 self._reloader.check_all = False
524 524 self._reloader.autoload_obj = False
525 525 self.loaded_modules = set(sys.modules)
526 526
527 527 @line_magic
528 528 def autoreload(self, parameter_s=""):
529 529 r"""%autoreload => Reload modules automatically
530 530
531 531 %autoreload or %autoreload now
532 532 Reload all modules (except those excluded by %aimport) automatically
533 533 now.
534 534
535 535 %autoreload 0 or %autoreload off
536 536 Disable automatic reloading.
537 537
538 538 %autoreload 1 or %autoreload explicit
539 539 Reload only modules imported with %aimport every time before executing
540 540 the Python code typed.
541 541
542 542 %autoreload 2 or %autoreload all
543 543 Reload all modules (except those excluded by %aimport) every time
544 544 before executing the Python code typed.
545 545
546 546 %autoreload 3 or %autoreload complete
547 547 Same as 2/all, but also but also adds any new objects in the module. See
548 548 unit test at IPython/extensions/tests/test_autoreload.py::test_autoload_newly_added_objects
549 549
550 550 Reloading Python modules in a reliable way is in general
551 551 difficult, and unexpected things may occur. %autoreload tries to
552 552 work around common pitfalls by replacing function code objects and
553 553 parts of classes previously in the module with new versions. This
554 554 makes the following things to work:
555 555
556 556 - Functions and classes imported via 'from xxx import foo' are upgraded
557 557 to new versions when 'xxx' is reloaded.
558 558
559 559 - Methods and properties of classes are upgraded on reload, so that
560 560 calling 'c.foo()' on an object 'c' created before the reload causes
561 561 the new code for 'foo' to be executed.
562 562
563 563 Some of the known remaining caveats are:
564 564
565 565 - Replacing code objects does not always succeed: changing a @property
566 566 in a class to an ordinary method or a method to a member variable
567 567 can cause problems (but in old objects only).
568 568
569 569 - Functions that are removed (eg. via monkey-patching) from a module
570 570 before it is reloaded are not upgraded.
571 571
572 572 - C extension modules cannot be reloaded, and so cannot be
573 573 autoreloaded.
574 574
575 575 """
576 if parameter_s == "" or parameter_s.lower() == "now":
576 parameter_s_lower = parameter_s.lower()
577 if parameter_s == "" or parameter_s_lower == "now":
577 578 self._reloader.check(True)
578 elif parameter_s == "0" or parameter_s.lower() == "off":
579 elif parameter_s == "0" or parameter_s_lower == "off":
579 580 self._reloader.enabled = False
580 elif parameter_s == "1" or parameter_s.lower() == "explicit":
581 elif parameter_s == "1" or parameter_s_lower == "explicit":
581 582 self._reloader.check_all = False
582 583 self._reloader.enabled = True
583 elif parameter_s == "2" or parameter_s.lower() == "all":
584 elif parameter_s == "2" or parameter_s_lower == "all":
584 585 self._reloader.check_all = True
585 586 self._reloader.enabled = True
586 elif parameter_s == "3" or parameter_s.lower() == "complete":
587 elif parameter_s == "3" or parameter_s_lower == "complete":
587 588 self._reloader.check_all = True
588 589 self._reloader.enabled = True
589 590 self._reloader.autoload_obj = True
590 591 else:
591 592 raise ValueError(f'Unrecognized parameter "{parameter_s}".')
592 593
593 594 @line_magic
594 595 def aimport(self, parameter_s="", stream=None):
595 596 """%aimport => Import modules for automatic reloading.
596 597
597 598 %aimport
598 599 List modules to automatically import and not to import.
599 600
600 601 %aimport foo
601 602 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
602 603
603 604 %aimport foo, bar
604 605 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
605 606
606 607 %aimport -foo, bar
607 608 Mark module 'foo' to not be autoreloaded for %autoreload 1, 2, or 3, and 'bar'
608 609 to be autoreloaded for 1.
609 610 """
610 611 modname = parameter_s
611 612 if not modname:
612 613 to_reload = sorted(self._reloader.modules.keys())
613 614 to_skip = sorted(self._reloader.skip_modules.keys())
614 615 if stream is None:
615 616 stream = sys.stdout
616 617 if self._reloader.check_all:
617 618 stream.write("Modules to reload:\nall-except-skipped\n")
618 619 else:
619 620 stream.write("Modules to reload:\n%s\n" % " ".join(to_reload))
620 621 stream.write("\nModules to skip:\n%s\n" % " ".join(to_skip))
621 622 else:
622 623 for _module in [_.strip() for _ in modname.split(",")]:
623 624 if _module.startswith("-"):
624 625 _module = _module[1:].strip()
625 626 self._reloader.mark_module_skipped(_module)
626 627 else:
627 628 top_module, top_name = self._reloader.aimport_module(_module)
628 629
629 630 # Inject module to user namespace
630 631 self.shell.push({top_name: top_module})
631 632
632 633 @line_magic
633 634 def averbose(self, parameter_s=""):
634 635 r"""%averbose => Turn verbosity on/off for autoreloading.
635 636
636 637 %averbose 0 or %averbose off
637 638 Turn off any reporting during autoreload.
638 639
639 640 %averbose 1 or %averbose on
640 641 Report autoreload activity via print statements.
641 642
642 643 %averbose 2 or %averbose log
643 644 Report autoreload activity via logging.
644 645 """
645 646
646 647 if parameter_s == "0" or parameter_s.lower() == "off":
647 648 self._reloader._report = lambda msg: None
648 649 elif parameter_s == "1" or parameter_s.lower() == "on":
649 650 self._reloader._report = lambda msg: print(msg)
650 651 elif parameter_s == "2" or parameter_s.lower() == "log":
651 652 self._reloader._report = lambda msg: logging.getLogger("autoreload").info(
652 653 msg
653 654 )
654 655 else:
655 656 raise ValueError(f'Unrecognized parameter "{parameter_s}".')
656 657
657 658 def pre_run_cell(self):
658 659 if self._reloader.enabled:
659 660 try:
660 661 self._reloader.check()
661 662 except:
662 663 pass
663 664
664 665 def post_execute_hook(self):
665 666 """Cache the modification times of any modules imported in this execution"""
666 667 newly_loaded_modules = set(sys.modules) - self.loaded_modules
667 668 for modname in newly_loaded_modules:
668 669 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
669 670 if pymtime is not None:
670 671 self._reloader.modules_mtimes[modname] = pymtime
671 672
672 673 self.loaded_modules.update(newly_loaded_modules)
673 674
674 675
675 676 def load_ipython_extension(ip):
676 677 """Load the extension in IPython."""
677 678 auto_reload = AutoreloadMagics(ip)
678 679 ip.register_magics(auto_reload)
679 680 ip.events.register("pre_run_cell", auto_reload.pre_run_cell)
680 681 ip.events.register("post_execute", auto_reload.post_execute_hook)
General Comments 0
You need to be logged in to leave comments. Login now