diff --git a/IPython/config/configurable.py b/IPython/config/configurable.py index 7d04db1..1d63107 100755 --- a/IPython/config/configurable.py +++ b/IPython/config/configurable.py @@ -37,6 +37,9 @@ class ConfigurableError(Exception): pass +class MultipleInstanceError(ConfigurableError): + pass + #----------------------------------------------------------------------------- # Configurable implementation #----------------------------------------------------------------------------- @@ -170,4 +173,62 @@ class Configurable(HasTraits): def class_print_help(cls): print cls.class_get_help() - \ No newline at end of file + +class SingletonConfigurable(Configurable): + """A configurable that only allows one instance. + + This class is for classes that should only have one instance of itself + or *any* subclass. To create and retrieve such a class use the + :meth:`SingletonConfigurable.instance` method. + """ + + _instance = None + + @classmethod + def instance(cls, *args, **kwargs): + """Returns a global instance of this class. + + This method create a new instance if none have previously been created + and returns a previously created instance is one already exists. + + The arguments and keyword arguments passed to this method are passed + on to the :meth:`__init__` method of the class upon instantiation. + + Examples + -------- + + Create a singleton class using instance, and retrieve it:: + + >>> from IPython.config.configurable import SingletonConfigurable + >>> class Foo(SingletonConfigurable): pass + >>> foo = Foo.instance() + >>> foo == Foo.instance() + True + + Create a subclass that is retrived using the base class instance:: + + >>> class Bar(SingletonConfigurable): pass + >>> class Bam(Bar): pass + >>> bam = Bam.instance() + >>> bam == Bar.instance() + True + """ + # Create and save the instance + if cls._instance is None: + inst = cls(*args, **kwargs) + # Now make sure that the instance will also be returned by + # the subclasses instance attribute. + for subclass in cls.mro(): + if issubclass(cls, subclass) and \ + issubclass(subclass, SingletonConfigurable) and \ + subclass != SingletonConfigurable: + subclass._instance = inst + else: + break + if isinstance(cls._instance, cls): + return cls._instance + else: + raise MultipleInstanceError( + 'Multiple incompatible subclass instances of ' + '%s are being created.' % cls.__name__ + ) diff --git a/IPython/config/tests/test_configurable.py b/IPython/config/tests/test_configurable.py index b4d541f..2a5a4ac 100644 --- a/IPython/config/tests/test_configurable.py +++ b/IPython/config/tests/test_configurable.py @@ -22,10 +22,15 @@ Authors: from unittest import TestCase -from IPython.config.configurable import Configurable, ConfigurableError +from IPython.config.configurable import ( + Configurable, + SingletonConfigurable +) + from IPython.utils.traitlets import ( - TraitError, Int, Float, Str + Int, Float, Str ) + from IPython.config.loader import Config @@ -57,7 +62,7 @@ class Bar(Foo): c = Float(config=True, shortname="c", help="The string c.") -class TestConfigurableConfig(TestCase): +class TestConfigurable(TestCase): def test_default(self): c1 = Configurable() @@ -141,3 +146,23 @@ class TestConfigurableConfig(TestCase): def test_help(self): self.assertEquals(MyConfigurable.class_get_help(), mc_help) + +class TestSingletonConfigurable(TestCase): + + def test_instance(self): + from IPython.config.configurable import SingletonConfigurable + class Foo(SingletonConfigurable): pass + foo = Foo.instance() + self.assertEquals(foo, Foo.instance()) + self.assertEquals(SingletonConfigurable._instance, None) + + def test_inheritance(self): + + class Bar(SingletonConfigurable): pass + class Bam(Bar): pass + bam = Bam.instance() + bam == Bar.instance() + self.assertEquals(bam, Bam._instance) + self.assertEquals(bam, Bar._instance) + self.assertEquals(SingletonConfigurable._instance, None) + \ No newline at end of file diff --git a/IPython/config/tests/test_loader.py b/IPython/config/tests/test_loader.py index 3ec7b3c..8a0dcf3 100755 --- a/IPython/config/tests/test_loader.py +++ b/IPython/config/tests/test_loader.py @@ -117,7 +117,6 @@ class TestKeyValueCL(TestCase): def test_basic(self): cl = KeyValueConfigLoader() argv = [s.strip('c.') for s in pyfile.split('\n')[2:-1]] - print argv config = cl.load_config(argv) self.assertEquals(config.a, 10) self.assertEquals(config.b, 20)