From 6f9046f435bd1e6478406e5035bc35352fe40905 2014-05-09 22:59:35 From: Thomas Kluyver Date: 2014-05-09 22:59:35 Subject: [PATCH] Merge pull request #5353 from Zaharid/nonetraits Allow None for most traits --- diff --git a/IPython/utils/tests/test_traitlets.py b/IPython/utils/tests/test_traitlets.py index 2cbd920..e8da006 100644 --- a/IPython/utils/tests/test_traitlets.py +++ b/IPython/utils/tests/test_traitlets.py @@ -673,6 +673,23 @@ class TraitTestBase(TestCase): if hasattr(self, '_default_value'): self.assertEqual(self._default_value, self.obj.value) + def test_allow_none(self): + if (hasattr(self, '_bad_values') and hasattr(self, '_good_values') and + None in self._bad_values): + trait=self.obj.traits()['value'] + try: + trait.allow_none = True + self._bad_values.remove(None) + #skip coerce. Allow None casts None to None. + self.assign(None) + self.assertEqual(self.obj.value,None) + self.test_good_values() + self.test_bad_values() + finally: + #tear down + trait.allow_none = False + self._bad_values.append(None) + def tearDown(self): # restore default value after tests, if set if hasattr(self, '_default_value'): @@ -830,7 +847,7 @@ class TestObjectName(TraitTestBase): _default_value = "abc" _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"] _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]", - object(), object] + None, object(), object] if sys.version_info[0] < 3: _bad_values.append(u"þ") else: @@ -845,7 +862,7 @@ class TestDottedObjectName(TraitTestBase): _default_value = "a.b" _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"] - _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."] + _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None] if sys.version_info[0] < 3: _bad_values.append(u"t.þ") else: @@ -862,7 +879,7 @@ class TestTCPAddress(TraitTestBase): _default_value = ('127.0.0.1',0) _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)] - _bad_values = [(0,0),('localhost',10.0),('localhost',-1)] + _bad_values = [(0,0),('localhost',10.0),('localhost',-1), None] class ListTrait(HasTraits): @@ -919,14 +936,14 @@ class TestLenList(TraitTestBase): class TupleTrait(HasTraits): - value = Tuple(Int) + value = Tuple(Int(allow_none=True)) class TestTupleTrait(TraitTestBase): obj = TupleTrait() _default_value = None - _good_values = [(1,), None, (0,), [1]] + _good_values = [(1,), None, (0,), [1], (None,)] _bad_values = [10, (1,2), ('a'), ()] def coerce(self, value): diff --git a/IPython/utils/traitlets.py b/IPython/utils/traitlets.py index 0cc0c75..69d7b64 100644 --- a/IPython/utils/traitlets.py +++ b/IPython/utils/traitlets.py @@ -252,13 +252,16 @@ class TraitType(object): metadata = {} default_value = Undefined + allow_none = False info_text = 'any value' - def __init__(self, default_value=NoDefaultSpecified, **metadata): + def __init__(self, default_value=NoDefaultSpecified, allow_none=None, **metadata): """Create a TraitType. """ if default_value is not NoDefaultSpecified: self.default_value = default_value + if allow_none is not None: + self.allow_none = allow_none if len(metadata) > 0: if len(self.metadata) > 0: @@ -371,6 +374,8 @@ class TraitType(object): obj._notify_trait(self.name, old_value, new_value) def _validate(self, obj, value): + if value is None and self.allow_none: + return value if hasattr(self, 'validate'): return self.validate(obj, value) elif hasattr(self, 'is_valid_for'): @@ -745,9 +750,8 @@ class Type(ClassBasedTraitType): raise TraitError("A Type trait must specify a class.") self.klass = klass - self._allow_none = allow_none - super(Type, self).__init__(default_value, **metadata) + super(Type, self).__init__(default_value, allow_none=allow_none, **metadata) def validate(self, obj, value): """Validates that the value is a valid object instance.""" @@ -755,8 +759,7 @@ class Type(ClassBasedTraitType): if issubclass(value, self.klass): return value except: - if (value is None) and (self._allow_none): - return value + pass self.error(obj, value) @@ -767,7 +770,7 @@ class Type(ClassBasedTraitType): else: klass = self.klass.__name__ result = 'a subclass of ' + klass - if self._allow_none: + if self.allow_none: return result + ' or None' return result @@ -831,8 +834,6 @@ class Instance(ClassBasedTraitType): not (but not both), None is replace by ``()`` or ``{}``. """ - self._allow_none = allow_none - if (klass is None) or (not (inspect.isclass(klass) or isinstance(klass, py3compat.string_types))): raise TraitError('The klass argument must be a class' ' you gave: %r' % klass) @@ -856,14 +857,9 @@ class Instance(ClassBasedTraitType): default_value = DefaultValueGenerator(*args, **kw) - super(Instance, self).__init__(default_value, **metadata) + super(Instance, self).__init__(default_value, allow_none=allow_none, **metadata) def validate(self, obj, value): - if value is None: - if self._allow_none: - return value - self.error(obj, value) - if isinstance(value, self.klass): return value else: @@ -875,7 +871,7 @@ class Instance(ClassBasedTraitType): else: klass = self.klass.__name__ result = class_of(klass) - if self._allow_none: + if self.allow_none: return result + ' or None' return result @@ -1169,14 +1165,9 @@ class Enum(TraitType): def __init__(self, values, default_value=None, allow_none=True, **metadata): self.values = values - self._allow_none = allow_none - super(Enum, self).__init__(default_value, **metadata) + super(Enum, self).__init__(default_value, allow_none=allow_none, **metadata) def validate(self, obj, value): - if value is None: - if self._allow_none: - return value - if value in self.values: return value self.error(obj, value) @@ -1184,7 +1175,7 @@ class Enum(TraitType): def info(self): """ Returns a description of the trait.""" result = 'any of ' + repr(self.values) - if self._allow_none: + if self.allow_none: return result + ' or None' return result @@ -1192,10 +1183,6 @@ class CaselessStrEnum(Enum): """An enum of strings that are caseless in validate.""" def validate(self, obj, value): - if value is None: - if self._allow_none: - return value - if not isinstance(value, py3compat.string_types): self.error(obj, value) @@ -1290,7 +1277,7 @@ class Container(Instance): return value for v in value: try: - v = self._trait.validate(obj, v) + v = self._trait._validate(obj, v) except TraitError: self.element_error(obj, v, self._trait) else: @@ -1366,8 +1353,6 @@ class List(Container): def validate(self, obj, value): value = super(List, self).validate(obj, value) - if value is None: - return value value = self.validate_elements(obj, value) @@ -1463,7 +1448,7 @@ class Tuple(Container): validated = [] for t,v in zip(self._traits, value): try: - v = t.validate(obj, v) + v = t._validate(obj, v) except TraitError: self.element_error(obj, v, t) else: