diff --git a/IPython/utils/tests/test_traitlets.py b/IPython/utils/tests/test_traitlets.py index cdb0cd3..81f61da 100644 --- a/IPython/utils/tests/test_traitlets.py +++ b/IPython/utils/tests/test_traitlets.py @@ -32,7 +32,7 @@ from IPython.utils.traitlets import ( HasTraits, MetaHasTraits, TraitType, Any, CBytes, Dict, Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError, Undefined, Type, This, Instance, TCPAddress, List, Tuple, - ObjectName, DottedObjectName, CRegExp + ObjectName, DottedObjectName, CRegExp, bind ) from IPython.utils import py3compat from IPython.testing.decorators import skipif @@ -973,3 +973,103 @@ def test_dict_assignment(): d['a'] = 5 nt.assert_equal(d, c.value) nt.assert_true(c.value is d) + +class TestBind(TestCase): + def test_connect_same(self): + """Verify two traitlets of the same type can be bound together using bind.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = bind((a, 'value'), (b, 'value')) + + # Make sure the values are the same at the point of binding. + self.assertEqual(a.value, b.value) + + # Change one of the values to make sure they stay in sync. + a.value = 5 + self.assertEqual(a.value, b.value) + b.value = 6 + self.assertEqual(a.value, b.value) + + def test_bind_different(self): + """Verify two traitlets of different types can be bound together using bind.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + class B(HasTraits): + count = Int() + a = A(value=9) + b = B(count=8) + + # Conenct the two classes. + c = bind((a, 'value'), (b, 'count')) + + # Make sure the values are the same at the point of binding. + self.assertEqual(a.value, b.count) + + # Change one of the values to make sure they stay in sync. + a.value = 5 + self.assertEqual(a.value, b.count) + b.count = 4 + self.assertEqual(a.value, b.count) + + def test_unbind(self): + """Verify two binded traitlets can be unbinded.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = bind((a, 'value'), (b, 'value')) + a.value = 4 + c.unbind() + + # Change one of the values to make sure they stay in sync. + a.value = 5 + self.assertNotEqual(a.value, b.value) + + def test_callbacks(self): + """Verify two binded traitlets have their callbacks called once.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + class B(HasTraits): + count = Int() + a = A(value=9) + b = B(count=8) + + # Register callbacks that count. + callback_count = [] + def a_callback(name, old, new): + callback_count.append('a') + a.on_trait_change(a_callback, 'value') + def b_callback(name, old, new): + callback_count.append('b') + b.on_trait_change(b_callback, 'count') + + # Conenct the two classes. + c = bind((a, 'value'), (b, 'count')) + + # Make sure b's count was set to a's value once. + self.assertEqual(''.join(callback_count), 'b') + del callback_count[:] + + # Make sure a's value was set to b's count once. + b.count = 5 + self.assertEqual(''.join(callback_count), 'ba') + del callback_count[:] + + # Make sure b's count was set to a's value once. + a.value = 4 + self.assertEqual(''.join(callback_count), 'ab') + del callback_count[:] diff --git a/IPython/utils/traitlets.py b/IPython/utils/traitlets.py index 31fb5cf..30561ab 100644 --- a/IPython/utils/traitlets.py +++ b/IPython/utils/traitlets.py @@ -52,7 +52,7 @@ Authors: # Imports #----------------------------------------------------------------------------- - +import contextlib import inspect import re import sys @@ -67,6 +67,7 @@ except: from .importstring import import_item from IPython.utils import py3compat from IPython.utils.py3compat import iteritems +from IPython.testing.skipdoctest import skip_doctest SequenceTypes = (list, tuple, set, frozenset) @@ -182,6 +183,60 @@ def getmembers(object, predicate=None): results.sort() return results +@skip_doctest +class bind(object): + """Bind traits from different objects together so they remain in sync. + + Parameters + ---------- + obj : pairs of objects/attributes + + Examples + -------- + + >>> c = bind((obj1, 'value'), (obj2, 'value'), (obj3, 'value')) + >>> obj1.value = 5 # updates other objects as well + """ + updating = False + def __init__(self, *args): + if len(args) < 2: + raise TypeError('At least two traitlets must be provided.') + + self.objects = {} + initial = getattr(args[0][0], args[0][1]) + for obj,attr in args: + if getattr(obj, attr) != initial: + setattr(obj, attr, initial) + + callback = self._make_closure(obj,attr) + obj.on_trait_change(callback, attr) + self.objects[(obj,attr)] = callback + + @contextlib.contextmanager + def _busy_updating(self): + self.updating = True + try: + yield + finally: + self.updating = False + + def _make_closure(self, sending_obj, sending_attr): + def update(name, old, new): + self._update(sending_obj, sending_attr, new) + return update + + def _update(self, sending_obj, sending_attr, new): + if self.updating: + return + with self._busy_updating(): + for obj,attr in self.objects.keys(): + if obj is not sending_obj or attr != sending_attr: + setattr(obj, attr, new) + + def unbind(self): + for key, callback in self.objects.items(): + (obj,attr) = key + obj.on_trait_change(callback, attr, remove=True) #----------------------------------------------------------------------------- # Base TraitType for all traits