##// END OF EJS Templates
Fix equals sign and clarify where the number of restart attempts comes from
Fix equals sign and clarify where the number of restart attempts comes from

File last commit:

r17673:1f9294c2
r18235:b839870a
Show More
test_traitlets.py
1308 lines | 35.4 KiB | text/x-python | PythonLexer
# encoding: utf-8
"""Tests for IPython.utils.traitlets."""
# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.
#
# Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
# also under the terms of the Modified BSD License.
import pickle
import re
import sys
from unittest import TestCase
import nose.tools as nt
from nose import SkipTest
from IPython.utils.traitlets import (
HasTraits, MetaHasTraits, TraitType, Any, CBytes, Dict,
Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
Undefined, Type, This, Instance, TCPAddress, List, Tuple,
ObjectName, DottedObjectName, CRegExp, link, directional_link,
EventfulList, EventfulDict
)
from IPython.utils import py3compat
from IPython.testing.decorators import skipif
#-----------------------------------------------------------------------------
# Helper classes for testing
#-----------------------------------------------------------------------------
class HasTraitsStub(HasTraits):
def _notify_trait(self, name, old, new):
self._notify_name = name
self._notify_old = old
self._notify_new = new
#-----------------------------------------------------------------------------
# Test classes
#-----------------------------------------------------------------------------
class TestTraitType(TestCase):
def test_get_undefined(self):
class A(HasTraits):
a = TraitType
a = A()
self.assertEqual(a.a, Undefined)
def test_set(self):
class A(HasTraitsStub):
a = TraitType
a = A()
a.a = 10
self.assertEqual(a.a, 10)
self.assertEqual(a._notify_name, 'a')
self.assertEqual(a._notify_old, Undefined)
self.assertEqual(a._notify_new, 10)
def test_validate(self):
class MyTT(TraitType):
def validate(self, inst, value):
return -1
class A(HasTraitsStub):
tt = MyTT
a = A()
a.tt = 10
self.assertEqual(a.tt, -1)
def test_default_validate(self):
class MyIntTT(TraitType):
def validate(self, obj, value):
if isinstance(value, int):
return value
self.error(obj, value)
class A(HasTraits):
tt = MyIntTT(10)
a = A()
self.assertEqual(a.tt, 10)
# Defaults are validated when the HasTraits is instantiated
class B(HasTraits):
tt = MyIntTT('bad default')
self.assertRaises(TraitError, B)
def test_is_valid_for(self):
class MyTT(TraitType):
def is_valid_for(self, value):
return True
class A(HasTraits):
tt = MyTT
a = A()
a.tt = 10
self.assertEqual(a.tt, 10)
def test_value_for(self):
class MyTT(TraitType):
def value_for(self, value):
return 20
class A(HasTraits):
tt = MyTT
a = A()
a.tt = 10
self.assertEqual(a.tt, 20)
def test_info(self):
class A(HasTraits):
tt = TraitType
a = A()
self.assertEqual(A.tt.info(), 'any value')
def test_error(self):
class A(HasTraits):
tt = TraitType
a = A()
self.assertRaises(TraitError, A.tt.error, a, 10)
def test_dynamic_initializer(self):
class A(HasTraits):
x = Int(10)
def _x_default(self):
return 11
class B(A):
x = Int(20)
class C(A):
def _x_default(self):
return 21
a = A()
self.assertEqual(a._trait_values, {})
self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
self.assertEqual(a.x, 11)
self.assertEqual(a._trait_values, {'x': 11})
b = B()
self.assertEqual(b._trait_values, {'x': 20})
self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
self.assertEqual(b.x, 20)
c = C()
self.assertEqual(c._trait_values, {})
self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
self.assertEqual(c.x, 21)
self.assertEqual(c._trait_values, {'x': 21})
# Ensure that the base class remains unmolested when the _default
# initializer gets overridden in a subclass.
a = A()
c = C()
self.assertEqual(a._trait_values, {})
self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
self.assertEqual(a.x, 11)
self.assertEqual(a._trait_values, {'x': 11})
class TestHasTraitsMeta(TestCase):
def test_metaclass(self):
self.assertEqual(type(HasTraits), MetaHasTraits)
class A(HasTraits):
a = Int
a = A()
self.assertEqual(type(a.__class__), MetaHasTraits)
self.assertEqual(a.a,0)
a.a = 10
self.assertEqual(a.a,10)
class B(HasTraits):
b = Int()
b = B()
self.assertEqual(b.b,0)
b.b = 10
self.assertEqual(b.b,10)
class C(HasTraits):
c = Int(30)
c = C()
self.assertEqual(c.c,30)
c.c = 10
self.assertEqual(c.c,10)
def test_this_class(self):
class A(HasTraits):
t = This()
tt = This()
class B(A):
tt = This()
ttt = This()
self.assertEqual(A.t.this_class, A)
self.assertEqual(B.t.this_class, A)
self.assertEqual(B.tt.this_class, B)
self.assertEqual(B.ttt.this_class, B)
class TestHasTraitsNotify(TestCase):
def setUp(self):
self._notify1 = []
self._notify2 = []
def notify1(self, name, old, new):
self._notify1.append((name, old, new))
def notify2(self, name, old, new):
self._notify2.append((name, old, new))
def test_notify_all(self):
class A(HasTraits):
a = Int
b = Float
a = A()
a.on_trait_change(self.notify1)
a.a = 0
self.assertEqual(len(self._notify1),0)
a.b = 0.0
self.assertEqual(len(self._notify1),0)
a.a = 10
self.assertTrue(('a',0,10) in self._notify1)
a.b = 10.0
self.assertTrue(('b',0.0,10.0) in self._notify1)
self.assertRaises(TraitError,setattr,a,'a','bad string')
self.assertRaises(TraitError,setattr,a,'b','bad string')
self._notify1 = []
a.on_trait_change(self.notify1,remove=True)
a.a = 20
a.b = 20.0
self.assertEqual(len(self._notify1),0)
def test_notify_one(self):
class A(HasTraits):
a = Int
b = Float
a = A()
a.on_trait_change(self.notify1, 'a')
a.a = 0
self.assertEqual(len(self._notify1),0)
a.a = 10
self.assertTrue(('a',0,10) in self._notify1)
self.assertRaises(TraitError,setattr,a,'a','bad string')
def test_subclass(self):
class A(HasTraits):
a = Int
class B(A):
b = Float
b = B()
self.assertEqual(b.a,0)
self.assertEqual(b.b,0.0)
b.a = 100
b.b = 100.0
self.assertEqual(b.a,100)
self.assertEqual(b.b,100.0)
def test_notify_subclass(self):
class A(HasTraits):
a = Int
class B(A):
b = Float
b = B()
b.on_trait_change(self.notify1, 'a')
b.on_trait_change(self.notify2, 'b')
b.a = 0
b.b = 0.0
self.assertEqual(len(self._notify1),0)
self.assertEqual(len(self._notify2),0)
b.a = 10
b.b = 10.0
self.assertTrue(('a',0,10) in self._notify1)
self.assertTrue(('b',0.0,10.0) in self._notify2)
def test_static_notify(self):
class A(HasTraits):
a = Int
_notify1 = []
def _a_changed(self, name, old, new):
self._notify1.append((name, old, new))
a = A()
a.a = 0
# This is broken!!!
self.assertEqual(len(a._notify1),0)
a.a = 10
self.assertTrue(('a',0,10) in a._notify1)
class B(A):
b = Float
_notify2 = []
def _b_changed(self, name, old, new):
self._notify2.append((name, old, new))
b = B()
b.a = 10
b.b = 10.0
self.assertTrue(('a',0,10) in b._notify1)
self.assertTrue(('b',0.0,10.0) in b._notify2)
def test_notify_args(self):
def callback0():
self.cb = ()
def callback1(name):
self.cb = (name,)
def callback2(name, new):
self.cb = (name, new)
def callback3(name, old, new):
self.cb = (name, old, new)
class A(HasTraits):
a = Int
a = A()
a.on_trait_change(callback0, 'a')
a.a = 10
self.assertEqual(self.cb,())
a.on_trait_change(callback0, 'a', remove=True)
a.on_trait_change(callback1, 'a')
a.a = 100
self.assertEqual(self.cb,('a',))
a.on_trait_change(callback1, 'a', remove=True)
a.on_trait_change(callback2, 'a')
a.a = 1000
self.assertEqual(self.cb,('a',1000))
a.on_trait_change(callback2, 'a', remove=True)
a.on_trait_change(callback3, 'a')
a.a = 10000
self.assertEqual(self.cb,('a',1000,10000))
a.on_trait_change(callback3, 'a', remove=True)
self.assertEqual(len(a._trait_notifiers['a']),0)
def test_notify_only_once(self):
class A(HasTraits):
listen_to = ['a']
a = Int(0)
b = 0
def __init__(self, **kwargs):
super(A, self).__init__(**kwargs)
self.on_trait_change(self.listener1, ['a'])
def listener1(self, name, old, new):
self.b += 1
class B(A):
c = 0
d = 0
def __init__(self, **kwargs):
super(B, self).__init__(**kwargs)
self.on_trait_change(self.listener2)
def listener2(self, name, old, new):
self.c += 1
def _a_changed(self, name, old, new):
self.d += 1
b = B()
b.a += 1
self.assertEqual(b.b, b.c)
self.assertEqual(b.b, b.d)
b.a += 1
self.assertEqual(b.b, b.c)
self.assertEqual(b.b, b.d)
class TestHasTraits(TestCase):
def test_trait_names(self):
class A(HasTraits):
i = Int
f = Float
a = A()
self.assertEqual(sorted(a.trait_names()),['f','i'])
self.assertEqual(sorted(A.class_trait_names()),['f','i'])
def test_trait_metadata(self):
class A(HasTraits):
i = Int(config_key='MY_VALUE')
a = A()
self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
def test_trait_metadata_default(self):
class A(HasTraits):
i = Int()
a = A()
self.assertEqual(a.trait_metadata('i', 'config_key'), None)
self.assertEqual(a.trait_metadata('i', 'config_key', 'default'), 'default')
def test_traits(self):
class A(HasTraits):
i = Int
f = Float
a = A()
self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
def test_traits_metadata(self):
class A(HasTraits):
i = Int(config_key='VALUE1', other_thing='VALUE2')
f = Float(config_key='VALUE3', other_thing='VALUE2')
j = Int(0)
a = A()
self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
self.assertEqual(traits, dict(i=A.i))
# This passes, but it shouldn't because I am replicating a bug in
# traits.
traits = a.traits(config_key=lambda v: True)
self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
def test_init(self):
class A(HasTraits):
i = Int()
x = Float()
a = A(i=1, x=10.0)
self.assertEqual(a.i, 1)
self.assertEqual(a.x, 10.0)
def test_positional_args(self):
class A(HasTraits):
i = Int(0)
def __init__(self, i):
super(A, self).__init__()
self.i = i
a = A(5)
self.assertEqual(a.i, 5)
# should raise TypeError if no positional arg given
self.assertRaises(TypeError, A)
#-----------------------------------------------------------------------------
# Tests for specific trait types
#-----------------------------------------------------------------------------
class TestType(TestCase):
def test_default(self):
class B(object): pass
class A(HasTraits):
klass = Type
a = A()
self.assertEqual(a.klass, None)
a.klass = B
self.assertEqual(a.klass, B)
self.assertRaises(TraitError, setattr, a, 'klass', 10)
def test_value(self):
class B(object): pass
class C(object): pass
class A(HasTraits):
klass = Type(B)
a = A()
self.assertEqual(a.klass, B)
self.assertRaises(TraitError, setattr, a, 'klass', C)
self.assertRaises(TraitError, setattr, a, 'klass', object)
a.klass = B
def test_allow_none(self):
class B(object): pass
class C(B): pass
class A(HasTraits):
klass = Type(B, allow_none=False)
a = A()
self.assertEqual(a.klass, B)
self.assertRaises(TraitError, setattr, a, 'klass', None)
a.klass = C
self.assertEqual(a.klass, C)
def test_validate_klass(self):
class A(HasTraits):
klass = Type('no strings allowed')
self.assertRaises(ImportError, A)
class A(HasTraits):
klass = Type('rub.adub.Duck')
self.assertRaises(ImportError, A)
def test_validate_default(self):
class B(object): pass
class A(HasTraits):
klass = Type('bad default', B)
self.assertRaises(ImportError, A)
class C(HasTraits):
klass = Type(None, B, allow_none=False)
self.assertRaises(TraitError, C)
def test_str_klass(self):
class A(HasTraits):
klass = Type('IPython.utils.ipstruct.Struct')
from IPython.utils.ipstruct import Struct
a = A()
a.klass = Struct
self.assertEqual(a.klass, Struct)
self.assertRaises(TraitError, setattr, a, 'klass', 10)
def test_set_str_klass(self):
class A(HasTraits):
klass = Type()
a = A(klass='IPython.utils.ipstruct.Struct')
from IPython.utils.ipstruct import Struct
self.assertEqual(a.klass, Struct)
class TestInstance(TestCase):
def test_basic(self):
class Foo(object): pass
class Bar(Foo): pass
class Bah(object): pass
class A(HasTraits):
inst = Instance(Foo)
a = A()
self.assertTrue(a.inst is None)
a.inst = Foo()
self.assertTrue(isinstance(a.inst, Foo))
a.inst = Bar()
self.assertTrue(isinstance(a.inst, Foo))
self.assertRaises(TraitError, setattr, a, 'inst', Foo)
self.assertRaises(TraitError, setattr, a, 'inst', Bar)
self.assertRaises(TraitError, setattr, a, 'inst', Bah())
def test_default_klass(self):
class Foo(object): pass
class Bar(Foo): pass
class Bah(object): pass
class FooInstance(Instance):
klass = Foo
class A(HasTraits):
inst = FooInstance()
a = A()
self.assertTrue(a.inst is None)
a.inst = Foo()
self.assertTrue(isinstance(a.inst, Foo))
a.inst = Bar()
self.assertTrue(isinstance(a.inst, Foo))
self.assertRaises(TraitError, setattr, a, 'inst', Foo)
self.assertRaises(TraitError, setattr, a, 'inst', Bar)
self.assertRaises(TraitError, setattr, a, 'inst', Bah())
def test_unique_default_value(self):
class Foo(object): pass
class A(HasTraits):
inst = Instance(Foo,(),{})
a = A()
b = A()
self.assertTrue(a.inst is not b.inst)
def test_args_kw(self):
class Foo(object):
def __init__(self, c): self.c = c
class Bar(object): pass
class Bah(object):
def __init__(self, c, d):
self.c = c; self.d = d
class A(HasTraits):
inst = Instance(Foo, (10,))
a = A()
self.assertEqual(a.inst.c, 10)
class B(HasTraits):
inst = Instance(Bah, args=(10,), kw=dict(d=20))
b = B()
self.assertEqual(b.inst.c, 10)
self.assertEqual(b.inst.d, 20)
class C(HasTraits):
inst = Instance(Foo)
c = C()
self.assertTrue(c.inst is None)
def test_bad_default(self):
class Foo(object): pass
class A(HasTraits):
inst = Instance(Foo, allow_none=False)
self.assertRaises(TraitError, A)
def test_instance(self):
class Foo(object): pass
def inner():
class A(HasTraits):
inst = Instance(Foo())
self.assertRaises(TraitError, inner)
class TestThis(TestCase):
def test_this_class(self):
class Foo(HasTraits):
this = This
f = Foo()
self.assertEqual(f.this, None)
g = Foo()
f.this = g
self.assertEqual(f.this, g)
self.assertRaises(TraitError, setattr, f, 'this', 10)
def test_this_inst(self):
class Foo(HasTraits):
this = This()
f = Foo()
f.this = Foo()
self.assertTrue(isinstance(f.this, Foo))
def test_subclass(self):
class Foo(HasTraits):
t = This()
class Bar(Foo):
pass
f = Foo()
b = Bar()
f.t = b
b.t = f
self.assertEqual(f.t, b)
self.assertEqual(b.t, f)
def test_subclass_override(self):
class Foo(HasTraits):
t = This()
class Bar(Foo):
t = This()
f = Foo()
b = Bar()
f.t = b
self.assertEqual(f.t, b)
self.assertRaises(TraitError, setattr, b, 't', f)
class TraitTestBase(TestCase):
"""A best testing class for basic trait types."""
def assign(self, value):
self.obj.value = value
def coerce(self, value):
return value
def test_good_values(self):
if hasattr(self, '_good_values'):
for value in self._good_values:
self.assign(value)
self.assertEqual(self.obj.value, self.coerce(value))
def test_bad_values(self):
if hasattr(self, '_bad_values'):
for value in self._bad_values:
try:
self.assertRaises(TraitError, self.assign, value)
except AssertionError:
assert False, value
def test_default_value(self):
if hasattr(self, '_default_value'):
self.assertEqual(self._default_value, self.obj.value)
def test_allow_none(self):
if (hasattr(self, '_bad_values') and hasattr(self, '_good_values') and
None in self._bad_values):
trait=self.obj.traits()['value']
try:
trait.allow_none = True
self._bad_values.remove(None)
#skip coerce. Allow None casts None to None.
self.assign(None)
self.assertEqual(self.obj.value,None)
self.test_good_values()
self.test_bad_values()
finally:
#tear down
trait.allow_none = False
self._bad_values.append(None)
def tearDown(self):
# restore default value after tests, if set
if hasattr(self, '_default_value'):
self.obj.value = self._default_value
class AnyTrait(HasTraits):
value = Any
class AnyTraitTest(TraitTestBase):
obj = AnyTrait()
_default_value = None
_good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
_bad_values = []
class IntTrait(HasTraits):
value = Int(99)
class TestInt(TraitTestBase):
obj = IntTrait()
_default_value = 99
_good_values = [10, -10]
_bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
if not py3compat.PY3:
_bad_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
class LongTrait(HasTraits):
value = Long(99 if py3compat.PY3 else long(99))
class TestLong(TraitTestBase):
obj = LongTrait()
_default_value = 99 if py3compat.PY3 else long(99)
_good_values = [10, -10]
_bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
'-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
u'-10.1']
if not py3compat.PY3:
# maxint undefined on py3, because int == long
_good_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
_bad_values.extend([[long(10)], (long(10),)])
@skipif(py3compat.PY3, "not relevant on py3")
def test_cast_small(self):
"""Long casts ints to long"""
self.obj.value = 10
self.assertEqual(type(self.obj.value), long)
class IntegerTrait(HasTraits):
value = Integer(1)
class TestInteger(TestLong):
obj = IntegerTrait()
_default_value = 1
def coerce(self, n):
return int(n)
@skipif(py3compat.PY3, "not relevant on py3")
def test_cast_small(self):
"""Integer casts small longs to int"""
if py3compat.PY3:
raise SkipTest("not relevant on py3")
self.obj.value = long(100)
self.assertEqual(type(self.obj.value), int)
class FloatTrait(HasTraits):
value = Float(99.0)
class TestFloat(TraitTestBase):
obj = FloatTrait()
_default_value = 99.0
_good_values = [10, -10, 10.1, -10.1]
_bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
if not py3compat.PY3:
_bad_values.extend([long(10), long(-10)])
class ComplexTrait(HasTraits):
value = Complex(99.0-99.0j)
class TestComplex(TraitTestBase):
obj = ComplexTrait()
_default_value = 99.0-99.0j
_good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
10.1j, 10.1+10.1j, 10.1-10.1j]
_bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
if not py3compat.PY3:
_bad_values.extend([long(10), long(-10)])
class BytesTrait(HasTraits):
value = Bytes(b'string')
class TestBytes(TraitTestBase):
obj = BytesTrait()
_default_value = b'string'
_good_values = [b'10', b'-10', b'10L',
b'-10L', b'10.1', b'-10.1', b'string']
_bad_values = [10, -10, 10.1, -10.1, 1j, [10],
['ten'],{'ten': 10},(10,), None, u'string']
if not py3compat.PY3:
_bad_values.extend([long(10), long(-10)])
class UnicodeTrait(HasTraits):
value = Unicode(u'unicode')
class TestUnicode(TraitTestBase):
obj = UnicodeTrait()
_default_value = u'unicode'
_good_values = ['10', '-10', '10L', '-10L', '10.1',
'-10.1', '', u'', 'string', u'string', u"€"]
_bad_values = [10, -10, 10.1, -10.1, 1j,
[10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
if not py3compat.PY3:
_bad_values.extend([long(10), long(-10)])
class ObjectNameTrait(HasTraits):
value = ObjectName("abc")
class TestObjectName(TraitTestBase):
obj = ObjectNameTrait()
_default_value = "abc"
_good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
_bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
None, object(), object]
if sys.version_info[0] < 3:
_bad_values.append(u"þ")
else:
_good_values.append(u"þ") # þ=1 is valid in Python 3 (PEP 3131).
class DottedObjectNameTrait(HasTraits):
value = DottedObjectName("a.b")
class TestDottedObjectName(TraitTestBase):
obj = DottedObjectNameTrait()
_default_value = "a.b"
_good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
_bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None]
if sys.version_info[0] < 3:
_bad_values.append(u"t.þ")
else:
_good_values.append(u"t.þ")
class TCPAddressTrait(HasTraits):
value = TCPAddress()
class TestTCPAddress(TraitTestBase):
obj = TCPAddressTrait()
_default_value = ('127.0.0.1',0)
_good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
_bad_values = [(0,0),('localhost',10.0),('localhost',-1), None]
class ListTrait(HasTraits):
value = List(Int)
class TestList(TraitTestBase):
obj = ListTrait()
_default_value = []
_good_values = [[], [1], list(range(10)), (1,2)]
_bad_values = [10, [1,'a'], 'a']
def coerce(self, value):
if value is not None:
value = list(value)
return value
class Foo(object):
pass
class InstanceListTrait(HasTraits):
value = List(Instance(__name__+'.Foo'))
class TestInstanceList(TraitTestBase):
obj = InstanceListTrait()
def test_klass(self):
"""Test that the instance klass is properly assigned."""
self.assertIs(self.obj.traits()['value']._trait.klass, Foo)
_default_value = []
_good_values = [[Foo(), Foo(), None], None]
_bad_values = [['1', 2,], '1', [Foo]]
class LenListTrait(HasTraits):
value = List(Int, [0], minlen=1, maxlen=2)
class TestLenList(TraitTestBase):
obj = LenListTrait()
_default_value = [0]
_good_values = [[1], [1,2], (1,2)]
_bad_values = [10, [1,'a'], 'a', [], list(range(3))]
def coerce(self, value):
if value is not None:
value = list(value)
return value
class TupleTrait(HasTraits):
value = Tuple(Int(allow_none=True))
class TestTupleTrait(TraitTestBase):
obj = TupleTrait()
_default_value = None
_good_values = [(1,), None, (0,), [1], (None,)]
_bad_values = [10, (1,2), ('a'), ()]
def coerce(self, value):
if value is not None:
value = tuple(value)
return value
def test_invalid_args(self):
self.assertRaises(TypeError, Tuple, 5)
self.assertRaises(TypeError, Tuple, default_value='hello')
t = Tuple(Int, CBytes, default_value=(1,5))
class LooseTupleTrait(HasTraits):
value = Tuple((1,2,3))
class TestLooseTupleTrait(TraitTestBase):
obj = LooseTupleTrait()
_default_value = (1,2,3)
_good_values = [(1,), None, [1], (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
_bad_values = [10, 'hello', {}]
def coerce(self, value):
if value is not None:
value = tuple(value)
return value
def test_invalid_args(self):
self.assertRaises(TypeError, Tuple, 5)
self.assertRaises(TypeError, Tuple, default_value='hello')
t = Tuple(Int, CBytes, default_value=(1,5))
class MultiTupleTrait(HasTraits):
value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
class TestMultiTuple(TraitTestBase):
obj = MultiTupleTrait()
_default_value = (99,b'bottles')
_good_values = [(1,b'a'), (2,b'b')]
_bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
class CRegExpTrait(HasTraits):
value = CRegExp(r'')
class TestCRegExp(TraitTestBase):
def coerce(self, value):
return re.compile(value)
obj = CRegExpTrait()
_default_value = re.compile(r'')
_good_values = [r'\d+', re.compile(r'\d+')]
_bad_values = ['(', None, ()]
class DictTrait(HasTraits):
value = Dict()
def test_dict_assignment():
d = dict()
c = DictTrait()
c.value = d
d['a'] = 5
nt.assert_equal(d, c.value)
nt.assert_true(c.value is d)
class TestLink(TestCase):
def test_connect_same(self):
"""Verify two traitlets of the same type can be linked together using link."""
# Create two simple classes with Int traitlets.
class A(HasTraits):
value = Int()
a = A(value=9)
b = A(value=8)
# Conenct the two classes.
c = link((a, 'value'), (b, 'value'))
# Make sure the values are the same at the point of linking.
self.assertEqual(a.value, b.value)
# Change one of the values to make sure they stay in sync.
a.value = 5
self.assertEqual(a.value, b.value)
b.value = 6
self.assertEqual(a.value, b.value)
def test_link_different(self):
"""Verify two traitlets of different types can be linked together using link."""
# Create two simple classes with Int traitlets.
class A(HasTraits):
value = Int()
class B(HasTraits):
count = Int()
a = A(value=9)
b = B(count=8)
# Conenct the two classes.
c = link((a, 'value'), (b, 'count'))
# Make sure the values are the same at the point of linking.
self.assertEqual(a.value, b.count)
# Change one of the values to make sure they stay in sync.
a.value = 5
self.assertEqual(a.value, b.count)
b.count = 4
self.assertEqual(a.value, b.count)
def test_unlink(self):
"""Verify two linked traitlets can be unlinked."""
# Create two simple classes with Int traitlets.
class A(HasTraits):
value = Int()
a = A(value=9)
b = A(value=8)
# Connect the two classes.
c = link((a, 'value'), (b, 'value'))
a.value = 4
c.unlink()
# Change one of the values to make sure they don't stay in sync.
a.value = 5
self.assertNotEqual(a.value, b.value)
def test_callbacks(self):
"""Verify two linked traitlets have their callbacks called once."""
# Create two simple classes with Int traitlets.
class A(HasTraits):
value = Int()
class B(HasTraits):
count = Int()
a = A(value=9)
b = B(count=8)
# Register callbacks that count.
callback_count = []
def a_callback(name, old, new):
callback_count.append('a')
a.on_trait_change(a_callback, 'value')
def b_callback(name, old, new):
callback_count.append('b')
b.on_trait_change(b_callback, 'count')
# Connect the two classes.
c = link((a, 'value'), (b, 'count'))
# Make sure b's count was set to a's value once.
self.assertEqual(''.join(callback_count), 'b')
del callback_count[:]
# Make sure a's value was set to b's count once.
b.count = 5
self.assertEqual(''.join(callback_count), 'ba')
del callback_count[:]
# Make sure b's count was set to a's value once.
a.value = 4
self.assertEqual(''.join(callback_count), 'ab')
del callback_count[:]
class TestDirectionalLink(TestCase):
def test_connect_same(self):
"""Verify two traitlets of the same type can be linked together using directional_link."""
# Create two simple classes with Int traitlets.
class A(HasTraits):
value = Int()
a = A(value=9)
b = A(value=8)
# Conenct the two classes.
c = directional_link((a, 'value'), (b, 'value'))
# Make sure the values are the same at the point of linking.
self.assertEqual(a.value, b.value)
# Change one the value of the source and check that it synchronizes the target.
a.value = 5
self.assertEqual(b.value, 5)
# Change one the value of the target and check that it has no impact on the source
b.value = 6
self.assertEqual(a.value, 5)
def test_link_different(self):
"""Verify two traitlets of different types can be linked together using link."""
# Create two simple classes with Int traitlets.
class A(HasTraits):
value = Int()
class B(HasTraits):
count = Int()
a = A(value=9)
b = B(count=8)
# Conenct the two classes.
c = directional_link((a, 'value'), (b, 'count'))
# Make sure the values are the same at the point of linking.
self.assertEqual(a.value, b.count)
# Change one the value of the source and check that it synchronizes the target.
a.value = 5
self.assertEqual(b.count, 5)
# Change one the value of the target and check that it has no impact on the source
b.value = 6
self.assertEqual(a.value, 5)
def test_unlink(self):
"""Verify two linked traitlets can be unlinked."""
# Create two simple classes with Int traitlets.
class A(HasTraits):
value = Int()
a = A(value=9)
b = A(value=8)
# Connect the two classes.
c = directional_link((a, 'value'), (b, 'value'))
a.value = 4
c.unlink()
# Change one of the values to make sure they don't stay in sync.
a.value = 5
self.assertNotEqual(a.value, b.value)
class Pickleable(HasTraits):
i = Int()
j = Int()
def _i_default(self):
return 1
def _i_changed(self, name, old, new):
self.j = new
def test_pickle_hastraits():
c = Pickleable()
for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
p = pickle.dumps(c, protocol)
c2 = pickle.loads(p)
nt.assert_equal(c2.i, c.i)
nt.assert_equal(c2.j, c.j)
c.i = 5
for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
p = pickle.dumps(c, protocol)
c2 = pickle.loads(p)
nt.assert_equal(c2.i, c.i)
nt.assert_equal(c2.j, c.j)
class TestEventful(TestCase):
def test_list(self):
"""Does the EventfulList work?"""
event_cache = []
class A(HasTraits):
x = EventfulList([c for c in 'abc'])
a = A()
a.x.on_events(lambda i, x: event_cache.append('insert'), \
lambda i, x: event_cache.append('set'), \
lambda i: event_cache.append('del'), \
lambda: event_cache.append('reverse'), \
lambda *p, **k: event_cache.append('sort'))
a.x.remove('c')
# ab
a.x.insert(0, 'z')
# zab
del a.x[1]
# zb
a.x.reverse()
# bz
a.x[1] = 'o'
# bo
a.x.append('a')
# boa
a.x.sort()
# abo
# Were the correct events captured?
self.assertEqual(event_cache, ['del', 'insert', 'del', 'reverse', 'set', 'set', 'sort'])
# Is the output correct?
self.assertEqual(a.x, [c for c in 'abo'])
def test_dict(self):
"""Does the EventfulDict work?"""
event_cache = []
class A(HasTraits):
x = EventfulDict({c: c for c in 'abc'})
a = A()
a.x.on_events(lambda k, v: event_cache.append('add'), \
lambda k, v: event_cache.append('set'), \
lambda k: event_cache.append('del'))
del a.x['c']
# ab
a.x['z'] = 1
# abz
a.x['z'] = 'z'
# abz
a.x.pop('a')
# bz
# Were the correct events captured?
self.assertEqual(event_cache, ['del', 'add', 'set', 'del'])
# Is the output correct?
self.assertEqual(a.x, {c: c for c in 'bz'})