diff --git a/IPython/utils/tests/test_traitlets.py b/IPython/utils/tests/test_traitlets.py index 294f376..f2823b2 100644 --- a/IPython/utils/tests/test_traitlets.py +++ b/IPython/utils/tests/test_traitlets.py @@ -30,7 +30,8 @@ from unittest import TestCase from IPython.utils.traitlets import ( HasTraitlets, MetaHasTraitlets, TraitletType, Any, - Int, Long, Float, Complex, Str, Unicode, Bool, TraitletError + Int, Long, Float, Complex, Str, Unicode, Bool, TraitletError, + Undefined, Type, Instance ) @@ -65,15 +66,15 @@ class TestTraitletType(TestCase): self.tt.name = self.name self.hast = HasTraitletsStub() - def test_get(self): + def test_get_undefined(self): value = self.tt.__get__(self.hast) - self.assertEquals(value, None) + self.assertEquals(value, Undefined) def test_set(self): self.tt.__set__(self.hast, 10) self.assertEquals(self.hast._traitlet_values[self.name],10) self.assertEquals(self.hast._notify_name,self.name) - self.assertEquals(self.hast._notify_old,None) + self.assertEquals(self.hast._notify_old,Undefined) self.assertEquals(self.hast._notify_new,10) def test_validate(self): @@ -85,6 +86,20 @@ class TestTraitletType(TestCase): tt.__set__(self.hast, 10) self.assertEquals(tt.__get__(self.hast),-1) + def test_default_validate(self): + class MyIntTT(TraitletType): + def validate(self, obj, value): + if isinstance(value, int): + return value + self.error(obj, value) + tt = MyIntTT(10) + tt.name = 'a' + self.assertEquals(tt.__get__(self.hast), 10) + tt = MyIntTT('bad default') + tt.name = 'b' # different name from 'a' as we want an unset dv + self.assertRaises(TraitletError, tt.__get__, self.hast) + + def test_is_valid_for(self): class MyTT(TraitletType): def is_valid_for(self, value): @@ -92,7 +107,7 @@ class TestTraitletType(TestCase): tt = MyTT() tt.name = self.name tt.__set__(self.hast, 10) - self.assertEquals(tt.__get__(self.hast),10) + self.assertEquals(tt.__get__(self.hast), 10) def test_value_for(self): class MyTT(TraitletType): @@ -101,10 +116,10 @@ class TestTraitletType(TestCase): tt = MyTT() tt.name = self.name tt.__set__(self.hast, 10) - self.assertEquals(tt.__get__(self.hast),20) + self.assertEquals(tt.__get__(self.hast), 20) def test_info(self): - self.assertEquals(self.tt.info(),'any value') + self.assertEquals(self.tt.info(), 'any value') def test_error(self): self.assertRaises(TraitletError, self.tt.error, self.hast, 10) @@ -311,7 +326,113 @@ class TestAddTraitlet(TestCase): # Tests for specific traitlet types #----------------------------------------------------------------------------- + +class TestType(TestCase): + + def test_default(self): + + class B(object): pass + class A(HasTraitlets): + klass = Type + + a = A() + self.assertEquals(a.klass, None) + a.klass = B + self.assertEquals(a.klass, B) + self.assertRaises(TraitletError, setattr, a, 'klass', 10) + + def test_value(self): + + class B(object): pass + class C(object): pass + class A(HasTraitlets): + klass = Type(B) + + a = A() + self.assertEquals(a.klass, B) + self.assertRaises(TraitletError, setattr, a, 'klass', C) + self.assertRaises(TraitletError, setattr, a, 'klass', object) + a.klass = B + + def test_allow_none(self): + + class B(object): pass + class C(B): pass + class A(HasTraitlets): + klass = Type(B, allow_none=False) + + a = A() + self.assertEquals(a.klass, B) + self.assertRaises(TraitletError, setattr, a, 'klass', None) + a.klass = C + self.assertEquals(a.klass, C) + + +class TestInstance(TestCase): + + def test_basic(self): + class Foo(object): pass + class Bar(Foo): pass + class Bah(object): pass + + class A(HasTraitlets): + inst = Instance(Foo) + + a = A() + self.assert_(isinstance(a.inst, Foo)) + a.inst = Foo() + self.assert_(isinstance(a.inst, Foo)) + a.inst = Bar() + self.assert_(isinstance(a.inst, Foo)) + self.assertRaises(TraitletError, setattr, a, 'inst', Foo) + self.assertRaises(TraitletError, setattr, a, 'inst', Bar) + self.assertRaises(TraitletError, setattr, a, 'inst', Bah()) + + def test_unique_default_value(self): + class Foo(object): pass + class A(HasTraitlets): + inst = Instance(Foo) + + a = A() + b = A() + self.assert_(a.inst is not b.inst) + + def test_args_kw(self): + class Foo(object): + def __init__(self, c): self.c = c + + class A(HasTraitlets): + inst = Instance(Foo, args=(10,)) + + a = A() + self.assertEquals(a.inst.c, 10) + + class Bar(object): + def __init__(self, c, d): + self.c = c; self.d = d + + class B(HasTraitlets): + inst = Instance(Bar, args=(10,),kw=dict(d=20)) + b = B() + self.assertEquals(b.inst.c, 10) + self.assertEquals(b.inst.d, 20) + + def test_instance(self): + # Does passing an instance yield a default value of None? + class Foo(object): pass + + class A(HasTraitlets): + inst = Instance(Foo()) + a = A() + self.assertEquals(a.inst, None) + + class B(HasTraitlets): + inst = Instance(Foo(), allow_none=False) + b = B() + self.assertRaises(TraitletError, getattr, b, 'inst') + class TraitletTestBase(TestCase): + """A best testing class for basic traitlet types.""" def assign(self, value): self.obj.value = value diff --git a/IPython/utils/traitlets.py b/IPython/utils/traitlets.py index ea574d0..807a0cc 100644 --- a/IPython/utils/traitlets.py +++ b/IPython/utils/traitlets.py @@ -42,9 +42,13 @@ Authors: # Imports #----------------------------------------------------------------------------- + import inspect +import sys import types -from types import InstanceType +from types import InstanceType, ClassType + +ClassTypes = (ClassType, type) #----------------------------------------------------------------------------- # Basic classes @@ -124,6 +128,12 @@ def parse_notifier_name(name): return name +def get_module_name ( level = 2 ): + """ Returns the name of the module that the caller's caller is located in. + """ + return sys._getframe( level ).f_globals.get( '__name__', '__main__' ) + + #----------------------------------------------------------------------------- # Base TraitletType for all traitlets #----------------------------------------------------------------------------- @@ -132,30 +142,55 @@ def parse_notifier_name(name): class TraitletType(object): metadata = {} - default_value = None + default_value = Undefined info_text = 'any value' def __init__(self, default_value=NoDefaultSpecified, **metadata): + """Create a TraitletType. + """ if default_value is not NoDefaultSpecified: self.default_value = default_value self.metadata.update(metadata) + self.init() - def __get__(self, inst, cls=None): - if inst is None: + def init(self): + pass + + def get_default_value(self): + """Create a new instance of the default value.""" + dv = self.default_value + return dv + + def __get__(self, obj, cls=None): + """Get the value of the traitlet by self.name for the instance. + + The creation of default values is deferred until this is called the + first time. This is done so instances of the parent HasTraitlets + will have their own default value instances. + """ + if obj is None: return self else: - return inst._traitlet_values.get(self.name, self.default_value) - - def __set__(self, inst, value): - new_value = self._validate(inst, value) - old_value = self.__get__(inst) - if old_value != new_value: - inst._traitlet_values[self.name] = new_value - inst._notify(self.name, old_value, new_value) + if not obj._traitlet_values.has_key(self.name): + dv = self.get_default_value() + self.__set__(obj, dv, first=True) + return dv + else: + return obj._traitlet_values[self.name] + + def __set__(self, obj, value, first=False): + new_value = self._validate(obj, value) + if not first: + old_value = self.__get__(obj) + if old_value != new_value: + obj._traitlet_values[self.name] = new_value + obj._notify(self.name, old_value, new_value) + else: + obj._traitlet_values[self.name] = new_value - def _validate(self, inst, value): + def _validate(self, obj, value): if hasattr(self, 'validate'): - return self.validate(inst, value) + return self.validate(obj, value) elif hasattr(self, 'is_valid_for'): valid = self.is_valid_for(value) if valid: @@ -333,10 +368,291 @@ class HasTraitlets(object): inst.name = name setattr(self.__class__, name, inst) + #----------------------------------------------------------------------------- # Actual TraitletTypes implementations/subclasses #----------------------------------------------------------------------------- +#----------------------------------------------------------------------------- +# TraitletTypes subclasses for handling classes and instances of classes +#----------------------------------------------------------------------------- + + +class BaseClassResolver(TraitletType): + """Mixin class for traitlets that need to resolve classes by strings. + + This class provides is a mixin that provides its subclasses with the + ability to resolve classes by specifying a string name (for example, + 'foo.bar.MyClass'). An actual class can also be resolved. + + Any subclass must define instances with 'klass' and 'module' attributes + that contain the string name of the class (or actual class object) and + the module name that contained the original trait definition (used for + resolving local class names (e.g. 'LocalClass')). + """ + + def resolve_class(self, obj, value): + klass = self.validate_class(self.find_class(self.klass)) + if klass is None: + self.validate_failed(obj, value) + + self.klass = klass + + def validate_class(self, klass): + return klass + + def find_class(self, klass): + module = self.module + col = klass.rfind('.') + if col >= 0: + module = klass[ : col ] + klass = klass[ col + 1: ] + + theClass = getattr(sys.modules.get(module), klass, None) + if (theClass is None) and (col >= 0): + try: + mod = __import__(module) + for component in module.split( '.' )[1:]: + mod = getattr(mod, component) + + theClass = getattr(mod, klass, None) + except: + pass + + return theClass + + def validate_failed (self, obj, value): + kind = type(value) + if kind is InstanceType: + msg = 'class %s' % value.__class__.__name__ + else: + msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) ) + + self.error(obj, msg) + + +class Type(BaseClassResolver): + """A traitlet whose value must be a subclass of a specified class.""" + + def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ): + """Construct a Type traitlet + + A Type traitlet specifies that its values must be subclasses of + a particular class. + + Parameters + ---------- + default_value : class or None + The default value must be a subclass of klass. + klass : class, str, None + Values of this traitlet must be a subclass of klass. The klass + may be specified in a string like: 'foo.bar.MyClass'. + allow_none : boolean + Indicates whether None is allowed as an assignable value. Even if + ``False``, the default value may be ``None``. + """ + if default_value is None: + if klass is None: + klass = object + elif klass is None: + klass = default_value + + if isinstance(klass, basestring): + self.validate = self.resolve + elif not isinstance(klass, ClassTypes): + raise TraitletError("A Type traitlet must specify a class.") + + self.klass = klass + self._allow_none = allow_none + self.module = get_module_name() + + super(Type, self).__init__(default_value, **metadata) + + def validate(self, obj, value): + """Validates that the value is a valid object instance.""" + try: + if issubclass(value, self.klass): + return value + except: + if (value is None) and (self._allow_none): + return value + + self.error(obj, value) + + def resolve(self, obj, name, value): + """ Resolves a class originally specified as a string into an actual + class, then resets the trait so that future calls will be handled by + the normal validate method. + """ + if isinstance(self.klass, basestring): + self.resolve_class(obj, value) + del self.validate + + return self.validate(obj, value) + + def info(self): + """ Returns a description of the trait.""" + klass = self.klass + if not isinstance(klass, basestring): + klass = klass.__name__ + + result = 'a subclass of ' + klass + + if self._allow_none: + return result + ' or None' + + return result + + def get_default_value(self): + """ Returns a tuple of the form: ( default_value_type, default_value ) + which describes the default value for this trait. + """ + if not isinstance(self.default_value, basestring): + return super(Type, self).get_default_value() + + dv = self.resolve_default_value() + dvt = type(dv) + return (dvt, dv) + + def resolve_default_value(self): + """ Resolves a class name into a class so that it can be used to + return the class as the default value of the trait. + """ + if isinstance(self.klass, basestring): + try: + self.resolve_class(None, None) + del self.validate + except: + raise TraitletError('Could not resolve %s into a valid class' % + self.klass ) + + return self.klass + + +class DefaultValueGenerator(object): + """A class for generating new default value instances.""" + + def __init__(self, klass, *args, **kw): + self.klass = klass + self.args = args + self.kw = kw + + +class Instance(BaseClassResolver): + """A trait whose value must be an instance of a specified class. + + The value can also be an instance of a subclass of the specified class. + """ + + def __init__(self, klass=None, args=None, kw=None, allow_none=True, + module = None, **metadata ): + """Construct an Instance traitlet. + + Parameters + ---------- + klass : class or instance + The object that forms the basis for the traitlet. If an instance + values must have isinstance(value, type(instance)). + args : tuple + Positional arguments for generating the default value. + kw : dict + Keyword arguments for generating the default value. + allow_none : bool + Indicates whether None is allowed as a value. + + Default Value + ------------- + If klass is an instance, default value is None. If klass is a class + then the default value is obtained by calling ``klass(*args, **kw)``. + If klass is a str, it is first resolved to an actual class and then + instantiated with ``klass(*args, **kw)``. + """ + + self._allow_none = allow_none + self.module = module or get_module_name() + + if klass is None: + raise TraitletError('A %s traitlet must have a class specified.' % + self.__class__.__name__ ) + elif not isinstance(klass, (basestring,) + ClassTypes ): + # klass is an instance so default value will be None + self.klass = klass.__class__ + default_value = None + else: + # klass is a str or class so we handle args, kw + if args is None: + args = () + if kw is None: + if isinstance(args, dict): + kw = args + args = () + else: + kw = {} + if not isinstance(kw, dict): + raise TraitletError("The 'kw' argument must be a dict.") + if not isinstance(args, tuple): + raise TraitletError("The 'args' argument must be a tuple.") + self.klass = klass + # This tells my get_default_value that the default value + # instance needs to be generated when it is called. This + # is usually when TraitletType.__get__ is called for the 1st time. + + default_value = DefaultValueGenerator(klass, *args, **kw) + + super(Instance, self).__init__(default_value, **metadata) + + def validate(self, obj, value): + if value is None: + if self._allow_none: + return value + self.validate_failed(obj, value) + + # This is where self.klass is turned into a real class if it was + # a str initially. This happens the first time TraitletType.__set__ + # is called. This does happen if a default value is generated by + # TraitletType.__get__. + if isinstance(self.klass, basestring): + self.resolve_class(obj, value) + + if isinstance(value, self.klass): + return value + else: + self.validate_failed(obj, value) + + def info ( self ): + klass = self.klass + if not isinstance( klass, basestring ): + klass = klass.__name__ + result = class_of(klass) + if self._allow_none: + return result + ' or None' + + return result + + def get_default_value ( self ): + """Instantiate a default value instance. + + When TraitletType.__get__ is called the first time, this is called + (if no value has been assigned) to get a default value instance. + """ + dv = self.default_value + if isinstance(dv, DefaultValueGenerator): + klass = dv.klass + args = dv.args + kw = dv.kw + if isinstance(klass, basestring): + klass = self.validate_class(self.find_class(klass)) + if klass is None: + raise TraitletError('Unable to locate class: ' + dv.klass) + return klass(*args, **kw) + else: + return dv + + +#----------------------------------------------------------------------------- +# Basic TraitletTypes implementations/subclasses +#----------------------------------------------------------------------------- + class Any(TraitletType): default_value = None