diff --git a/IPython/utils/tests/test_traitlets.py b/IPython/utils/tests/test_traitlets.py index 8344a3d..4bb70cb 100644 --- a/IPython/utils/tests/test_traitlets.py +++ b/IPython/utils/tests/test_traitlets.py @@ -1079,6 +1079,19 @@ def test_dict_assignment(): nt.assert_equal(d, c.value) nt.assert_true(c.value is d) +class ValidatedDictTrait(HasTraits): + + value = Dict(Unicode()) + +class TestInstanceDict(TraitTestBase): + + obj = ValidatedDictTrait() + + _default_value = {} + _good_values = [{'0': 'foo'}, {'1': 'bar'}] + _bad_values = [{'0': 0}, {'1': 1}] + + def test_dict_default_value(): """Check that the `{}` default value of the Dict traitlet constructor is actually copied.""" diff --git a/IPython/utils/traitlets.py b/IPython/utils/traitlets.py index e124971..297cd07 100644 --- a/IPython/utils/traitlets.py +++ b/IPython/utils/traitlets.py @@ -1645,13 +1645,33 @@ class Tuple(Container): class Dict(Instance): """An instance of a Python dict.""" + _trait = None - def __init__(self, default_value={}, allow_none=False, **metadata): + def __init__(self, trait=None, default_value=NoDefaultSpecified, allow_none=False, **metadata): """Create a dict trait type from a dict. The default value is created by doing ``dict(default_value)``, which creates a copy of the ``default_value``. + + trait : TraitType [ optional ] + the type for restricting the contents of the Container. If unspecified, + types are not checked. + + default_value : SequenceType [ optional ] + The default value for the Dict. Must be dict, tuple, or None, and + will be cast to a dict if not None. If `trait` is specified, the + `default_value` must conform to the constraints it specifies. + + allow_none : bool [ default False ] + Whether to allow the value to be None + """ + if default_value is NoDefaultSpecified and trait is not None: + if not is_trait(trait): + default_value = trait + trait = None + if default_value is NoDefaultSpecified: + default_value = {} if default_value is None: args = None elif isinstance(default_value, dict): @@ -1661,9 +1681,48 @@ class Dict(Instance): else: raise TypeError('default value of Dict was %s' % default_value) + if is_trait(trait): + self._trait = trait() if isinstance(trait, type) else trait + self._trait.name = 'element' + elif trait is not None: + raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait)) + super(Dict,self).__init__(klass=dict, args=args, allow_none=allow_none, **metadata) + def element_error(self, obj, element, validator): + e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \ + % (self.name, class_of(obj), validator.info(), repr_type(element)) + raise TraitError(e) + + def validate(self, obj, value): + value = super(Dict, self).validate(obj, value) + if value is None: + return value + value = self.validate_elements(obj, value) + return value + + def validate_elements(self, obj, value): + if self._trait is None or isinstance(self._trait, Any): + return value + validated = {} + for key in value: + v = value[key] + try: + v = self._trait._validate(obj, v) + except TraitError: + self.element_error(obj, v, self._trait) + else: + validated[key] = v + return self.klass(validated) + + def instance_init(self, obj): + if isinstance(self._trait, TraitType): + self._trait.this_class = self.this_class + if hasattr(self._trait, '_resolve_classes'): + self._trait._resolve_classes(obj) + super(Dict, self).instance_init(obj) + class EventfulDict(Instance): """An instance of an EventfulDict."""