diff --git a/IPython/utils/tests/test_traitlets.py b/IPython/utils/tests/test_traitlets.py index cc1360b..1a63c67 100644 --- a/IPython/utils/tests/test_traitlets.py +++ b/IPython/utils/tests/test_traitlets.py @@ -16,7 +16,7 @@ import nose.tools as nt from nose import SkipTest from IPython.utils.traitlets import ( - HasTraits, MetaHasTraits, TraitType, AllowNone, Any, CBytes, Dict, + HasTraits, MetaHasTraits, TraitType, Any, CBytes, Dict, Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError, Undefined, Type, This, Instance, TCPAddress, List, Tuple, ObjectName, DottedObjectName, CRegExp, link @@ -73,7 +73,7 @@ class TestTraitType(TestCase): self.assertEqual(a.tt, -1) def test_default_validate(self): - class MyIntTT(AllowNone): + class MyIntTT(TraitType): def validate(self, obj, value): if isinstance(value, int): return value @@ -354,29 +354,29 @@ class TestHasTraitsNotify(TestCase): class A(HasTraits): listen_to = ['a'] - + a = Int(0) b = 0 - + def __init__(self, **kwargs): super(A, self).__init__(**kwargs) self.on_trait_change(self.listener1, ['a']) - + def listener1(self, name, old, new): self.b += 1 class B(A): - + c = 0 d = 0 - + def __init__(self, **kwargs): super(B, self).__init__(**kwargs) self.on_trait_change(self.listener2) - + def listener2(self, name, old, new): self.c += 1 - + def _a_changed(self, name, old, new): self.d += 1 @@ -442,7 +442,7 @@ class TestHasTraits(TestCase): def __init__(self, i): super(A, self).__init__() self.i = i - + a = A(5) self.assertEqual(a.i, 5) # should raise TypeError if no positional arg given @@ -677,19 +677,19 @@ class TraitTestBase(TestCase): if (hasattr(self, '_bad_values') and hasattr(self, '_good_values') and None in self._bad_values): trait=self.obj.traits()['value'] - if isinstance(trait, AllowNone) and not trait._allow_none: - try: - trait._allow_none = True - self._bad_values.remove(None) - #skip coerce. Allow None casts None to None. - self.assign(None) - self.assertEqual(self.obj.value,None) - self.test_good_values() - self.test_bad_values() - finally: - #tear down - trait._allow_none = False - self._bad_values.append(None) + try: + trait.allow_none = True + self._bad_values.remove(None) + #skip coerce. Allow None casts None to None. + self.assign(None) + self.assertEqual(self.obj.value,None) + self.test_good_values() + self.test_bad_values() + finally: + #tear down + trait.allow_none = False + self._bad_values.append(None) + print "bad values %s" % self def tearDown(self): @@ -894,7 +894,7 @@ class TestList(TraitTestBase): _default_value = [] _good_values = [[], [1], list(range(10)), (1,2)] _bad_values = [10, [1,'a'], 'a'] - + def coerce(self, value): if value is not None: value = list(value) @@ -1073,7 +1073,7 @@ class TestLink(TestCase): count = Int() a = A(value=9) b = B(count=8) - + # Register callbacks that count. callback_count = [] def a_callback(name, old, new): @@ -1124,4 +1124,4 @@ def test_pickle_hastraits(): c2 = pickle.loads(p) nt.assert_equal(c2.i, c.i) nt.assert_equal(c2.j, c.j) - + \ No newline at end of file diff --git a/IPython/utils/traitlets.py b/IPython/utils/traitlets.py index d1bb095..36711d9 100644 --- a/IPython/utils/traitlets.py +++ b/IPython/utils/traitlets.py @@ -220,7 +220,7 @@ class link(object): for obj,attr in self.objects.keys(): if obj is not sending_obj or attr != sending_attr: setattr(obj, attr, new) - + def unlink(self): for key, callback in self.objects.items(): (obj,attr) = key @@ -252,13 +252,16 @@ class TraitType(object): metadata = {} default_value = Undefined + allow_none = False info_text = 'any value' - def __init__(self, default_value=NoDefaultSpecified, **metadata): + def __init__(self, default_value=NoDefaultSpecified, allow_none=None, **metadata): """Create a TraitType. """ if default_value is not NoDefaultSpecified: self.default_value = default_value + if allow_none is not None: + self.allow_none = allow_none if len(metadata) > 0: if len(self.metadata) > 0: @@ -371,6 +374,8 @@ class TraitType(object): obj._notify_trait(self.name, old_value, new_value) def _validate(self, obj, value): + if value is None and self.allow_none: + return value if hasattr(self, 'validate'): return self.validate(obj, value) elif hasattr(self, 'is_valid_for'): @@ -677,29 +682,16 @@ class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)): else: return trait.get_metadata(key) -class AllowNone(TraitType): - """A trait that can be set to allow None values. It does not provide - validation.""" - def __init__(self, default_value=NoDefaultSpecified, allow_none = False, **metadata): - self._allow_none = allow_none - super(AllowNone, self).__init__(default_value, **metadata) - - def _none_ok(self, value): - """The validate method can return the None value.""" - return value is None and self._allow_none - - #----------------------------------------------------------------------------- # Actual TraitTypes implementations/subclasses #----------------------------------------------------------------------------- - #----------------------------------------------------------------------------- # TraitTypes subclasses for handling classes and instances of classes #----------------------------------------------------------------------------- -class ClassBasedTraitType(AllowNone): +class ClassBasedTraitType(TraitType): """A trait with error reporting for Type, Instance and This.""" def error(self, obj, value): @@ -767,8 +759,7 @@ class Type(ClassBasedTraitType): if issubclass(value, self.klass): return value except: - if self._none_ok(value): - return value + pass self.error(obj, value) @@ -779,7 +770,7 @@ class Type(ClassBasedTraitType): else: klass = self.klass.__name__ result = 'a subclass of ' + klass - if self._allow_none: + if self.allow_none: return result + ' or None' return result @@ -869,11 +860,6 @@ class Instance(ClassBasedTraitType): super(Instance, self).__init__(default_value, allow_none, **metadata) def validate(self, obj, value): - if value is None: - if self._allow_none: - return value - self.error(obj, value) - if isinstance(value, self.klass): return value else: @@ -885,7 +871,7 @@ class Instance(ClassBasedTraitType): else: klass = self.klass.__name__ result = class_of(klass) - if self._allow_none: + if self.allow_none: return result + ' or None' return result @@ -945,14 +931,14 @@ class Any(TraitType): info_text = 'any value' -class Int(AllowNone): +class Int(TraitType): """An int trait.""" default_value = 0 info_text = 'an int' def validate(self, obj, value): - if isinstance(value, int) or self._none_ok(value): + if isinstance(value, int): return value self.error(obj, value) @@ -963,22 +949,20 @@ class CInt(Int): try: return int(value) except: - if self._none_ok(value): - return value self.error(obj, value) if py3compat.PY3: Long, CLong = Int, CInt Integer = Int else: - class Long(AllowNone): + class Long(TraitType): """A long integer trait.""" default_value = 0 info_text = 'a long' def validate(self, obj, value): - if isinstance(value, long) or self._none_ok(value): + if isinstance(value, long): return value if isinstance(value, int): return long(value) @@ -992,11 +976,9 @@ else: try: return long(value) except: - if self._none_ok(value): - return value self.error(obj, value) - class Integer(AllowNone): + class Integer(TraitType): """An integer trait. Longs that are unnecessary (<= sys.maxint) are cast to ints.""" @@ -1005,7 +987,7 @@ else: info_text = 'an integer' def validate(self, obj, value): - if isinstance(value, int) or self._none_ok(value): + if isinstance(value, int): return value if isinstance(value, long): # downcast longs that fit in int: @@ -1019,14 +1001,14 @@ else: self.error(obj, value) -class Float(AllowNone): +class Float(TraitType): """A float trait.""" default_value = 0.0 info_text = 'a float' def validate(self, obj, value): - if isinstance(value, float ) or self._none_ok(value): + if isinstance(value, float): return value if isinstance(value, int): return float(value) @@ -1040,18 +1022,16 @@ class CFloat(Float): try: return float(value) except: - if self._none_ok(value): - return value self.error(obj, value) -class Complex(AllowNone): +class Complex(TraitType): """A trait for complex numbers.""" default_value = 0.0 + 0.0j info_text = 'a complex number' def validate(self, obj, value): - if isinstance(value, complex) or self._none_ok(value): + if isinstance(value, complex): return value if isinstance(value, (float, int)): return complex(value) @@ -1065,21 +1045,19 @@ class CComplex(Complex): try: return complex(value) except: - if self._noe_ok(value): - return value self.error(obj, value) # We should always be explicit about whether we're using bytes or unicode, both # for Python 3 conversion and for reliable unicode behaviour on Python 2. So # we don't have a Str type. -class Bytes(AllowNone): +class Bytes(TraitType): """A trait for byte strings.""" default_value = b'' info_text = 'a bytes object' def validate(self, obj, value): - if isinstance(value, bytes) or self._none_ok(value): + if isinstance(value, bytes): return value self.error(obj, value) @@ -1091,19 +1069,17 @@ class CBytes(Bytes): try: return bytes(value) except: - if self._none_ok(value): - return value self.error(obj, value) -class Unicode(AllowNone): +class Unicode(TraitType): """A trait for unicode strings.""" default_value = u'' info_text = 'a unicode string' def validate(self, obj, value): - if isinstance(value, py3compat.unicode_type) or self._none_ok(value): + if isinstance(value, py3compat.unicode_type): return value if isinstance(value, bytes): try: @@ -1121,12 +1097,10 @@ class CUnicode(Unicode): try: return py3compat.unicode_type(value) except: - if self._allow_none(value): - return value self.error(obj, value) -class ObjectName(AllowNone): +class ObjectName(TraitType): """A string holding a valid object name in this version of Python. This does not check that the name exists in any scope.""" @@ -1148,8 +1122,6 @@ class ObjectName(AllowNone): return value def validate(self, obj, value): - if self._none_ok(value): - return value value = self.coerce_str(obj, value) if isinstance(value, str) and py3compat.isidentifier(value): @@ -1159,8 +1131,6 @@ class ObjectName(AllowNone): class DottedObjectName(ObjectName): """A string holding a valid dotted object name in Python, such as A.b3._c""" def validate(self, obj, value): - if self._none_ok(value): - return value value = self.coerce_str(obj, value) if isinstance(value, str) and py3compat.isidentifier(value, dotted=True): @@ -1168,14 +1138,14 @@ class DottedObjectName(ObjectName): self.error(obj, value) -class Bool(AllowNone): +class Bool(TraitType): """A boolean (True, False) trait.""" default_value = False info_text = 'a boolean' def validate(self, obj, value): - if isinstance(value, bool) or self._none_ok(value): + if isinstance(value, bool): return value self.error(obj, value) @@ -1195,14 +1165,9 @@ class Enum(TraitType): def __init__(self, values, default_value=None, allow_none=True, **metadata): self.values = values - self._allow_none = allow_none - super(Enum, self).__init__(default_value, **metadata) + super(Enum, self).__init__(default_value, allow_none, **metadata) def validate(self, obj, value): - if value is None: - if self._allow_none: - return value - if value in self.values: return value self.error(obj, value) @@ -1210,7 +1175,7 @@ class Enum(TraitType): def info(self): """ Returns a description of the trait.""" result = 'any of ' + repr(self.values) - if self._allow_none: + if self.allow_none: return result + ' or None' return result @@ -1218,10 +1183,6 @@ class CaselessStrEnum(Enum): """An enum of strings that are caseless in validate.""" def validate(self, obj, value): - if value is None: - if self._allow_none: - return value - if not isinstance(value, py3compat.string_types): self.error(obj, value) @@ -1384,7 +1345,7 @@ class List(Container): self.length_error(obj, value) return super(List, self).validate_elements(obj, value) - + def validate(self, obj, value): value = super(List, self).validate(obj, value) if value is None: @@ -1393,7 +1354,7 @@ class List(Container): value = self.validate_elements(obj, value) return value - + class Set(List): @@ -1513,7 +1474,7 @@ class Dict(Instance): super(Dict,self).__init__(klass=dict, args=args, allow_none=allow_none, **metadata) -class TCPAddress(AllowNone): +class TCPAddress(TraitType): """A trait for an (ip, port) tuple. This allows for both IPv4 IP addresses as well as hostnames. @@ -1529,11 +1490,9 @@ class TCPAddress(AllowNone): port = value[1] if port >= 0 and port <= 65535: return value - if self._none_ok(value): - return value self.error(obj, value) -class CRegExp(AllowNone): +class CRegExp(TraitType): """A casting compiled regular expression trait. Accepts both strings and compiled regular expressions. The resulting @@ -1545,6 +1504,4 @@ class CRegExp(AllowNone): try: return re.compile(value) except: - if self._none_ok(value): - return value self.error(obj, value)