diff --git a/IPython/core/application.py b/IPython/core/application.py index cf26a31..918211b 100644 --- a/IPython/core/application.py +++ b/IPython/core/application.py @@ -25,7 +25,7 @@ from IPython.core import release, crashhandler from IPython.core.profiledir import ProfileDir, ProfileDirError from IPython.utils.path import get_ipython_dir, get_ipython_package_dir, ensure_dir_exists from IPython.utils import py3compat -from IPython.utils.traitlets import List, Unicode, Type, Bool, Dict, Set, Instance +from IPython.utils.traitlets import List, Unicode, Type, Bool, Dict, Set, Instance, Undefined if os.name == 'nt': programdata = os.environ.get('PROGRAMDATA', None) @@ -215,7 +215,7 @@ class BaseIPythonApplication(Application): return crashhandler.crash_handler_lite(etype, evalue, tb) def _ipython_dir_changed(self, name, old, new): - if old is not None: + if old is not None and old is not Undefined: str_old = py3compat.cast_bytes_py2(os.path.abspath(old), sys.getfilesystemencoding() ) diff --git a/IPython/html/widgets/widget_float.py b/IPython/html/widgets/widget_float.py index afc7936..a7cf157 100644 --- a/IPython/html/widgets/widget_float.py +++ b/IPython/html/widgets/widget_float.py @@ -1,4 +1,4 @@ -"""Float class. +"""Float class. Represents an unbounded float using a widget. """ @@ -15,7 +15,8 @@ Represents an unbounded float using a widget. #----------------------------------------------------------------------------- from .widget import DOMWidget, register from .trait_types import Color -from IPython.utils.traitlets import Unicode, CFloat, Bool, CaselessStrEnum, Tuple +from IPython.utils.traitlets import (Unicode, CFloat, Bool, CaselessStrEnum, + Tuple, TraitError) from IPython.utils.warn import DeprecatedClass #----------------------------------------------------------------------------- @@ -31,39 +32,37 @@ class _Float(DOMWidget): kwargs['value'] = value super(_Float, self).__init__(**kwargs) + class _BoundedFloat(_Float): max = CFloat(100.0, help="Max value", sync=True) min = CFloat(0.0, help="Min value", sync=True) - step = CFloat(0.1, help="Minimum step that the value can take (ignored by some views)", sync=True) + step = CFloat(0.1, help="Minimum step to increment the value (ignored by some views)", sync=True) def __init__(self, *pargs, **kwargs): """Constructor""" super(_BoundedFloat, self).__init__(*pargs, **kwargs) - self._handle_value_changed('value', None, self.value) - self._handle_max_changed('max', None, self.max) - self._handle_min_changed('min', None, self.min) - self.on_trait_change(self._handle_value_changed, 'value') - self.on_trait_change(self._handle_max_changed, 'max') - self.on_trait_change(self._handle_min_changed, 'min') - def _handle_value_changed(self, name, old, new): - """Validate value.""" - if self.min > new or new > self.max: - self.value = min(max(new, self.min), self.max) + def _value_validate(self, value, trait): + """Cap and floor value""" + if self.min > value or self.max < value: + value = min(max(value, self.min), self.max) + return value - def _handle_max_changed(self, name, old, new): - """Make sure the min is always <= the max.""" - if new < self.min: - raise ValueError("setting max < min") - if new < self.value: - self.value = new + def _min_validate(self, min, trait): + """Enforce min <= value <= max""" + if min > self.max: + raise TraitError("Setting min > max") + if min > self.value: + self.value = min + return min - def _handle_min_changed(self, name, old, new): - """Make sure the max is always >= the min.""" - if new > self.max: - raise ValueError("setting min > max") - if new > self.value: - self.value = new + def _max_validate(self, max, trait): + """Enforce min <= value <= max""" + if max < self.min: + raise TraitError("setting max < min") + if max < self.value: + self.value = max + return max @register('IPython.FloatText') @@ -76,9 +75,9 @@ class FloatText(_Float): value : float value displayed description : str - description displayed next to the textbox + description displayed next to the text box color : str Unicode color code (eg. '#C13535'), optional - color of the value displayed + color of the value displayed """ _view_name = Unicode('FloatTextView', sync=True) diff --git a/IPython/html/widgets/widget_int.py b/IPython/html/widgets/widget_int.py index e006054..e76f26b 100644 --- a/IPython/html/widgets/widget_int.py +++ b/IPython/html/widgets/widget_int.py @@ -1,4 +1,4 @@ -"""Int class. +"""Int class. Represents an unbounded int using a widget. """ @@ -15,7 +15,8 @@ Represents an unbounded int using a widget. #----------------------------------------------------------------------------- from .widget import DOMWidget, register from .trait_types import Color -from IPython.utils.traitlets import Unicode, CInt, Bool, CaselessStrEnum, Tuple +from IPython.utils.traitlets import (Unicode, CInt, Bool, CaselessStrEnum, + Tuple, TraitError) from IPython.utils.warn import DeprecatedClass #----------------------------------------------------------------------------- @@ -32,41 +33,39 @@ class _Int(DOMWidget): kwargs['value'] = value super(_Int, self).__init__(**kwargs) + class _BoundedInt(_Int): """Base class used to create widgets that represent a int that is bounded by a minium and maximum.""" - step = CInt(1, help="Minimum step that the value can take (ignored by some views)", sync=True) + step = CInt(1, help="Minimum step to increment the value (ignored by some views)", sync=True) max = CInt(100, help="Max value", sync=True) min = CInt(0, help="Min value", sync=True) def __init__(self, *pargs, **kwargs): """Constructor""" super(_BoundedInt, self).__init__(*pargs, **kwargs) - self._handle_value_changed('value', None, self.value) - self._handle_max_changed('max', None, self.max) - self._handle_min_changed('min', None, self.min) - self.on_trait_change(self._handle_value_changed, 'value') - self.on_trait_change(self._handle_max_changed, 'max') - self.on_trait_change(self._handle_min_changed, 'min') - - def _handle_value_changed(self, name, old, new): - """Validate value.""" - if self.min > new or new > self.max: - self.value = min(max(new, self.min), self.max) - - def _handle_max_changed(self, name, old, new): - """Make sure the min is always <= the max.""" - if new < self.min: - raise ValueError("setting max < min") - if new < self.value: - self.value = new - - def _handle_min_changed(self, name, old, new): - """Make sure the max is always >= the min.""" - if new > self.max: - raise ValueError("setting min > max") - if new > self.value: - self.value = new + + def _value_validate(self, value, trait): + """Cap and floor value""" + if self.min > value or self.max < value: + value = min(max(value, self.min), self.max) + return value + + def _min_validate(self, min, trait): + """Enforce min <= value <= max""" + if min > self.max: + raise TraitError("Setting min > max") + if min > self.value: + self.value = min + return min + + def _max_validate(self, max, trait): + """Enforce min <= value <= max""" + if max < self.min: + raise TraitError("setting max < min") + if max < self.value: + self.value = max + return max @register('IPython.IntText') class IntText(_Int): diff --git a/IPython/utils/tests/test_traitlets.py b/IPython/utils/tests/test_traitlets.py index 6004907..6333d2a 100644 --- a/IPython/utils/tests/test_traitlets.py +++ b/IPython/utils/tests/test_traitlets.py @@ -1333,11 +1333,20 @@ def test_pickle_hastraits(): def test_hold_trait_notifications(): changes = [] + class Test(HasTraits): a = Integer(0) + b = Integer(0) + def _a_changed(self, name, old, new): changes.append((old, new)) - + + def _b_validate(self, value, trait): + if value != 0: + raise TraitError('Only 0 is a valid value') + return value + + # Test context manager and nesting t = Test() with t.hold_trait_notifications(): with t.hold_trait_notifications(): @@ -1356,8 +1365,16 @@ def test_hold_trait_notifications(): t.a = 4 nt.assert_equal(t.a, 4) nt.assert_equal(changes, []) - nt.assert_equal(changes, [(0,1), (1,2), (2,3), (3,4)]) + nt.assert_equal(changes, [(3,4)]) + # Test roll-back + try: + with t.hold_trait_notifications(): + t.b = 1 # raises a Trait error + except: + pass + nt.assert_equal(t.b, 0) + class OrderTraits(HasTraits): notified = Dict() diff --git a/IPython/utils/traitlets.py b/IPython/utils/traitlets.py index 076091a..92fff44 100644 --- a/IPython/utils/traitlets.py +++ b/IPython/utils/traitlets.py @@ -319,7 +319,6 @@ class TraitType(object): accept superclasses for :class:`This` values. """ - metadata = {} default_value = Undefined allow_none = False @@ -447,7 +446,7 @@ class TraitType(object): try: old_value = obj._trait_values[self.name] except KeyError: - old_value = None + old_value = Undefined obj._trait_values[self.name] = new_value try: @@ -465,13 +464,14 @@ class TraitType(object): return value if hasattr(self, 'validate'): value = self.validate(obj, value) - try: - obj_validate = getattr(obj, '_%s_validate' % self.name) - except (AttributeError, RuntimeError): - # Qt mixins raise RuntimeError on missing attrs accessed before __init__ - pass - else: - value = obj_validate(value, self) + if obj._cross_validation_lock is False: + value = self._cross_validate(obj, value) + return value + + def _cross_validate(self, obj, value): + if hasattr(obj, '_%s_validate' % self.name): + cross_validate = getattr(obj, '_%s_validate' % self.name) + value = cross_validate(value, self) return value def __or__(self, other): @@ -542,6 +542,7 @@ class MetaHasTraits(type): v.this_class = cls super(MetaHasTraits, cls).__init__(name, bases, classdict) + class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)): def __new__(cls, *args, **kw): @@ -555,6 +556,7 @@ class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)): inst._trait_values = {} inst._trait_notifiers = {} inst._trait_dyn_inits = {} + inst._cross_validation_lock = True # Here we tell all the TraitType instances to set their default # values on the instance. for key in dir(cls): @@ -570,43 +572,74 @@ class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)): value.instance_init() if key not in kw: value.set_default_value(inst) - + inst._cross_validation_lock = False return inst def __init__(self, *args, **kw): # Allow trait values to be set using keyword arguments. # We need to use setattr for this to trigger validation and # notifications. - with self.hold_trait_notifications(): for key, value in iteritems(kw): setattr(self, key, value) - + @contextlib.contextmanager def hold_trait_notifications(self): - """Context manager for bundling trait change notifications - - Use this when doing multiple trait assignments (init, config), - to avoid race conditions in trait notifiers requesting other trait values. + """Context manager for bundling trait change notifications and cross + validation. + + Use this when doing multiple trait assignments (init, config), to avoid + race conditions in trait notifiers requesting other trait values. All trait notifications will fire after all values have been assigned. """ - _notify_trait = self._notify_trait - notifications = [] - self._notify_trait = lambda *a: notifications.append(a) - - try: + if self._cross_validation_lock is True: yield - finally: - self._notify_trait = _notify_trait - if isinstance(_notify_trait, types.MethodType): - # FIXME: remove when support is bumped to 3.4. - # when original method is restored, - # remove the redundant value from __dict__ - # (only used to preserve pickleability on Python < 3.4) - self.__dict__.pop('_notify_trait', None) - # trigger delayed notifications - for args in notifications: - self._notify_trait(*args) + return + else: + self._cross_validation_lock = True + cache = {} + notifications = {} + _notify_trait = self._notify_trait + + def cache_values(*a): + cache[a[0]] = a + + def hold_notifications(*a): + notifications[a[0]] = a + + self._notify_trait = cache_values + + try: + yield + finally: + try: + self._notify_trait = hold_notifications + for name in cache: + if hasattr(self, '_%s_validate' % name): + cross_validate = getattr(self, '_%s_validate' % name) + setattr(self, name, cross_validate(getattr(self, name), self)) + except TraitError as e: + self._notify_trait = lambda *x: None + for name in cache: + if cache[name][1] is not Undefined: + setattr(self, name, cache[name][1]) + else: + delattr(self, name) + cache = {} + notifications = {} + raise e + finally: + self._notify_trait = _notify_trait + self._cross_validation_lock = False + if isinstance(_notify_trait, types.MethodType): + # FIXME: remove when support is bumped to 3.4. + # when original method is restored, + # remove the redundant value from __dict__ + # (only used to preserve pickleability on Python < 3.4) + self.__dict__.pop('_notify_trait', None) + # trigger delayed notifications + for v in dict(cache, **notifications).values(): + self._notify_trait(*v) def _notify_trait(self, name, old_value, new_value):