##// END OF EJS Templates
Combined recursive approach with check for already visited objects to avoid infinite recursion....
Niclas -
Show More
@@ -1,642 +1,591 b''
1 1 """IPython extension to reload modules before executing user code.
2 2
3 3 ``autoreload`` reloads modules automatically before entering the execution of
4 4 code typed at the IPython prompt.
5 5
6 6 This makes for example the following workflow possible:
7 7
8 8 .. sourcecode:: ipython
9 9
10 10 In [1]: %load_ext autoreload
11 11
12 12 In [2]: %autoreload 2
13 13
14 14 In [3]: from foo import some_function
15 15
16 16 In [4]: some_function()
17 17 Out[4]: 42
18 18
19 19 In [5]: # open foo.py in an editor and change some_function to return 43
20 20
21 21 In [6]: some_function()
22 22 Out[6]: 43
23 23
24 24 The module was reloaded without reloading it explicitly, and the object
25 25 imported with ``from foo import ...`` was also updated.
26 26
27 27 Usage
28 28 =====
29 29
30 30 The following magic commands are provided:
31 31
32 32 ``%autoreload``
33 33
34 34 Reload all modules (except those excluded by ``%aimport``)
35 35 automatically now.
36 36
37 37 ``%autoreload 0``
38 38
39 39 Disable automatic reloading.
40 40
41 41 ``%autoreload 1``
42 42
43 43 Reload all modules imported with ``%aimport`` every time before
44 44 executing the Python code typed.
45 45
46 46 ``%autoreload 2``
47 47
48 48 Reload all modules (except those excluded by ``%aimport``) every
49 49 time before executing the Python code typed.
50 50
51 51 ``%aimport``
52 52
53 53 List modules which are to be automatically imported or not to be imported.
54 54
55 55 ``%aimport foo``
56 56
57 57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
58 58
59 59 ``%aimport foo, bar``
60 60
61 61 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
62 62
63 63 ``%aimport -foo``
64 64
65 65 Mark module 'foo' to not be autoreloaded.
66 66
67 67 Caveats
68 68 =======
69 69
70 70 Reloading Python modules in a reliable way is in general difficult,
71 71 and unexpected things may occur. ``%autoreload`` tries to work around
72 72 common pitfalls by replacing function code objects and parts of
73 73 classes previously in the module with new versions. This makes the
74 74 following things to work:
75 75
76 76 - Functions and classes imported via 'from xxx import foo' are upgraded
77 77 to new versions when 'xxx' is reloaded.
78 78
79 79 - Methods and properties of classes are upgraded on reload, so that
80 80 calling 'c.foo()' on an object 'c' created before the reload causes
81 81 the new code for 'foo' to be executed.
82 82
83 83 Some of the known remaining caveats are:
84 84
85 85 - Replacing code objects does not always succeed: changing a @property
86 86 in a class to an ordinary method or a method to a member variable
87 87 can cause problems (but in old objects only).
88 88
89 89 - Functions that are removed (eg. via monkey-patching) from a module
90 90 before it is reloaded are not upgraded.
91 91
92 92 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
93 93 """
94 94
95 95 skip_doctest = True
96 96
97 97 #-----------------------------------------------------------------------------
98 98 # Copyright (C) 2000 Thomas Heller
99 99 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
100 100 # Copyright (C) 2012 The IPython Development Team
101 101 #
102 102 # Distributed under the terms of the BSD License. The full license is in
103 103 # the file COPYING, distributed as part of this software.
104 104 #-----------------------------------------------------------------------------
105 105 #
106 106 # This IPython module is written by Pauli Virtanen, based on the autoreload
107 107 # code by Thomas Heller.
108 108
109 109 #-----------------------------------------------------------------------------
110 110 # Imports
111 111 #-----------------------------------------------------------------------------
112 112
113 113 import os
114 114 import sys
115 115 import traceback
116 116 import types
117 117 import weakref
118 118 import inspect
119 119 from importlib import import_module
120 120 from importlib.util import source_from_cache
121 121 from imp import reload
122 122
123 123 #------------------------------------------------------------------------------
124 124 # Autoreload functionality
125 125 #------------------------------------------------------------------------------
126 126
127 127 class ModuleReloader(object):
128 128 enabled = False
129 129 """Whether this reloader is enabled"""
130 130
131 131 check_all = True
132 132 """Autoreload all modules, not just those listed in 'modules'"""
133 133
134 134 def __init__(self):
135 135 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
136 136 self.failed = {}
137 137 # Modules specially marked as autoreloadable.
138 138 self.modules = {}
139 139 # Modules specially marked as not autoreloadable.
140 140 self.skip_modules = {}
141 141 # (module-name, name) -> weakref, for replacing old code objects
142 142 self.old_objects = {}
143 143 # Module modification timestamps
144 144 self.modules_mtimes = {}
145 145
146 146 # Cache module modification times
147 147 self.check(check_all=True, do_reload=False)
148 148
149 149 def mark_module_skipped(self, module_name):
150 150 """Skip reloading the named module in the future"""
151 151 try:
152 152 del self.modules[module_name]
153 153 except KeyError:
154 154 pass
155 155 self.skip_modules[module_name] = True
156 156
157 157 def mark_module_reloadable(self, module_name):
158 158 """Reload the named module in the future (if it is imported)"""
159 159 try:
160 160 del self.skip_modules[module_name]
161 161 except KeyError:
162 162 pass
163 163 self.modules[module_name] = True
164 164
165 165 def aimport_module(self, module_name):
166 166 """Import a module, and mark it reloadable
167 167
168 168 Returns
169 169 -------
170 170 top_module : module
171 171 The imported module if it is top-level, or the top-level
172 172 top_name : module
173 173 Name of top_module
174 174
175 175 """
176 176 self.mark_module_reloadable(module_name)
177 177
178 178 import_module(module_name)
179 179 top_name = module_name.split('.')[0]
180 180 top_module = sys.modules[top_name]
181 181 return top_module, top_name
182 182
183 183 def filename_and_mtime(self, module):
184 184 if not hasattr(module, '__file__') or module.__file__ is None:
185 185 return None, None
186 186
187 187 if getattr(module, '__name__', None) in [None, '__mp_main__', '__main__']:
188 188 # we cannot reload(__main__) or reload(__mp_main__)
189 189 return None, None
190 190
191 191 filename = module.__file__
192 192 path, ext = os.path.splitext(filename)
193 193
194 194 if ext.lower() == '.py':
195 195 py_filename = filename
196 196 else:
197 197 try:
198 198 py_filename = source_from_cache(filename)
199 199 except ValueError:
200 200 return None, None
201 201
202 202 try:
203 203 pymtime = os.stat(py_filename).st_mtime
204 204 except OSError:
205 205 return None, None
206 206
207 207 return py_filename, pymtime
208 208
209 209 def check(self, check_all=False, do_reload=True):
210 210 """Check whether some modules need to be reloaded."""
211 211
212 212 if not self.enabled and not check_all:
213 213 return
214 214
215 215 if check_all or self.check_all:
216 216 modules = list(sys.modules.keys())
217 217 else:
218 218 modules = list(self.modules.keys())
219 219
220 220 for modname in modules:
221 221 m = sys.modules.get(modname, None)
222 222
223 223 if modname in self.skip_modules:
224 224 continue
225 225
226 226 py_filename, pymtime = self.filename_and_mtime(m)
227 227 if py_filename is None:
228 228 continue
229 229
230 230 try:
231 231 if pymtime <= self.modules_mtimes[modname]:
232 232 continue
233 233 except KeyError:
234 234 self.modules_mtimes[modname] = pymtime
235 235 continue
236 236 else:
237 237 if self.failed.get(py_filename, None) == pymtime:
238 238 continue
239 239
240 240 self.modules_mtimes[modname] = pymtime
241 241
242 242 # If we've reached this point, we should try to reload the module
243 243 if do_reload:
244 244 try:
245 245 superreload(m, reload, self.old_objects)
246 246 if py_filename in self.failed:
247 247 del self.failed[py_filename]
248 248 except:
249 249 print("[autoreload of %s failed: %s]" % (
250 250 modname, traceback.format_exc(10)), file=sys.stderr)
251 251 self.failed[py_filename] = pymtime
252 252
253 253 #------------------------------------------------------------------------------
254 254 # superreload
255 255 #------------------------------------------------------------------------------
256 256
257 257
258 258 func_attrs = ['__code__', '__defaults__', '__doc__',
259 259 '__closure__', '__globals__', '__dict__']
260 260
261 261
262 262 def update_function(old, new):
263 263 """Upgrade the code object of a function"""
264 264 for name in func_attrs:
265 265 try:
266 266 setattr(old, name, getattr(new, name))
267 267 except (AttributeError, TypeError):
268 268 pass
269 269
270 270
271 def _find_instances(old_type):
272 """Try to find all instances of a class that need updating.
273
274 Classic graph exploration, we want to avoid re-visiting object multiple times.
275 """
276 # find ipython workspace stack frame, this is just to bootstrap where we
277 # find the object that need updating.
278 frame = next(frame_nfo.frame for frame_nfo in inspect.stack()
279 if 'trigger' in frame_nfo.function)
280 # build generator for non-private variable values from workspace
281 shell = frame.f_locals['self'].shell
282 user_ns = shell.user_ns
283 user_ns_hidden = shell.user_ns_hidden
284 nonmatching = object()
285 objects = ( value for key, value in user_ns.items()
286 if not key.startswith('_')
287 and (value is not user_ns_hidden.get(key, nonmatching))
288 and not inspect.ismodule(value))
289
290 # note: in the following we do use dict as object might not be hashable.
291 # list of objects we found that will need an update.
292 to_update = {}
293
294 # list of object we have not recursed into yet
295 open_set = {}
296
297 # list of object we have visited already
298 closed_set = {}
299
300 open_set.update({id(o):o for o in objects})
301
302 it = 0
303 while len(open_set) > 0:
304 it += 1
305 if it > 100_000:
306 raise ValueError('infinite')
307 (current_id,current) = next(iter(open_set.items()))
308 if type(current) is old_type:
309 to_update[current_id] = current
310 if hasattr(current, '__dict__') and not (inspect.isfunction(current)
311 or inspect.ismethod(current)):
312 potential_new = {id(o):o for o in current.__dict__.values() if id(o) not in closed_set.keys()}
313 open_set.update(potential_new)
314 # if object is a container, search it
315 if hasattr(current, 'items') or (hasattr(current, '__contains__')
316 and not isinstance(current, str)):
317 potential_new = (value for key, value in current.items()
318 if not str(key).startswith('_')
319 and not inspect.ismodule(value) and not id(value) in closed_set.keys())
320 open_set.update(potential_new)
321 del open_set[id(current)]
322 closed_set[id(current)] = current
323 return to_update.values()
324
325 def update_instances(old, new, objects=None):
271 def update_instances(old, new, objects=None, visited={}):
326 272 """Iterate through objects recursively, searching for instances of old and
327 273 replace their __class__ reference with new. If no objects are given, start
328 274 with the current ipython workspace.
329 275 """
330 if not objects:
276 if objects is None:
277 # make sure visited is cleaned when not called recursively
278 visited = {}
331 279 # find ipython workspace stack frame
332 280 frame = next(frame_nfo.frame for frame_nfo in inspect.stack()
333 281 if 'trigger' in frame_nfo.function)
334 282 # build generator for non-private variable values from workspace
335 283 shell = frame.f_locals['self'].shell
336 284 user_ns = shell.user_ns
337 285 user_ns_hidden = shell.user_ns_hidden
338 286 nonmatching = object()
339 287 objects = ( value for key, value in user_ns.items()
340 288 if not key.startswith('_')
341 289 and (value is not user_ns_hidden.get(key, nonmatching))
342 290 and not inspect.ismodule(value))
343 291
344 292 # use dict values if objects is a dict but don't touch private variables
345 293 if hasattr(objects, 'items'):
346 294 objects = (value for key, value in objects.items()
347 295 if not str(key).startswith('_')
348 296 and not inspect.ismodule(value) )
349 297
350 298 # try if objects is iterable
351 299 try:
352 for obj in objects:
300 for obj in (obj for obj in objects if id(obj) not in visited):
301 # add current object to visited to avoid revisiting
302 visited.update({id(obj):obj})
353 303
354 304 # update, if object is instance of old_class (but no subclasses)
355 305 if type(obj) is old:
356 306 obj.__class__ = new
357 307
358 308
359 309 # if object is instance of other class, look for nested instances
360 310 if hasattr(obj, '__dict__') and not (inspect.isfunction(obj)
361 311 or inspect.ismethod(obj)):
362 update_instances(old, new, obj.__dict__)
312 update_instances(old, new, obj.__dict__, visited)
363 313
364 314 # if object is a container, search it
365 315 if hasattr(obj, 'items') or (hasattr(obj, '__contains__')
366 316 and not isinstance(obj, str)):
367 update_instances(old, new, obj)
317 update_instances(old, new, obj, visited)
368 318
369 319 except TypeError:
370 320 pass
371 321
372 322
373 323 def update_class(old, new):
374 324 """Replace stuff in the __dict__ of a class, and upgrade
375 325 method code objects, and add new methods, if any"""
376 print('old is', old)
326 print('old is', id(old))
377 327 for key in list(old.__dict__.keys()):
378 328 old_obj = getattr(old, key)
379 329 try:
380 330 new_obj = getattr(new, key)
381 331 # explicitly checking that comparison returns True to handle
382 332 # cases where `==` doesn't return a boolean.
383 333 if (old_obj == new_obj) is True:
384 334 continue
385 335 except AttributeError:
386 336 # obsolete attribute: remove it
387 337 try:
388 338 delattr(old, key)
389 339 except (AttributeError, TypeError):
390 340 pass
391 341 continue
392 342
393 343 if update_generic(old_obj, new_obj): continue
394 344
395 345 try:
396 346 setattr(old, key, getattr(new, key))
397 347 except (AttributeError, TypeError):
398 348 pass # skip non-writable attributes
399 349
400 350 for key in list(new.__dict__.keys()):
401 351 if key not in list(old.__dict__.keys()):
402 352 try:
403 353 setattr(old, key, getattr(new, key))
404 354 except (AttributeError, TypeError):
405 355 pass # skip non-writable attributes
406 356
407 357 # update all instances of class
408 for instance in _find_instances(old):
409 instance.__class__ = new
358 update_instances(old, new)
410 359
411 360
412 361 def update_property(old, new):
413 362 """Replace get/set/del functions of a property"""
414 363 update_generic(old.fdel, new.fdel)
415 364 update_generic(old.fget, new.fget)
416 365 update_generic(old.fset, new.fset)
417 366
418 367
419 368 def isinstance2(a, b, typ):
420 369 return isinstance(a, typ) and isinstance(b, typ)
421 370
422 371
423 372 UPDATE_RULES = [
424 373 (lambda a, b: isinstance2(a, b, type),
425 374 update_class),
426 375 (lambda a, b: isinstance2(a, b, types.FunctionType),
427 376 update_function),
428 377 (lambda a, b: isinstance2(a, b, property),
429 378 update_property),
430 379 ]
431 380 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
432 381 lambda a, b: update_function(a.__func__, b.__func__)),
433 382 ])
434 383
435 384
436 385 def update_generic(a, b):
437 386 for type_check, update in UPDATE_RULES:
438 387 if type_check(a, b):
439 388 update(a, b)
440 389 return True
441 390 return False
442 391
443 392
444 393 class StrongRef(object):
445 394 def __init__(self, obj):
446 395 self.obj = obj
447 396 def __call__(self):
448 397 return self.obj
449 398
450 399
451 400 def superreload(module, reload=reload, old_objects=None):
452 401 """Enhanced version of the builtin reload function.
453 402
454 403 superreload remembers objects previously in the module, and
455 404
456 405 - upgrades the class dictionary of every old class in the module
457 406 - upgrades the code object of every old function and method
458 407 - clears the module's namespace before reloading
459 408
460 409 """
461 410 if old_objects is None:
462 411 old_objects = {}
463 412
464 413 # collect old objects in the module
465 414 for name, obj in list(module.__dict__.items()):
466 415 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
467 416 continue
468 417 key = (module.__name__, name)
469 418 try:
470 419 old_objects.setdefault(key, []).append(weakref.ref(obj))
471 420 except TypeError:
472 421 pass
473 422
474 423 # reload module
475 424 try:
476 425 # clear namespace first from old cruft
477 426 old_dict = module.__dict__.copy()
478 427 old_name = module.__name__
479 428 module.__dict__.clear()
480 429 module.__dict__['__name__'] = old_name
481 430 module.__dict__['__loader__'] = old_dict['__loader__']
482 431 except (TypeError, AttributeError, KeyError):
483 432 pass
484 433
485 434 try:
486 435 module = reload(module)
487 436 except:
488 437 # restore module dictionary on failed reload
489 438 module.__dict__.update(old_dict)
490 439 raise
491 440
492 441 # iterate over all objects and update functions & classes
493 442 for name, new_obj in list(module.__dict__.items()):
494 443 key = (module.__name__, name)
495 444 if key not in old_objects: continue
496 445
497 446 new_refs = []
498 447 for old_ref in old_objects[key]:
499 448 old_obj = old_ref()
500 449 if old_obj is None: continue
501 450 new_refs.append(old_ref)
502 451 update_generic(old_obj, new_obj)
503 452
504 453 if new_refs:
505 454 old_objects[key] = new_refs
506 455 else:
507 456 del old_objects[key]
508 457
509 458 return module
510 459
511 460 #------------------------------------------------------------------------------
512 461 # IPython connectivity
513 462 #------------------------------------------------------------------------------
514 463
515 464 from IPython.core.magic import Magics, magics_class, line_magic
516 465
517 466 @magics_class
518 467 class AutoreloadMagics(Magics):
519 468 def __init__(self, *a, **kw):
520 469 super(AutoreloadMagics, self).__init__(*a, **kw)
521 470 self._reloader = ModuleReloader()
522 471 self._reloader.check_all = False
523 472 self.loaded_modules = set(sys.modules)
524 473
525 474 @line_magic
526 475 def autoreload(self, parameter_s=''):
527 476 r"""%autoreload => Reload modules automatically
528 477
529 478 %autoreload
530 479 Reload all modules (except those excluded by %aimport) automatically
531 480 now.
532 481
533 482 %autoreload 0
534 483 Disable automatic reloading.
535 484
536 485 %autoreload 1
537 486 Reload all modules imported with %aimport every time before executing
538 487 the Python code typed.
539 488
540 489 %autoreload 2
541 490 Reload all modules (except those excluded by %aimport) every time
542 491 before executing the Python code typed.
543 492
544 493 Reloading Python modules in a reliable way is in general
545 494 difficult, and unexpected things may occur. %autoreload tries to
546 495 work around common pitfalls by replacing function code objects and
547 496 parts of classes previously in the module with new versions. This
548 497 makes the following things to work:
549 498
550 499 - Functions and classes imported via 'from xxx import foo' are upgraded
551 500 to new versions when 'xxx' is reloaded.
552 501
553 502 - Methods and properties of classes are upgraded on reload, so that
554 503 calling 'c.foo()' on an object 'c' created before the reload causes
555 504 the new code for 'foo' to be executed.
556 505
557 506 Some of the known remaining caveats are:
558 507
559 508 - Replacing code objects does not always succeed: changing a @property
560 509 in a class to an ordinary method or a method to a member variable
561 510 can cause problems (but in old objects only).
562 511
563 512 - Functions that are removed (eg. via monkey-patching) from a module
564 513 before it is reloaded are not upgraded.
565 514
566 515 - C extension modules cannot be reloaded, and so cannot be
567 516 autoreloaded.
568 517
569 518 """
570 519 if parameter_s == '':
571 520 self._reloader.check(True)
572 521 elif parameter_s == '0':
573 522 self._reloader.enabled = False
574 523 elif parameter_s == '1':
575 524 self._reloader.check_all = False
576 525 self._reloader.enabled = True
577 526 elif parameter_s == '2':
578 527 self._reloader.check_all = True
579 528 self._reloader.enabled = True
580 529
581 530 @line_magic
582 531 def aimport(self, parameter_s='', stream=None):
583 532 """%aimport => Import modules for automatic reloading.
584 533
585 534 %aimport
586 535 List modules to automatically import and not to import.
587 536
588 537 %aimport foo
589 538 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
590 539
591 540 %aimport foo, bar
592 541 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
593 542
594 543 %aimport -foo
595 544 Mark module 'foo' to not be autoreloaded for %autoreload 1
596 545 """
597 546 modname = parameter_s
598 547 if not modname:
599 548 to_reload = sorted(self._reloader.modules.keys())
600 549 to_skip = sorted(self._reloader.skip_modules.keys())
601 550 if stream is None:
602 551 stream = sys.stdout
603 552 if self._reloader.check_all:
604 553 stream.write("Modules to reload:\nall-except-skipped\n")
605 554 else:
606 555 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
607 556 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
608 557 elif modname.startswith('-'):
609 558 modname = modname[1:]
610 559 self._reloader.mark_module_skipped(modname)
611 560 else:
612 561 for _module in ([_.strip() for _ in modname.split(',')]):
613 562 top_module, top_name = self._reloader.aimport_module(_module)
614 563
615 564 # Inject module to user namespace
616 565 self.shell.push({top_name: top_module})
617 566
618 567 def pre_run_cell(self):
619 568 if self._reloader.enabled:
620 569 try:
621 570 self._reloader.check()
622 571 except:
623 572 pass
624 573
625 574 def post_execute_hook(self):
626 575 """Cache the modification times of any modules imported in this execution
627 576 """
628 577 newly_loaded_modules = set(sys.modules) - self.loaded_modules
629 578 for modname in newly_loaded_modules:
630 579 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
631 580 if pymtime is not None:
632 581 self._reloader.modules_mtimes[modname] = pymtime
633 582
634 583 self.loaded_modules.update(newly_loaded_modules)
635 584
636 585
637 586 def load_ipython_extension(ip):
638 587 """Load the extension in IPython."""
639 588 auto_reload = AutoreloadMagics(ip)
640 589 ip.register_magics(auto_reload)
641 590 ip.events.register('pre_run_cell', auto_reload.pre_run_cell)
642 591 ip.events.register('post_execute', auto_reload.post_execute_hook)
General Comments 0
You need to be logged in to leave comments. Login now