From 466a0e3864a2d762d4a9331391e912cd8b1d3f00 2014-02-03 21:36:41 From: Jonathan Frederic Date: 2014-02-03 21:36:41 Subject: [PATCH] Use closure to prevent traitlet callbacks from firing twice. --- diff --git a/IPython/utils/tests/test_traitlets.py b/IPython/utils/tests/test_traitlets.py index 9d4e901..81f61da 100644 --- a/IPython/utils/tests/test_traitlets.py +++ b/IPython/utils/tests/test_traitlets.py @@ -993,6 +993,8 @@ class TestBind(TestCase): # Change one of the values to make sure they stay in sync. a.value = 5 self.assertEqual(a.value, b.value) + b.value = 6 + self.assertEqual(a.value, b.value) def test_bind_different(self): """Verify two traitlets of different types can be bound together using bind.""" @@ -1014,6 +1016,8 @@ class TestBind(TestCase): # Change one of the values to make sure they stay in sync. a.value = 5 self.assertEqual(a.value, b.count) + b.count = 4 + self.assertEqual(a.value, b.count) def test_unbind(self): """Verify two binded traitlets can be unbinded.""" @@ -1032,3 +1036,40 @@ class TestBind(TestCase): # Change one of the values to make sure they stay in sync. a.value = 5 self.assertNotEqual(a.value, b.value) + + def test_callbacks(self): + """Verify two binded traitlets have their callbacks called once.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + class B(HasTraits): + count = Int() + a = A(value=9) + b = B(count=8) + + # Register callbacks that count. + callback_count = [] + def a_callback(name, old, new): + callback_count.append('a') + a.on_trait_change(a_callback, 'value') + def b_callback(name, old, new): + callback_count.append('b') + b.on_trait_change(b_callback, 'count') + + # Conenct the two classes. + c = bind((a, 'value'), (b, 'count')) + + # Make sure b's count was set to a's value once. + self.assertEqual(''.join(callback_count), 'b') + del callback_count[:] + + # Make sure a's value was set to b's count once. + b.count = 5 + self.assertEqual(''.join(callback_count), 'ba') + del callback_count[:] + + # Make sure b's count was set to a's value once. + a.value = 4 + self.assertEqual(''.join(callback_count), 'ab') + del callback_count[:] diff --git a/IPython/utils/traitlets.py b/IPython/utils/traitlets.py index ab3f947..30561ab 100644 --- a/IPython/utils/traitlets.py +++ b/IPython/utils/traitlets.py @@ -202,13 +202,15 @@ class bind(object): if len(args) < 2: raise TypeError('At least two traitlets must be provided.') - self.objects = args + self.objects = {} + initial = getattr(args[0][0], args[0][1]) for obj,attr in args: - obj.on_trait_change(self._update, attr) + if getattr(obj, attr) != initial: + setattr(obj, attr, initial) - # Syncronize the traitlets initially. - initial = getattr(args[0][0], args[0][1]) - self._update(args[0][1], initial, initial) + callback = self._make_closure(obj,attr) + obj.on_trait_change(callback, attr) + self.objects[(obj,attr)] = callback @contextlib.contextmanager def _busy_updating(self): @@ -218,16 +220,23 @@ class bind(object): finally: self.updating = False - def _update(self, name, old, new): + def _make_closure(self, sending_obj, sending_attr): + def update(name, old, new): + self._update(sending_obj, sending_attr, new) + return update + + def _update(self, sending_obj, sending_attr, new): if self.updating: return with self._busy_updating(): - for obj,attr in self.objects: - setattr(obj, attr, new) + for obj,attr in self.objects.keys(): + if obj is not sending_obj or attr != sending_attr: + setattr(obj, attr, new) def unbind(self): - for obj,attr in self.objects: - obj.on_trait_change(self._update, attr, remove=True) + for key, callback in self.objects.items(): + (obj,attr) = key + obj.on_trait_change(callback, attr, remove=True) #----------------------------------------------------------------------------- # Base TraitType for all traits