diff --git a/traitlets/tests/test_traitlets.py b/traitlets/tests/test_traitlets.py index 6333d2a..1a32fff 100644 --- a/traitlets/tests/test_traitlets.py +++ b/traitlets/tests/test_traitlets.py @@ -1366,7 +1366,7 @@ def test_hold_trait_notifications(): nt.assert_equal(t.a, 4) nt.assert_equal(changes, []) - nt.assert_equal(changes, [(3,4)]) + nt.assert_equal(changes, [(0, 4)]) # Test roll-back try: with t.hold_trait_notifications(): diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index 11dba4c..1869129 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -597,50 +597,49 @@ class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)): yield return else: - self._cross_validation_lock = True cache = {} - notifications = {} _notify_trait = self._notify_trait - def cache_values(*a): - cache[a[0]] = a - - def hold_notifications(*a): - notifications[a[0]] = a + def merge(previous, current): + """merges notifications of the form (name, old, value)""" + if previous is None: + return current + else: + return (current[0], previous[1], current[2]) - self._notify_trait = cache_values + def hold(*a): + cache[a[0]] = merge(cache.get(a[0]), a) try: + self._notify_trait = hold + self._cross_validation_lock = True yield + for name in cache: + if hasattr(self, '_%s_validate' % name): + cross_validate = getattr(self, '_%s_validate' % name) + setattr(self, name, cross_validate(getattr(self, name), self)) + except TraitError as e: + self._notify_trait = lambda *x: None + for name in cache: + if cache[name][1] is not Undefined: + setattr(self, name, cache[name][1]) + else: + delattr(self, name) + cache = {} + raise e finally: - try: - self._notify_trait = hold_notifications - for name in cache: - if hasattr(self, '_%s_validate' % name): - cross_validate = getattr(self, '_%s_validate' % name) - setattr(self, name, cross_validate(getattr(self, name), self)) - except TraitError as e: - self._notify_trait = lambda *x: None - for name in cache: - if cache[name][1] is not Undefined: - setattr(self, name, cache[name][1]) - else: - delattr(self, name) - cache = {} - notifications = {} - raise e - finally: - self._notify_trait = _notify_trait - self._cross_validation_lock = False - if isinstance(_notify_trait, types.MethodType): - # FIXME: remove when support is bumped to 3.4. - # when original method is restored, - # remove the redundant value from __dict__ - # (only used to preserve pickleability on Python < 3.4) - self.__dict__.pop('_notify_trait', None) - # trigger delayed notifications - for v in dict(cache, **notifications).values(): - self._notify_trait(*v) + self._notify_trait = _notify_trait + self._cross_validation_lock = False + if isinstance(_notify_trait, types.MethodType): + # FIXME: remove when support is bumped to 3.4. + # when original method is restored, + # remove the redundant value from __dict__ + # (only used to preserve pickleability on Python < 3.4) + self.__dict__.pop('_notify_trait', None) + + # trigger delayed notifications + for v in cache.values(): + self._notify_trait(*v) def _notify_trait(self, name, old_value, new_value):