From ddeb9bb31452dda01ff1abd1e74664a29e81a3bf 2012-10-31 22:53:11
From: Bradley M. Froehle <brad.froehle@gmail.com>
Date: 2012-10-31 22:53:11
Subject: [PATCH] Merge pull request #2462 from takluyver/extensions-loaded

Track which extensions are loaded
---

diff --git a/IPython/core/extensions.py b/IPython/core/extensions.py
index 594a344..6ee6bcb 100644
--- a/IPython/core/extensions.py
+++ b/IPython/core/extensions.py
@@ -23,8 +23,12 @@ import sys
 from urllib import urlretrieve
 from urlparse import urlparse
 
+from IPython.core.error import UsageError
 from IPython.config.configurable import Configurable
 from IPython.utils.traitlets import Instance
+from IPython.utils.py3compat import PY3
+if PY3:
+    from imp import reload
 
 #-----------------------------------------------------------------------------
 # Main class
@@ -44,10 +48,11 @@ class ExtensionManager(Configurable):
     the only argument.  You can do anything you want with IPython at
     that point, including defining new magic and aliases, adding new
     components, etc.
-
-    The :func:`load_ipython_extension` will be called again is you
-    load or reload the extension again.  It is up to the extension
-    author to add code to manage that.
+    
+    You can also optionaly define an :func:`unload_ipython_extension(ipython)`
+    function, which will be called if the user unloads or reloads the extension.
+    The extension manager will only call :func:`load_ipython_extension` again
+    if the extension is reloaded.
 
     You can put your extension modules anywhere you want, as long as
     they can be imported by Python's standard import mechanism.  However,
@@ -63,6 +68,7 @@ class ExtensionManager(Configurable):
         self.shell.on_trait_change(
             self._on_ipython_dir_changed, 'ipython_dir'
         )
+        self.loaded = set()
 
     def __del__(self):
         self.shell.on_trait_change(
@@ -80,26 +86,43 @@ class ExtensionManager(Configurable):
     def load_extension(self, module_str):
         """Load an IPython extension by its module name.
 
-        If :func:`load_ipython_extension` returns anything, this function
-        will return that object.
+        Returns the string "already loaded" if the extension is already loaded,
+        "no load function" if the module doesn't have a load_ipython_extension
+        function, or None if it succeeded.
         """
+        if module_str in self.loaded:
+            return "already loaded"
+        
         from IPython.utils.syspathcontext import prepended_to_syspath
 
         if module_str not in sys.modules:
             with prepended_to_syspath(self.ipython_extension_dir):
                 __import__(module_str)
         mod = sys.modules[module_str]
-        return self._call_load_ipython_extension(mod)
+        if self._call_load_ipython_extension(mod):
+            self.loaded.add(module_str)
+        else:
+            return "no load function"
 
     def unload_extension(self, module_str):
         """Unload an IPython extension by its module name.
 
         This function looks up the extension's name in ``sys.modules`` and
         simply calls ``mod.unload_ipython_extension(self)``.
+        
+        Returns the string "no unload function" if the extension doesn't define
+        a function to unload itself, "not loaded" if the extension isn't loaded,
+        otherwise None.
         """
+        if module_str not in self.loaded:
+            return "not loaded"
+        
         if module_str in sys.modules:
             mod = sys.modules[module_str]
-            self._call_unload_ipython_extension(mod)
+            if self._call_unload_ipython_extension(mod):
+                self.loaded.discard(module_str)
+            else:
+                return "no unload function"
 
     def reload_extension(self, module_str):
         """Reload an IPython extension by calling reload.
@@ -111,21 +134,25 @@ class ExtensionManager(Configurable):
         """
         from IPython.utils.syspathcontext import prepended_to_syspath
 
-        with prepended_to_syspath(self.ipython_extension_dir):
-            if module_str in sys.modules:
-                mod = sys.modules[module_str]
+        if (module_str in self.loaded) and (module_str in sys.modules):
+            self.unload_extension(module_str)
+            mod = sys.modules[module_str]
+            with prepended_to_syspath(self.ipython_extension_dir):
                 reload(mod)
-                self._call_load_ipython_extension(mod)
-            else:
-                self.load_extension(module_str)
+            if self._call_load_ipython_extension(mod):
+                self.loaded.add(module_str)
+        else:
+            self.load_extension(module_str)
 
     def _call_load_ipython_extension(self, mod):
         if hasattr(mod, 'load_ipython_extension'):
-            return mod.load_ipython_extension(self.shell)
+            mod.load_ipython_extension(self.shell)
+            return True
 
     def _call_unload_ipython_extension(self, mod):
         if hasattr(mod, 'unload_ipython_extension'):
-            return mod.unload_ipython_extension(self.shell)
+            mod.unload_ipython_extension(self.shell)
+            return True
     
     def install_extension(self, url, filename=None):
         """Download and install an IPython extension. 
diff --git a/IPython/core/magics/extension.py b/IPython/core/magics/extension.py
index 93a8bc1..31356b0 100644
--- a/IPython/core/magics/extension.py
+++ b/IPython/core/magics/extension.py
@@ -59,14 +59,30 @@ class ExtensionMagics(Magics):
         """Load an IPython extension by its module name."""
         if not module_str:
             raise UsageError('Missing module name.')
-        return self.shell.extension_manager.load_extension(module_str)
+        res = self.shell.extension_manager.load_extension(module_str)
+        
+        if res == 'already loaded':
+            print "The %s extension is already loaded. To reload it, use:" % module_str
+            print "  %reload_ext", module_str
+        elif res == 'no load function':
+            print "The %s module is not an IPython extension." % module_str
 
     @line_magic
     def unload_ext(self, module_str):
-        """Unload an IPython extension by its module name."""
+        """Unload an IPython extension by its module name.
+        
+        Not all extensions can be unloaded, only those which define an
+        ``unload_ipython_extension`` function.
+        """
         if not module_str:
             raise UsageError('Missing module name.')
-        self.shell.extension_manager.unload_extension(module_str)
+        
+        res = self.shell.extension_manager.unload_extension(module_str)
+        
+        if res == 'no unload function':
+            print "The %s extension doesn't define how to unload it." % module_str
+        elif res == "not loaded":
+            print "The %s extension is not loaded." % module_str
 
     @line_magic
     def reload_ext(self, module_str):
diff --git a/IPython/core/tests/test_extension.py b/IPython/core/tests/test_extension.py
new file mode 100644
index 0000000..43446ba
--- /dev/null
+++ b/IPython/core/tests/test_extension.py
@@ -0,0 +1,73 @@
+import os.path
+
+import nose.tools as nt
+
+import IPython.testing.tools as tt
+from IPython.utils.syspathcontext import prepended_to_syspath
+from IPython.utils.tempdir import TemporaryDirectory
+
+ext1_content = """
+def load_ipython_extension(ip):
+    print("Running ext1 load")
+
+def unload_ipython_extension(ip):
+    print("Running ext1 unload")
+"""
+
+ext2_content = """
+def load_ipython_extension(ip):
+    print("Running ext2 load")
+"""
+
+def test_extension_loading():
+    em = get_ipython().extension_manager
+    with TemporaryDirectory() as td:
+        ext1 = os.path.join(td, 'ext1.py')
+        with open(ext1, 'w') as f:
+            f.write(ext1_content)
+        
+        ext2 = os.path.join(td, 'ext2.py')
+        with open(ext2, 'w') as f:
+            f.write(ext2_content)
+        
+        with prepended_to_syspath(td):
+            assert 'ext1' not in em.loaded
+            assert 'ext2' not in em.loaded
+            
+            # Load extension
+            with tt.AssertPrints("Running ext1 load"):
+                assert em.load_extension('ext1') is None
+            assert 'ext1' in em.loaded
+            
+            # Should refuse to load it again
+            with tt.AssertNotPrints("Running ext1 load"):
+                assert em.load_extension('ext1') == 'already loaded'
+            
+            # Reload
+            with tt.AssertPrints("Running ext1 unload"):
+                with tt.AssertPrints("Running ext1 load", suppress=False):
+                    em.reload_extension('ext1')
+            
+            # Unload
+            with tt.AssertPrints("Running ext1 unload"):
+                assert em.unload_extension('ext1') is None
+            
+            # Can't unload again
+            with tt.AssertNotPrints("Running ext1 unload"):
+                assert em.unload_extension('ext1') == 'not loaded'
+            assert em.unload_extension('ext2') == 'not loaded'
+            
+            # Load extension 2
+            with tt.AssertPrints("Running ext2 load"):
+                assert em.load_extension('ext2') is None
+            
+            # Can't unload this
+            assert em.unload_extension('ext2') == 'no unload function'
+            
+            # But can reload it
+            with tt.AssertPrints("Running ext2 load"):
+                em.reload_extension('ext2')
+
+def test_non_extension():
+    em = get_ipython().extension_manager
+    nt.assert_equal(em.load_extension('sys'), "no load function")
diff --git a/IPython/extensions/autoreload.py b/IPython/extensions/autoreload.py
index 27334ed..40ae40d 100644
--- a/IPython/extensions/autoreload.py
+++ b/IPython/extensions/autoreload.py
@@ -514,14 +514,8 @@ class AutoreloadMagics(Magics):
             pass
 
 
-_loaded = False
-
-
 def load_ipython_extension(ip):
     """Load the extension in IPython."""
-    global _loaded
-    if not _loaded:
-        auto_reload = AutoreloadMagics(ip)
-        ip.register_magics(auto_reload)
-        ip.set_hook('pre_run_code_hook', auto_reload.pre_run_code_hook)
-        _loaded = True
+    auto_reload = AutoreloadMagics(ip)
+    ip.register_magics(auto_reload)
+    ip.set_hook('pre_run_code_hook', auto_reload.pre_run_code_hook)
diff --git a/IPython/extensions/cythonmagic.py b/IPython/extensions/cythonmagic.py
index ae640fe..8edf5d5 100644
--- a/IPython/extensions/cythonmagic.py
+++ b/IPython/extensions/cythonmagic.py
@@ -273,11 +273,7 @@ class CythonMagics(Magics):
         html = '\n'.join(l for l in html.splitlines() if not r.match(l))
         return html
 
-_loaded = False
 
 def load_ipython_extension(ip):
     """Load the extension in IPython."""
-    global _loaded
-    if not _loaded:
-        ip.register_magics(CythonMagics)
-        _loaded = True
+    ip.register_magics(CythonMagics)
diff --git a/IPython/extensions/octavemagic.py b/IPython/extensions/octavemagic.py
index edba5c5..d8798cb 100644
--- a/IPython/extensions/octavemagic.py
+++ b/IPython/extensions/octavemagic.py
@@ -362,10 +362,6 @@ __doc__ = __doc__.format(
     )
 
 
-_loaded = False
 def load_ipython_extension(ip):
     """Load the extension in IPython."""
-    global _loaded
-    if not _loaded:
-        ip.register_magics(OctaveMagics)
-        _loaded = True
+    ip.register_magics(OctaveMagics)
diff --git a/IPython/extensions/rmagic.py b/IPython/extensions/rmagic.py
index f9cf9f7..1fdfa95 100644
--- a/IPython/extensions/rmagic.py
+++ b/IPython/extensions/rmagic.py
@@ -588,10 +588,6 @@ __doc__ = __doc__.format(
 )
 
 
-_loaded = False
 def load_ipython_extension(ip):
     """Load the extension in IPython."""
-    global _loaded
-    if not _loaded:
-        ip.register_magics(RMagics)
-        _loaded = True
+    ip.register_magics(RMagics)
diff --git a/IPython/extensions/storemagic.py b/IPython/extensions/storemagic.py
index b02457b..ca9376b 100644
--- a/IPython/extensions/storemagic.py
+++ b/IPython/extensions/storemagic.py
@@ -209,12 +209,6 @@ class StoreMagics(Magics):
                 print "Stored '%s' (%s)" % (args[0], obj.__class__.__name__)
 
 
-_loaded = False
-
-
 def load_ipython_extension(ip):
     """Load the extension in IPython."""
-    global _loaded
-    if not _loaded:
-        ip.register_magics(StoreMagics)
-        _loaded = True
+    ip.register_magics(StoreMagics)