##// END OF EJS Templates
Added support for Type and Instance traitlets with testing.
Brian Granger -
Show More
@@ -30,7 +30,8 b' from unittest import TestCase'
30
30
31 from IPython.utils.traitlets import (
31 from IPython.utils.traitlets import (
32 HasTraitlets, MetaHasTraitlets, TraitletType, Any,
32 HasTraitlets, MetaHasTraitlets, TraitletType, Any,
33 Int, Long, Float, Complex, Str, Unicode, Bool, TraitletError
33 Int, Long, Float, Complex, Str, Unicode, Bool, TraitletError,
34 Undefined, Type, Instance
34 )
35 )
35
36
36
37
@@ -65,15 +66,15 b' class TestTraitletType(TestCase):'
65 self.tt.name = self.name
66 self.tt.name = self.name
66 self.hast = HasTraitletsStub()
67 self.hast = HasTraitletsStub()
67
68
68 def test_get(self):
69 def test_get_undefined(self):
69 value = self.tt.__get__(self.hast)
70 value = self.tt.__get__(self.hast)
70 self.assertEquals(value, None)
71 self.assertEquals(value, Undefined)
71
72
72 def test_set(self):
73 def test_set(self):
73 self.tt.__set__(self.hast, 10)
74 self.tt.__set__(self.hast, 10)
74 self.assertEquals(self.hast._traitlet_values[self.name],10)
75 self.assertEquals(self.hast._traitlet_values[self.name],10)
75 self.assertEquals(self.hast._notify_name,self.name)
76 self.assertEquals(self.hast._notify_name,self.name)
76 self.assertEquals(self.hast._notify_old,None)
77 self.assertEquals(self.hast._notify_old,Undefined)
77 self.assertEquals(self.hast._notify_new,10)
78 self.assertEquals(self.hast._notify_new,10)
78
79
79 def test_validate(self):
80 def test_validate(self):
@@ -85,6 +86,20 b' class TestTraitletType(TestCase):'
85 tt.__set__(self.hast, 10)
86 tt.__set__(self.hast, 10)
86 self.assertEquals(tt.__get__(self.hast),-1)
87 self.assertEquals(tt.__get__(self.hast),-1)
87
88
89 def test_default_validate(self):
90 class MyIntTT(TraitletType):
91 def validate(self, obj, value):
92 if isinstance(value, int):
93 return value
94 self.error(obj, value)
95 tt = MyIntTT(10)
96 tt.name = 'a'
97 self.assertEquals(tt.__get__(self.hast), 10)
98 tt = MyIntTT('bad default')
99 tt.name = 'b' # different name from 'a' as we want an unset dv
100 self.assertRaises(TraitletError, tt.__get__, self.hast)
101
102
88 def test_is_valid_for(self):
103 def test_is_valid_for(self):
89 class MyTT(TraitletType):
104 class MyTT(TraitletType):
90 def is_valid_for(self, value):
105 def is_valid_for(self, value):
@@ -311,7 +326,113 b' class TestAddTraitlet(TestCase):'
311 # Tests for specific traitlet types
326 # Tests for specific traitlet types
312 #-----------------------------------------------------------------------------
327 #-----------------------------------------------------------------------------
313
328
329
330 class TestType(TestCase):
331
332 def test_default(self):
333
334 class B(object): pass
335 class A(HasTraitlets):
336 klass = Type
337
338 a = A()
339 self.assertEquals(a.klass, None)
340 a.klass = B
341 self.assertEquals(a.klass, B)
342 self.assertRaises(TraitletError, setattr, a, 'klass', 10)
343
344 def test_value(self):
345
346 class B(object): pass
347 class C(object): pass
348 class A(HasTraitlets):
349 klass = Type(B)
350
351 a = A()
352 self.assertEquals(a.klass, B)
353 self.assertRaises(TraitletError, setattr, a, 'klass', C)
354 self.assertRaises(TraitletError, setattr, a, 'klass', object)
355 a.klass = B
356
357 def test_allow_none(self):
358
359 class B(object): pass
360 class C(B): pass
361 class A(HasTraitlets):
362 klass = Type(B, allow_none=False)
363
364 a = A()
365 self.assertEquals(a.klass, B)
366 self.assertRaises(TraitletError, setattr, a, 'klass', None)
367 a.klass = C
368 self.assertEquals(a.klass, C)
369
370
371 class TestInstance(TestCase):
372
373 def test_basic(self):
374 class Foo(object): pass
375 class Bar(Foo): pass
376 class Bah(object): pass
377
378 class A(HasTraitlets):
379 inst = Instance(Foo)
380
381 a = A()
382 self.assert_(isinstance(a.inst, Foo))
383 a.inst = Foo()
384 self.assert_(isinstance(a.inst, Foo))
385 a.inst = Bar()
386 self.assert_(isinstance(a.inst, Foo))
387 self.assertRaises(TraitletError, setattr, a, 'inst', Foo)
388 self.assertRaises(TraitletError, setattr, a, 'inst', Bar)
389 self.assertRaises(TraitletError, setattr, a, 'inst', Bah())
390
391 def test_unique_default_value(self):
392 class Foo(object): pass
393 class A(HasTraitlets):
394 inst = Instance(Foo)
395
396 a = A()
397 b = A()
398 self.assert_(a.inst is not b.inst)
399
400 def test_args_kw(self):
401 class Foo(object):
402 def __init__(self, c): self.c = c
403
404 class A(HasTraitlets):
405 inst = Instance(Foo, args=(10,))
406
407 a = A()
408 self.assertEquals(a.inst.c, 10)
409
410 class Bar(object):
411 def __init__(self, c, d):
412 self.c = c; self.d = d
413
414 class B(HasTraitlets):
415 inst = Instance(Bar, args=(10,),kw=dict(d=20))
416 b = B()
417 self.assertEquals(b.inst.c, 10)
418 self.assertEquals(b.inst.d, 20)
419
420 def test_instance(self):
421 # Does passing an instance yield a default value of None?
422 class Foo(object): pass
423
424 class A(HasTraitlets):
425 inst = Instance(Foo())
426 a = A()
427 self.assertEquals(a.inst, None)
428
429 class B(HasTraitlets):
430 inst = Instance(Foo(), allow_none=False)
431 b = B()
432 self.assertRaises(TraitletError, getattr, b, 'inst')
433
314 class TraitletTestBase(TestCase):
434 class TraitletTestBase(TestCase):
435 """A best testing class for basic traitlet types."""
315
436
316 def assign(self, value):
437 def assign(self, value):
317 self.obj.value = value
438 self.obj.value = value
@@ -42,9 +42,13 b' Authors:'
42 # Imports
42 # Imports
43 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
44
44
45
45 import inspect
46 import inspect
47 import sys
46 import types
48 import types
47 from types import InstanceType
49 from types import InstanceType, ClassType
50
51 ClassTypes = (ClassType, type)
48
52
49 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
50 # Basic classes
54 # Basic classes
@@ -124,6 +128,12 b' def parse_notifier_name(name):'
124 return name
128 return name
125
129
126
130
131 def get_module_name ( level = 2 ):
132 """ Returns the name of the module that the caller's caller is located in.
133 """
134 return sys._getframe( level ).f_globals.get( '__name__', '__main__' )
135
136
127 #-----------------------------------------------------------------------------
137 #-----------------------------------------------------------------------------
128 # Base TraitletType for all traitlets
138 # Base TraitletType for all traitlets
129 #-----------------------------------------------------------------------------
139 #-----------------------------------------------------------------------------
@@ -132,30 +142,55 b' def parse_notifier_name(name):'
132 class TraitletType(object):
142 class TraitletType(object):
133
143
134 metadata = {}
144 metadata = {}
135 default_value = None
145 default_value = Undefined
136 info_text = 'any value'
146 info_text = 'any value'
137
147
138 def __init__(self, default_value=NoDefaultSpecified, **metadata):
148 def __init__(self, default_value=NoDefaultSpecified, **metadata):
149 """Create a TraitletType.
150 """
139 if default_value is not NoDefaultSpecified:
151 if default_value is not NoDefaultSpecified:
140 self.default_value = default_value
152 self.default_value = default_value
141 self.metadata.update(metadata)
153 self.metadata.update(metadata)
154 self.init()
142
155
143 def __get__(self, inst, cls=None):
156 def init(self):
144 if inst is None:
157 pass
158
159 def get_default_value(self):
160 """Create a new instance of the default value."""
161 dv = self.default_value
162 return dv
163
164 def __get__(self, obj, cls=None):
165 """Get the value of the traitlet by self.name for the instance.
166
167 The creation of default values is deferred until this is called the
168 first time. This is done so instances of the parent HasTraitlets
169 will have their own default value instances.
170 """
171 if obj is None:
145 return self
172 return self
146 else:
173 else:
147 return inst._traitlet_values.get(self.name, self.default_value)
174 if not obj._traitlet_values.has_key(self.name):
175 dv = self.get_default_value()
176 self.__set__(obj, dv, first=True)
177 return dv
178 else:
179 return obj._traitlet_values[self.name]
148
180
149 def __set__(self, inst, value):
181 def __set__(self, obj, value, first=False):
150 new_value = self._validate(inst, value)
182 new_value = self._validate(obj, value)
151 old_value = self.__get__(inst)
183 if not first:
184 old_value = self.__get__(obj)
152 if old_value != new_value:
185 if old_value != new_value:
153 inst._traitlet_values[self.name] = new_value
186 obj._traitlet_values[self.name] = new_value
154 inst._notify(self.name, old_value, new_value)
187 obj._notify(self.name, old_value, new_value)
188 else:
189 obj._traitlet_values[self.name] = new_value
155
190
156 def _validate(self, inst, value):
191 def _validate(self, obj, value):
157 if hasattr(self, 'validate'):
192 if hasattr(self, 'validate'):
158 return self.validate(inst, value)
193 return self.validate(obj, value)
159 elif hasattr(self, 'is_valid_for'):
194 elif hasattr(self, 'is_valid_for'):
160 valid = self.is_valid_for(value)
195 valid = self.is_valid_for(value)
161 if valid:
196 if valid:
@@ -333,10 +368,291 b' class HasTraitlets(object):'
333 inst.name = name
368 inst.name = name
334 setattr(self.__class__, name, inst)
369 setattr(self.__class__, name, inst)
335
370
371
336 #-----------------------------------------------------------------------------
372 #-----------------------------------------------------------------------------
337 # Actual TraitletTypes implementations/subclasses
373 # Actual TraitletTypes implementations/subclasses
338 #-----------------------------------------------------------------------------
374 #-----------------------------------------------------------------------------
339
375
376 #-----------------------------------------------------------------------------
377 # TraitletTypes subclasses for handling classes and instances of classes
378 #-----------------------------------------------------------------------------
379
380
381 class BaseClassResolver(TraitletType):
382 """Mixin class for traitlets that need to resolve classes by strings.
383
384 This class provides is a mixin that provides its subclasses with the
385 ability to resolve classes by specifying a string name (for example,
386 'foo.bar.MyClass'). An actual class can also be resolved.
387
388 Any subclass must define instances with 'klass' and 'module' attributes
389 that contain the string name of the class (or actual class object) and
390 the module name that contained the original trait definition (used for
391 resolving local class names (e.g. 'LocalClass')).
392 """
393
394 def resolve_class(self, obj, value):
395 klass = self.validate_class(self.find_class(self.klass))
396 if klass is None:
397 self.validate_failed(obj, value)
398
399 self.klass = klass
400
401 def validate_class(self, klass):
402 return klass
403
404 def find_class(self, klass):
405 module = self.module
406 col = klass.rfind('.')
407 if col >= 0:
408 module = klass[ : col ]
409 klass = klass[ col + 1: ]
410
411 theClass = getattr(sys.modules.get(module), klass, None)
412 if (theClass is None) and (col >= 0):
413 try:
414 mod = __import__(module)
415 for component in module.split( '.' )[1:]:
416 mod = getattr(mod, component)
417
418 theClass = getattr(mod, klass, None)
419 except:
420 pass
421
422 return theClass
423
424 def validate_failed (self, obj, value):
425 kind = type(value)
426 if kind is InstanceType:
427 msg = 'class %s' % value.__class__.__name__
428 else:
429 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
430
431 self.error(obj, msg)
432
433
434 class Type(BaseClassResolver):
435 """A traitlet whose value must be a subclass of a specified class."""
436
437 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
438 """Construct a Type traitlet
439
440 A Type traitlet specifies that its values must be subclasses of
441 a particular class.
442
443 Parameters
444 ----------
445 default_value : class or None
446 The default value must be a subclass of klass.
447 klass : class, str, None
448 Values of this traitlet must be a subclass of klass. The klass
449 may be specified in a string like: 'foo.bar.MyClass'.
450 allow_none : boolean
451 Indicates whether None is allowed as an assignable value. Even if
452 ``False``, the default value may be ``None``.
453 """
454 if default_value is None:
455 if klass is None:
456 klass = object
457 elif klass is None:
458 klass = default_value
459
460 if isinstance(klass, basestring):
461 self.validate = self.resolve
462 elif not isinstance(klass, ClassTypes):
463 raise TraitletError("A Type traitlet must specify a class.")
464
465 self.klass = klass
466 self._allow_none = allow_none
467 self.module = get_module_name()
468
469 super(Type, self).__init__(default_value, **metadata)
470
471 def validate(self, obj, value):
472 """Validates that the value is a valid object instance."""
473 try:
474 if issubclass(value, self.klass):
475 return value
476 except:
477 if (value is None) and (self._allow_none):
478 return value
479
480 self.error(obj, value)
481
482 def resolve(self, obj, name, value):
483 """ Resolves a class originally specified as a string into an actual
484 class, then resets the trait so that future calls will be handled by
485 the normal validate method.
486 """
487 if isinstance(self.klass, basestring):
488 self.resolve_class(obj, value)
489 del self.validate
490
491 return self.validate(obj, value)
492
493 def info(self):
494 """ Returns a description of the trait."""
495 klass = self.klass
496 if not isinstance(klass, basestring):
497 klass = klass.__name__
498
499 result = 'a subclass of ' + klass
500
501 if self._allow_none:
502 return result + ' or None'
503
504 return result
505
506 def get_default_value(self):
507 """ Returns a tuple of the form: ( default_value_type, default_value )
508 which describes the default value for this trait.
509 """
510 if not isinstance(self.default_value, basestring):
511 return super(Type, self).get_default_value()
512
513 dv = self.resolve_default_value()
514 dvt = type(dv)
515 return (dvt, dv)
516
517 def resolve_default_value(self):
518 """ Resolves a class name into a class so that it can be used to
519 return the class as the default value of the trait.
520 """
521 if isinstance(self.klass, basestring):
522 try:
523 self.resolve_class(None, None)
524 del self.validate
525 except:
526 raise TraitletError('Could not resolve %s into a valid class' %
527 self.klass )
528
529 return self.klass
530
531
532 class DefaultValueGenerator(object):
533 """A class for generating new default value instances."""
534
535 def __init__(self, klass, *args, **kw):
536 self.klass = klass
537 self.args = args
538 self.kw = kw
539
540
541 class Instance(BaseClassResolver):
542 """A trait whose value must be an instance of a specified class.
543
544 The value can also be an instance of a subclass of the specified class.
545 """
546
547 def __init__(self, klass=None, args=None, kw=None, allow_none=True,
548 module = None, **metadata ):
549 """Construct an Instance traitlet.
550
551 Parameters
552 ----------
553 klass : class or instance
554 The object that forms the basis for the traitlet. If an instance
555 values must have isinstance(value, type(instance)).
556 args : tuple
557 Positional arguments for generating the default value.
558 kw : dict
559 Keyword arguments for generating the default value.
560 allow_none : bool
561 Indicates whether None is allowed as a value.
562
563 Default Value
564 -------------
565 If klass is an instance, default value is None. If klass is a class
566 then the default value is obtained by calling ``klass(*args, **kw)``.
567 If klass is a str, it is first resolved to an actual class and then
568 instantiated with ``klass(*args, **kw)``.
569 """
570
571 self._allow_none = allow_none
572 self.module = module or get_module_name()
573
574 if klass is None:
575 raise TraitletError('A %s traitlet must have a class specified.' %
576 self.__class__.__name__ )
577 elif not isinstance(klass, (basestring,) + ClassTypes ):
578 # klass is an instance so default value will be None
579 self.klass = klass.__class__
580 default_value = None
581 else:
582 # klass is a str or class so we handle args, kw
583 if args is None:
584 args = ()
585 if kw is None:
586 if isinstance(args, dict):
587 kw = args
588 args = ()
589 else:
590 kw = {}
591 if not isinstance(kw, dict):
592 raise TraitletError("The 'kw' argument must be a dict.")
593 if not isinstance(args, tuple):
594 raise TraitletError("The 'args' argument must be a tuple.")
595 self.klass = klass
596 # This tells my get_default_value that the default value
597 # instance needs to be generated when it is called. This
598 # is usually when TraitletType.__get__ is called for the 1st time.
599
600 default_value = DefaultValueGenerator(klass, *args, **kw)
601
602 super(Instance, self).__init__(default_value, **metadata)
603
604 def validate(self, obj, value):
605 if value is None:
606 if self._allow_none:
607 return value
608 self.validate_failed(obj, value)
609
610 # This is where self.klass is turned into a real class if it was
611 # a str initially. This happens the first time TraitletType.__set__
612 # is called. This does happen if a default value is generated by
613 # TraitletType.__get__.
614 if isinstance(self.klass, basestring):
615 self.resolve_class(obj, value)
616
617 if isinstance(value, self.klass):
618 return value
619 else:
620 self.validate_failed(obj, value)
621
622 def info ( self ):
623 klass = self.klass
624 if not isinstance( klass, basestring ):
625 klass = klass.__name__
626 result = class_of(klass)
627 if self._allow_none:
628 return result + ' or None'
629
630 return result
631
632 def get_default_value ( self ):
633 """Instantiate a default value instance.
634
635 When TraitletType.__get__ is called the first time, this is called
636 (if no value has been assigned) to get a default value instance.
637 """
638 dv = self.default_value
639 if isinstance(dv, DefaultValueGenerator):
640 klass = dv.klass
641 args = dv.args
642 kw = dv.kw
643 if isinstance(klass, basestring):
644 klass = self.validate_class(self.find_class(klass))
645 if klass is None:
646 raise TraitletError('Unable to locate class: ' + dv.klass)
647 return klass(*args, **kw)
648 else:
649 return dv
650
651
652 #-----------------------------------------------------------------------------
653 # Basic TraitletTypes implementations/subclasses
654 #-----------------------------------------------------------------------------
655
340
656
341 class Any(TraitletType):
657 class Any(TraitletType):
342 default_value = None
658 default_value = None
General Comments 0
You need to be logged in to leave comments. Login now