##// END OF EJS Templates
Merge pull request #6128 from jasongrout/widget-trait-serialization...
Brian E. Granger -
r17330:3c3d1ae5 merge
parent child Browse files
Show More
@@ -1,482 +1,484
1 1 // Copyright (c) IPython Development Team.
2 2 // Distributed under the terms of the Modified BSD License.
3 3
4 4 define(["widgets/js/manager",
5 5 "underscore",
6 6 "backbone",
7 7 "jquery",
8 8 "base/js/namespace",
9 9 ], function(widgetmanager, _, Backbone, $, IPython){
10 10
11 11 var WidgetModel = Backbone.Model.extend({
12 12 constructor: function (widget_manager, model_id, comm) {
13 13 // Constructor
14 14 //
15 15 // Creates a WidgetModel instance.
16 16 //
17 17 // Parameters
18 18 // ----------
19 19 // widget_manager : WidgetManager instance
20 20 // model_id : string
21 21 // An ID unique to this model.
22 22 // comm : Comm instance (optional)
23 23 this.widget_manager = widget_manager;
24 24 this._buffered_state_diff = {};
25 25 this.pending_msgs = 0;
26 26 this.msg_buffer = null;
27 27 this.key_value_lock = null;
28 28 this.id = model_id;
29 29 this.views = [];
30 30
31 31 if (comm !== undefined) {
32 32 // Remember comm associated with the model.
33 33 this.comm = comm;
34 34 comm.model = this;
35 35
36 36 // Hook comm messages up to model.
37 37 comm.on_close($.proxy(this._handle_comm_closed, this));
38 38 comm.on_msg($.proxy(this._handle_comm_msg, this));
39 39 }
40 40 return Backbone.Model.apply(this);
41 41 },
42 42
43 43 send: function (content, callbacks) {
44 44 // Send a custom msg over the comm.
45 45 if (this.comm !== undefined) {
46 46 var data = {method: 'custom', content: content};
47 47 this.comm.send(data, callbacks);
48 48 this.pending_msgs++;
49 49 }
50 50 },
51 51
52 52 _handle_comm_closed: function (msg) {
53 53 // Handle when a widget is closed.
54 54 this.trigger('comm:close');
55 55 delete this.comm.model; // Delete ref so GC will collect widget model.
56 56 delete this.comm;
57 57 delete this.model_id; // Delete id from model so widget manager cleans up.
58 58 _.each(this.views, function(view, i) {
59 59 view.remove();
60 60 });
61 61 },
62 62
63 63 _handle_comm_msg: function (msg) {
64 64 // Handle incoming comm msg.
65 65 var method = msg.content.data.method;
66 66 switch (method) {
67 67 case 'update':
68 68 this.apply_update(msg.content.data.state);
69 69 break;
70 70 case 'custom':
71 71 this.trigger('msg:custom', msg.content.data.content);
72 72 break;
73 73 case 'display':
74 74 this.widget_manager.display_view(msg, this);
75 75 break;
76 76 }
77 77 },
78 78
79 79 apply_update: function (state) {
80 80 // Handle when a widget is updated via the python side.
81 81 var that = this;
82 82 _.each(state, function(value, key) {
83 83 that.key_value_lock = [key, value];
84 84 try {
85 85 WidgetModel.__super__.set.apply(that, [key, that._unpack_models(value)]);
86 86 } finally {
87 87 that.key_value_lock = null;
88 88 }
89 89 });
90 90 },
91 91
92 92 _handle_status: function (msg, callbacks) {
93 93 // Handle status msgs.
94 94
95 95 // execution_state : ('busy', 'idle', 'starting')
96 96 if (this.comm !== undefined) {
97 97 if (msg.content.execution_state ==='idle') {
98 98 // Send buffer if this message caused another message to be
99 99 // throttled.
100 100 if (this.msg_buffer !== null &&
101 101 (this.get('msg_throttle') || 3) === this.pending_msgs) {
102 102 var data = {method: 'backbone', sync_method: 'update', sync_data: this.msg_buffer};
103 103 this.comm.send(data, callbacks);
104 104 this.msg_buffer = null;
105 105 } else {
106 106 --this.pending_msgs;
107 107 }
108 108 }
109 109 }
110 110 },
111 111
112 112 callbacks: function(view) {
113 113 // Create msg callbacks for a comm msg.
114 114 var callbacks = this.widget_manager.callbacks(view);
115 115
116 116 if (callbacks.iopub === undefined) {
117 117 callbacks.iopub = {};
118 118 }
119 119
120 120 var that = this;
121 121 callbacks.iopub.status = function (msg) {
122 122 that._handle_status(msg, callbacks);
123 123 };
124 124 return callbacks;
125 125 },
126 126
127 127 set: function(key, val, options) {
128 128 // Set a value.
129 129 var return_value = WidgetModel.__super__.set.apply(this, arguments);
130 130
131 131 // Backbone only remembers the diff of the most recent set()
132 132 // operation. Calling set multiple times in a row results in a
133 133 // loss of diff information. Here we keep our own running diff.
134 134 this._buffered_state_diff = $.extend(this._buffered_state_diff, this.changedAttributes() || {});
135 135 return return_value;
136 136 },
137 137
138 138 sync: function (method, model, options) {
139 139 // Handle sync to the back-end. Called when a model.save() is called.
140 140
141 141 // Make sure a comm exists.
142 142 var error = options.error || function() {
143 143 console.error('Backbone sync error:', arguments);
144 144 };
145 145 if (this.comm === undefined) {
146 146 error();
147 147 return false;
148 148 }
149 149
150 150 // Delete any key value pairs that the back-end already knows about.
151 151 var attrs = (method === 'patch') ? options.attrs : model.toJSON(options);
152 152 if (this.key_value_lock !== null) {
153 153 var key = this.key_value_lock[0];
154 154 var value = this.key_value_lock[1];
155 155 if (attrs[key] === value) {
156 156 delete attrs[key];
157 157 }
158 158 }
159 159
160 160 // Only sync if there are attributes to send to the back-end.
161 161 attrs = this._pack_models(attrs);
162 162 if (_.size(attrs) > 0) {
163 163
164 164 // If this message was sent via backbone itself, it will not
165 165 // have any callbacks. It's important that we create callbacks
166 166 // so we can listen for status messages, etc...
167 167 var callbacks = options.callbacks || this.callbacks();
168 168
169 169 // Check throttle.
170 170 if (this.pending_msgs >= (this.get('msg_throttle') || 3)) {
171 171 // The throttle has been exceeded, buffer the current msg so
172 172 // it can be sent once the kernel has finished processing
173 173 // some of the existing messages.
174 174
175 175 // Combine updates if it is a 'patch' sync, otherwise replace updates
176 176 switch (method) {
177 177 case 'patch':
178 178 this.msg_buffer = $.extend(this.msg_buffer || {}, attrs);
179 179 break;
180 180 case 'update':
181 181 case 'create':
182 182 this.msg_buffer = attrs;
183 183 break;
184 184 default:
185 185 error();
186 186 return false;
187 187 }
188 188 this.msg_buffer_callbacks = callbacks;
189 189
190 190 } else {
191 191 // We haven't exceeded the throttle, send the message like
192 192 // normal.
193 193 var data = {method: 'backbone', sync_data: attrs};
194 194 this.comm.send(data, callbacks);
195 195 this.pending_msgs++;
196 196 }
197 197 }
198 198 // Since the comm is a one-way communication, assume the message
199 199 // arrived. Don't call success since we don't have a model back from the server
200 200 // this means we miss out on the 'sync' event.
201 201 this._buffered_state_diff = {};
202 202 },
203 203
204 204 save_changes: function(callbacks) {
205 205 // Push this model's state to the back-end
206 206 //
207 207 // This invokes a Backbone.Sync.
208 208 this.save(this._buffered_state_diff, {patch: true, callbacks: callbacks});
209 209 },
210 210
211 211 _pack_models: function(value) {
212 212 // Replace models with model ids recursively.
213 213 var that = this;
214 214 var packed;
215 215 if (value instanceof Backbone.Model) {
216 return value.id;
216 return "IPY_MODEL_" + value.id;
217 217
218 218 } else if ($.isArray(value)) {
219 219 packed = [];
220 220 _.each(value, function(sub_value, key) {
221 221 packed.push(that._pack_models(sub_value));
222 222 });
223 223 return packed;
224 224
225 225 } else if (value instanceof Object) {
226 226 packed = {};
227 227 _.each(value, function(sub_value, key) {
228 228 packed[key] = that._pack_models(sub_value);
229 229 });
230 230 return packed;
231 231
232 232 } else {
233 233 return value;
234 234 }
235 235 },
236 236
237 237 _unpack_models: function(value) {
238 238 // Replace model ids with models recursively.
239 239 var that = this;
240 240 var unpacked;
241 241 if ($.isArray(value)) {
242 242 unpacked = [];
243 243 _.each(value, function(sub_value, key) {
244 244 unpacked.push(that._unpack_models(sub_value));
245 245 });
246 246 return unpacked;
247 247
248 248 } else if (value instanceof Object) {
249 249 unpacked = {};
250 250 _.each(value, function(sub_value, key) {
251 251 unpacked[key] = that._unpack_models(sub_value);
252 252 });
253 253 return unpacked;
254 254
255 } else if (typeof value === 'string' && value.slice(0,10) === "IPY_MODEL_") {
256 var model = this.widget_manager.get_model(value.slice(10, value.length));
257 if (model) {
258 return model;
259 } else {
260 return value;
261 }
255 262 } else {
256 var model = this.widget_manager.get_model(value);
257 if (model) {
258 return model;
259 } else {
260 263 return value;
261 }
262 264 }
263 265 },
264 266
265 267 });
266 268 widgetmanager.WidgetManager.register_widget_model('WidgetModel', WidgetModel);
267 269
268 270
269 271 var WidgetView = Backbone.View.extend({
270 272 initialize: function(parameters) {
271 273 // Public constructor.
272 274 this.model.on('change',this.update,this);
273 275 this.options = parameters.options;
274 276 this.child_model_views = {};
275 277 this.child_views = {};
276 278 this.model.views.push(this);
277 279 this.id = this.id || IPython.utils.uuid();
278 280 },
279 281
280 282 update: function(){
281 283 // Triggered on model change.
282 284 //
283 285 // Update view to be consistent with this.model
284 286 },
285 287
286 288 create_child_view: function(child_model, options) {
287 289 // Create and return a child view.
288 290 //
289 291 // -given a model and (optionally) a view name if the view name is
290 292 // not given, it defaults to the model's default view attribute.
291 293
292 294 // TODO: this is hacky, and makes the view depend on this cell attribute and widget manager behavior
293 295 // it would be great to have the widget manager add the cell metadata
294 296 // to the subview without having to add it here.
295 297 options = $.extend({ parent: this }, options || {});
296 298 var child_view = this.model.widget_manager.create_view(child_model, options, this);
297 299
298 300 // Associate the view id with the model id.
299 301 if (this.child_model_views[child_model.id] === undefined) {
300 302 this.child_model_views[child_model.id] = [];
301 303 }
302 304 this.child_model_views[child_model.id].push(child_view.id);
303 305
304 306 // Remember the view by id.
305 307 this.child_views[child_view.id] = child_view;
306 308 return child_view;
307 309 },
308 310
309 311 pop_child_view: function(child_model) {
310 312 // Delete a child view that was previously created using create_child_view.
311 313 var view_ids = this.child_model_views[child_model.id];
312 314 if (view_ids !== undefined) {
313 315
314 316 // Only delete the first view in the list.
315 317 var view_id = view_ids[0];
316 318 var view = this.child_views[view_id];
317 319 delete this.child_views[view_id];
318 320 view_ids.splice(0,1);
319 321 child_model.views.pop(view);
320 322
321 323 // Remove the view list specific to this model if it is empty.
322 324 if (view_ids.length === 0) {
323 325 delete this.child_model_views[child_model.id];
324 326 }
325 327 return view;
326 328 }
327 329 return null;
328 330 },
329 331
330 332 do_diff: function(old_list, new_list, removed_callback, added_callback) {
331 333 // Difference a changed list and call remove and add callbacks for
332 334 // each removed and added item in the new list.
333 335 //
334 336 // Parameters
335 337 // ----------
336 338 // old_list : array
337 339 // new_list : array
338 340 // removed_callback : Callback(item)
339 341 // Callback that is called for each item removed.
340 342 // added_callback : Callback(item)
341 343 // Callback that is called for each item added.
342 344
343 345 // Walk the lists until an unequal entry is found.
344 346 var i;
345 347 for (i = 0; i < new_list.length; i++) {
346 348 if (i < old_list.length || new_list[i] !== old_list[i]) {
347 349 break;
348 350 }
349 351 }
350 352
351 353 // Remove the non-matching items from the old list.
352 354 for (var j = i; j < old_list.length; j++) {
353 355 removed_callback(old_list[j]);
354 356 }
355 357
356 358 // Add the rest of the new list items.
357 359 for (i; i < new_list.length; i++) {
358 360 added_callback(new_list[i]);
359 361 }
360 362 },
361 363
362 364 callbacks: function(){
363 365 // Create msg callbacks for a comm msg.
364 366 return this.model.callbacks(this);
365 367 },
366 368
367 369 render: function(){
368 370 // Render the view.
369 371 //
370 372 // By default, this is only called the first time the view is created
371 373 },
372 374
373 375 show: function(){
374 376 // Show the widget-area
375 377 if (this.options && this.options.cell &&
376 378 this.options.cell.widget_area !== undefined) {
377 379 this.options.cell.widget_area.show();
378 380 }
379 381 },
380 382
381 383 send: function (content) {
382 384 // Send a custom msg associated with this view.
383 385 this.model.send(content, this.callbacks());
384 386 },
385 387
386 388 touch: function () {
387 389 this.model.save_changes(this.callbacks());
388 390 },
389 391 });
390 392
391 393
392 394 var DOMWidgetView = WidgetView.extend({
393 395 initialize: function (options) {
394 396 // Public constructor
395 397
396 398 // In the future we may want to make changes more granular
397 399 // (e.g., trigger on visible:change).
398 400 this.model.on('change', this.update, this);
399 401 this.model.on('msg:custom', this.on_msg, this);
400 402 DOMWidgetView.__super__.initialize.apply(this, arguments);
401 403 this.on('displayed', this.show, this);
402 404 },
403 405
404 406 on_msg: function(msg) {
405 407 // Handle DOM specific msgs.
406 408 switch(msg.msg_type) {
407 409 case 'add_class':
408 410 this.add_class(msg.selector, msg.class_list);
409 411 break;
410 412 case 'remove_class':
411 413 this.remove_class(msg.selector, msg.class_list);
412 414 break;
413 415 }
414 416 },
415 417
416 418 add_class: function (selector, class_list) {
417 419 // Add a DOM class to an element.
418 420 this._get_selector_element(selector).addClass(class_list);
419 421 },
420 422
421 423 remove_class: function (selector, class_list) {
422 424 // Remove a DOM class from an element.
423 425 this._get_selector_element(selector).removeClass(class_list);
424 426 },
425 427
426 428 update: function () {
427 429 // Update the contents of this view
428 430 //
429 431 // Called when the model is changed. The model may have been
430 432 // changed by another view or by a state update from the back-end.
431 433 // The very first update seems to happen before the element is
432 434 // finished rendering so we use setTimeout to give the element time
433 435 // to render
434 436 var e = this.$el;
435 437 var visible = this.model.get('visible');
436 438 setTimeout(function() {e.toggle(visible);},0);
437 439
438 440 var css = this.model.get('_css');
439 441 if (css === undefined) {return;}
440 442 for (var i = 0; i < css.length; i++) {
441 443 // Apply the css traits to all elements that match the selector.
442 444 var selector = css[i][0];
443 445 var elements = this._get_selector_element(selector);
444 446 if (elements.length > 0) {
445 447 var trait_key = css[i][1];
446 448 var trait_value = css[i][2];
447 449 elements.css(trait_key ,trait_value);
448 450 }
449 451 }
450 452 },
451 453
452 454 _get_selector_element: function (selector) {
453 455 // Get the elements via the css selector.
454 456
455 457 // If the selector is blank, apply the style to the $el_to_style
456 458 // element. If the $el_to_style element is not defined, use apply
457 459 // the style to the view's element.
458 460 var elements;
459 461 if (!selector) {
460 462 if (this.$el_to_style === undefined) {
461 463 elements = this.$el;
462 464 } else {
463 465 elements = this.$el_to_style;
464 466 }
465 467 } else {
466 468 elements = this.$el.find(selector);
467 469 }
468 470 return elements;
469 471 },
470 472 });
471 473
472 474 var widget = {
473 475 'WidgetModel': WidgetModel,
474 476 'WidgetView': WidgetView,
475 477 'DOMWidgetView': DOMWidgetView,
476 478 };
477 479
478 480 // For backwards compatability.
479 481 $.extend(IPython, widget);
480 482
481 483 return widget;
482 484 });
@@ -1,440 +1,453
1 1 """Base Widget class. Allows user to create widgets in the back-end that render
2 2 in the IPython notebook front-end.
3 3 """
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (c) 2013, the IPython Development Team.
6 6 #
7 7 # Distributed under the terms of the Modified BSD License.
8 8 #
9 9 # The full license is in the file COPYING.txt, distributed with this software.
10 10 #-----------------------------------------------------------------------------
11 11
12 12 #-----------------------------------------------------------------------------
13 13 # Imports
14 14 #-----------------------------------------------------------------------------
15 15 from contextlib import contextmanager
16 16
17 17 from IPython.core.getipython import get_ipython
18 18 from IPython.kernel.comm import Comm
19 19 from IPython.config import LoggingConfigurable
20 20 from IPython.utils.traitlets import Unicode, Dict, Instance, Bool, List, Tuple, Int
21 21 from IPython.utils.py3compat import string_types
22 22
23 23 #-----------------------------------------------------------------------------
24 24 # Classes
25 25 #-----------------------------------------------------------------------------
26 26 class CallbackDispatcher(LoggingConfigurable):
27 27 """A structure for registering and running callbacks"""
28 28 callbacks = List()
29 29
30 30 def __call__(self, *args, **kwargs):
31 31 """Call all of the registered callbacks."""
32 32 value = None
33 33 for callback in self.callbacks:
34 34 try:
35 35 local_value = callback(*args, **kwargs)
36 36 except Exception as e:
37 37 ip = get_ipython()
38 38 if ip is None:
39 39 self.log.warn("Exception in callback %s: %s", callback, e, exc_info=True)
40 40 else:
41 41 ip.showtraceback()
42 42 else:
43 43 value = local_value if local_value is not None else value
44 44 return value
45 45
46 46 def register_callback(self, callback, remove=False):
47 47 """(Un)Register a callback
48 48
49 49 Parameters
50 50 ----------
51 51 callback: method handle
52 52 Method to be registered or unregistered.
53 53 remove=False: bool
54 54 Whether to unregister the callback."""
55 55
56 56 # (Un)Register the callback.
57 57 if remove and callback in self.callbacks:
58 58 self.callbacks.remove(callback)
59 59 elif not remove and callback not in self.callbacks:
60 60 self.callbacks.append(callback)
61 61
62 62 def _show_traceback(method):
63 63 """decorator for showing tracebacks in IPython"""
64 64 def m(self, *args, **kwargs):
65 65 try:
66 66 return(method(self, *args, **kwargs))
67 67 except Exception as e:
68 68 ip = get_ipython()
69 69 if ip is None:
70 70 self.log.warn("Exception in widget method %s: %s", method, e, exc_info=True)
71 71 else:
72 72 ip.showtraceback()
73 73 return m
74 74
75 75 class Widget(LoggingConfigurable):
76 76 #-------------------------------------------------------------------------
77 77 # Class attributes
78 78 #-------------------------------------------------------------------------
79 79 _widget_construction_callback = None
80 80 widgets = {}
81 81
82 82 @staticmethod
83 83 def on_widget_constructed(callback):
84 84 """Registers a callback to be called when a widget is constructed.
85 85
86 86 The callback must have the following signature:
87 87 callback(widget)"""
88 88 Widget._widget_construction_callback = callback
89 89
90 90 @staticmethod
91 91 def _call_widget_constructed(widget):
92 92 """Static method, called when a widget is constructed."""
93 93 if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback):
94 94 Widget._widget_construction_callback(widget)
95 95
96 96 #-------------------------------------------------------------------------
97 97 # Traits
98 98 #-------------------------------------------------------------------------
99 99 _model_name = Unicode('WidgetModel', help="""Name of the backbone model
100 100 registered in the front-end to create and sync this widget with.""")
101 101 _view_name = Unicode(help="""Default view registered in the front-end
102 102 to use to represent the widget.""", sync=True)
103 103 _comm = Instance('IPython.kernel.comm.Comm')
104 104
105 105 msg_throttle = Int(3, sync=True, help="""Maximum number of msgs the
106 106 front-end can send before receiving an idle msg from the back-end.""")
107 107
108 108 keys = List()
109 109 def _keys_default(self):
110 110 return [name for name in self.traits(sync=True)]
111 111
112 112 _property_lock = Tuple((None, None))
113 113
114 114 _display_callbacks = Instance(CallbackDispatcher, ())
115 115 _msg_callbacks = Instance(CallbackDispatcher, ())
116 116
117 117 #-------------------------------------------------------------------------
118 118 # (Con/de)structor
119 119 #-------------------------------------------------------------------------
120 120 def __init__(self, **kwargs):
121 121 """Public constructor"""
122 122 super(Widget, self).__init__(**kwargs)
123 123
124 124 self.on_trait_change(self._handle_property_changed, self.keys)
125 125 Widget._call_widget_constructed(self)
126 126
127 127 def __del__(self):
128 128 """Object disposal"""
129 129 self.close()
130 130
131 131 #-------------------------------------------------------------------------
132 132 # Properties
133 133 #-------------------------------------------------------------------------
134 134
135 135 @property
136 136 def comm(self):
137 137 """Gets the Comm associated with this widget.
138 138
139 139 If a Comm doesn't exist yet, a Comm will be created automagically."""
140 140 if self._comm is None:
141 141 # Create a comm.
142 142 self._comm = Comm(target_name=self._model_name)
143 143 self._comm.on_msg(self._handle_msg)
144 144 self._comm.on_close(self._close)
145 145 Widget.widgets[self.model_id] = self
146 146
147 147 # first update
148 148 self.send_state()
149 149 return self._comm
150 150
151 151 @property
152 152 def model_id(self):
153 153 """Gets the model id of this widget.
154 154
155 155 If a Comm doesn't exist yet, a Comm will be created automagically."""
156 156 return self.comm.comm_id
157 157
158 158 #-------------------------------------------------------------------------
159 159 # Methods
160 160 #-------------------------------------------------------------------------
161 161 def _close(self):
162 162 """Private close - cleanup objects, registry entries"""
163 163 del Widget.widgets[self.model_id]
164 164 self._comm = None
165 165
166 166 def close(self):
167 167 """Close method.
168 168
169 169 Closes the widget which closes the underlying comm.
170 170 When the comm is closed, all of the widget views are automatically
171 171 removed from the front-end."""
172 172 if self._comm is not None:
173 173 self._comm.close()
174 174 self._close()
175 175
176 176 def send_state(self, key=None):
177 177 """Sends the widget state, or a piece of it, to the front-end.
178 178
179 179 Parameters
180 180 ----------
181 181 key : unicode (optional)
182 182 A single property's name to sync with the front-end.
183 183 """
184 184 self._send({
185 185 "method" : "update",
186 186 "state" : self.get_state()
187 187 })
188 188
189 189 def get_state(self, key=None):
190 190 """Gets the widget state, or a piece of it.
191 191
192 192 Parameters
193 193 ----------
194 194 key : unicode (optional)
195 195 A single property's name to get.
196 196 """
197 197 keys = self.keys if key is None else [key]
198 return {k: self._pack_widgets(getattr(self, k)) for k in keys}
199
198 state = {}
199 for k in keys:
200 f = self.trait_metadata(k, 'to_json')
201 if f is None:
202 f = self._trait_to_json
203 value = getattr(self, k)
204 state[k] = f(value)
205 return state
206
200 207 def send(self, content):
201 208 """Sends a custom msg to the widget model in the front-end.
202 209
203 210 Parameters
204 211 ----------
205 212 content : dict
206 213 Content of the message to send.
207 214 """
208 215 self._send({"method": "custom", "content": content})
209 216
210 217 def on_msg(self, callback, remove=False):
211 218 """(Un)Register a custom msg receive callback.
212 219
213 220 Parameters
214 221 ----------
215 222 callback: callable
216 223 callback will be passed two arguments when a message arrives::
217 224
218 225 callback(widget, content)
219 226
220 227 remove: bool
221 228 True if the callback should be unregistered."""
222 229 self._msg_callbacks.register_callback(callback, remove=remove)
223 230
224 231 def on_displayed(self, callback, remove=False):
225 232 """(Un)Register a widget displayed callback.
226 233
227 234 Parameters
228 235 ----------
229 236 callback: method handler
230 237 Must have a signature of::
231 238
232 239 callback(widget, **kwargs)
233 240
234 241 kwargs from display are passed through without modification.
235 242 remove: bool
236 243 True if the callback should be unregistered."""
237 244 self._display_callbacks.register_callback(callback, remove=remove)
238 245
239 246 #-------------------------------------------------------------------------
240 247 # Support methods
241 248 #-------------------------------------------------------------------------
242 249 @contextmanager
243 250 def _lock_property(self, key, value):
244 251 """Lock a property-value pair.
245 252
246 253 NOTE: This, in addition to the single lock for all state changes, is
247 254 flawed. In the future we may want to look into buffering state changes
248 255 back to the front-end."""
249 256 self._property_lock = (key, value)
250 257 try:
251 258 yield
252 259 finally:
253 260 self._property_lock = (None, None)
254 261
255 262 def _should_send_property(self, key, value):
256 263 """Check the property lock (property_lock)"""
257 264 return key != self._property_lock[0] or \
258 265 value != self._property_lock[1]
259 266
260 267 # Event handlers
261 268 @_show_traceback
262 269 def _handle_msg(self, msg):
263 270 """Called when a msg is received from the front-end"""
264 271 data = msg['content']['data']
265 272 method = data['method']
266 273 if not method in ['backbone', 'custom']:
267 274 self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method)
268 275
269 276 # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
270 277 if method == 'backbone' and 'sync_data' in data:
271 278 sync_data = data['sync_data']
272 279 self._handle_receive_state(sync_data) # handles all methods
273 280
274 281 # Handle a custom msg from the front-end
275 282 elif method == 'custom':
276 283 if 'content' in data:
277 284 self._handle_custom_msg(data['content'])
278 285
279 286 def _handle_receive_state(self, sync_data):
280 287 """Called when a state is received from the front-end."""
281 288 for name in self.keys:
282 289 if name in sync_data:
283 value = self._unpack_widgets(sync_data[name])
290 f = self.trait_metadata(name, 'from_json')
291 if f is None:
292 f = self._trait_from_json
293 value = f(sync_data[name])
284 294 with self._lock_property(name, value):
285 295 setattr(self, name, value)
286 296
287 297 def _handle_custom_msg(self, content):
288 298 """Called when a custom msg is received."""
289 299 self._msg_callbacks(self, content)
290 300
291 301 def _handle_property_changed(self, name, old, new):
292 302 """Called when a property has been changed."""
293 303 # Make sure this isn't information that the front-end just sent us.
294 304 if self._should_send_property(name, new):
295 305 # Send new state to front-end
296 306 self.send_state(key=name)
297 307
298 308 def _handle_displayed(self, **kwargs):
299 309 """Called when a view has been displayed for this widget instance"""
300 310 self._display_callbacks(self, **kwargs)
301 311
302 def _pack_widgets(self, x):
303 """Recursively converts all widget instances to model id strings.
312 def _trait_to_json(self, x):
313 """Convert a trait value to json
304 314
305 Children widgets will be stored and transmitted to the front-end by
306 their model ids. Return value must be JSON-able."""
315 Traverse lists/tuples and dicts and serialize their values as well.
316 Replace any widgets with their model_id
317 """
307 318 if isinstance(x, dict):
308 return {k: self._pack_widgets(v) for k, v in x.items()}
319 return {k: self._trait_to_json(v) for k, v in x.items()}
309 320 elif isinstance(x, (list, tuple)):
310 return [self._pack_widgets(v) for v in x]
321 return [self._trait_to_json(v) for v in x]
311 322 elif isinstance(x, Widget):
312 return x.model_id
323 return "IPY_MODEL_" + x.model_id
313 324 else:
314 325 return x # Value must be JSON-able
315 326
316 def _unpack_widgets(self, x):
317 """Recursively converts all model id strings to widget instances.
327 def _trait_from_json(self, x):
328 """Convert json values to objects
318 329
319 Children widgets will be stored and transmitted to the front-end by
320 their model ids."""
330 Replace any strings representing valid model id values to Widget references.
331 """
321 332 if isinstance(x, dict):
322 return {k: self._unpack_widgets(v) for k, v in x.items()}
333 return {k: self._trait_from_json(v) for k, v in x.items()}
323 334 elif isinstance(x, (list, tuple)):
324 return [self._unpack_widgets(v) for v in x]
325 elif isinstance(x, string_types):
326 return x if x not in Widget.widgets else Widget.widgets[x]
335 return [self._trait_from_json(v) for v in x]
336 elif isinstance(x, string_types) and x.startswith('IPY_MODEL_') and x[10:] in Widget.widgets:
337 # we want to support having child widgets at any level in a hierarchy
338 # trusting that a widget UUID will not appear out in the wild
339 return Widget.widgets[x]
327 340 else:
328 341 return x
329 342
330 343 def _ipython_display_(self, **kwargs):
331 344 """Called when `IPython.display.display` is called on the widget."""
332 345 # Show view. By sending a display message, the comm is opened and the
333 346 # initial state is sent.
334 347 self._send({"method": "display"})
335 348 self._handle_displayed(**kwargs)
336 349
337 350 def _send(self, msg):
338 351 """Sends a message to the model in the front-end."""
339 352 self.comm.send(msg)
340 353
341 354
342 355 class DOMWidget(Widget):
343 356 visible = Bool(True, help="Whether the widget is visible.", sync=True)
344 357 _css = List(sync=True) # Internal CSS property list: (selector, key, value)
345 358
346 359 def get_css(self, key, selector=""):
347 360 """Get a CSS property of the widget.
348 361
349 362 Note: This function does not actually request the CSS from the
350 363 front-end; Only properties that have been set with set_css can be read.
351 364
352 365 Parameters
353 366 ----------
354 367 key: unicode
355 368 CSS key
356 369 selector: unicode (optional)
357 370 JQuery selector used when the CSS key/value was set.
358 371 """
359 372 if selector in self._css and key in self._css[selector]:
360 373 return self._css[selector][key]
361 374 else:
362 375 return None
363 376
364 377 def set_css(self, dict_or_key, value=None, selector=''):
365 378 """Set one or more CSS properties of the widget.
366 379
367 380 This function has two signatures:
368 381 - set_css(css_dict, selector='')
369 382 - set_css(key, value, selector='')
370 383
371 384 Parameters
372 385 ----------
373 386 css_dict : dict
374 387 CSS key/value pairs to apply
375 388 key: unicode
376 389 CSS key
377 390 value:
378 391 CSS value
379 392 selector: unicode (optional, kwarg only)
380 393 JQuery selector to use to apply the CSS key/value. If no selector
381 394 is provided, an empty selector is used. An empty selector makes the
382 395 front-end try to apply the css to a default element. The default
383 396 element is an attribute unique to each view, which is a DOM element
384 397 of the view that should be styled with common CSS (see
385 398 `$el_to_style` in the Javascript code).
386 399 """
387 400 if value is None:
388 401 css_dict = dict_or_key
389 402 else:
390 403 css_dict = {dict_or_key: value}
391 404
392 405 for (key, value) in css_dict.items():
393 406 # First remove the selector/key pair from the css list if it exists.
394 407 # Then add the selector/key pair and new value to the bottom of the
395 408 # list.
396 409 self._css = [x for x in self._css if not (x[0]==selector and x[1]==key)]
397 410 self._css += [(selector, key, value)]
398 411 self.send_state('_css')
399 412
400 413 def add_class(self, class_names, selector=""):
401 414 """Add class[es] to a DOM element.
402 415
403 416 Parameters
404 417 ----------
405 418 class_names: unicode or list
406 419 Class name(s) to add to the DOM element(s).
407 420 selector: unicode (optional)
408 421 JQuery selector to select the DOM element(s) that the class(es) will
409 422 be added to.
410 423 """
411 424 class_list = class_names
412 425 if isinstance(class_list, (list, tuple)):
413 426 class_list = ' '.join(class_list)
414 427
415 428 self.send({
416 429 "msg_type" : "add_class",
417 430 "class_list" : class_list,
418 431 "selector" : selector
419 432 })
420 433
421 434 def remove_class(self, class_names, selector=""):
422 435 """Remove class[es] from a DOM element.
423 436
424 437 Parameters
425 438 ----------
426 439 class_names: unicode or list
427 440 Class name(s) to remove from the DOM element(s).
428 441 selector: unicode (optional)
429 442 JQuery selector to select the DOM element(s) that the class(es) will
430 443 be removed from.
431 444 """
432 445 class_list = class_names
433 446 if isinstance(class_list, (list, tuple)):
434 447 class_list = ' '.join(class_list)
435 448
436 449 self.send({
437 450 "msg_type" : "remove_class",
438 451 "class_list" : class_list,
439 452 "selector" : selector,
440 453 })
@@ -1,1153 +1,1174
1 1 # encoding: utf-8
2 2 """Tests for IPython.utils.traitlets."""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6 #
7 7 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
8 8 # also under the terms of the Modified BSD License.
9 9
10 10 import pickle
11 11 import re
12 12 import sys
13 13 from unittest import TestCase
14 14
15 15 import nose.tools as nt
16 16 from nose import SkipTest
17 17
18 18 from IPython.utils.traitlets import (
19 19 HasTraits, MetaHasTraits, TraitType, Any, CBytes, Dict,
20 20 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
21 21 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
22 22 ObjectName, DottedObjectName, CRegExp, link
23 23 )
24 24 from IPython.utils import py3compat
25 25 from IPython.testing.decorators import skipif
26 26
27 27 #-----------------------------------------------------------------------------
28 28 # Helper classes for testing
29 29 #-----------------------------------------------------------------------------
30 30
31 31
32 32 class HasTraitsStub(HasTraits):
33 33
34 34 def _notify_trait(self, name, old, new):
35 35 self._notify_name = name
36 36 self._notify_old = old
37 37 self._notify_new = new
38 38
39 39
40 40 #-----------------------------------------------------------------------------
41 41 # Test classes
42 42 #-----------------------------------------------------------------------------
43 43
44 44
45 45 class TestTraitType(TestCase):
46 46
47 47 def test_get_undefined(self):
48 48 class A(HasTraits):
49 49 a = TraitType
50 50 a = A()
51 51 self.assertEqual(a.a, Undefined)
52 52
53 53 def test_set(self):
54 54 class A(HasTraitsStub):
55 55 a = TraitType
56 56
57 57 a = A()
58 58 a.a = 10
59 59 self.assertEqual(a.a, 10)
60 60 self.assertEqual(a._notify_name, 'a')
61 61 self.assertEqual(a._notify_old, Undefined)
62 62 self.assertEqual(a._notify_new, 10)
63 63
64 64 def test_validate(self):
65 65 class MyTT(TraitType):
66 66 def validate(self, inst, value):
67 67 return -1
68 68 class A(HasTraitsStub):
69 69 tt = MyTT
70 70
71 71 a = A()
72 72 a.tt = 10
73 73 self.assertEqual(a.tt, -1)
74 74
75 75 def test_default_validate(self):
76 76 class MyIntTT(TraitType):
77 77 def validate(self, obj, value):
78 78 if isinstance(value, int):
79 79 return value
80 80 self.error(obj, value)
81 81 class A(HasTraits):
82 82 tt = MyIntTT(10)
83 83 a = A()
84 84 self.assertEqual(a.tt, 10)
85 85
86 86 # Defaults are validated when the HasTraits is instantiated
87 87 class B(HasTraits):
88 88 tt = MyIntTT('bad default')
89 89 self.assertRaises(TraitError, B)
90 90
91 91 def test_is_valid_for(self):
92 92 class MyTT(TraitType):
93 93 def is_valid_for(self, value):
94 94 return True
95 95 class A(HasTraits):
96 96 tt = MyTT
97 97
98 98 a = A()
99 99 a.tt = 10
100 100 self.assertEqual(a.tt, 10)
101 101
102 102 def test_value_for(self):
103 103 class MyTT(TraitType):
104 104 def value_for(self, value):
105 105 return 20
106 106 class A(HasTraits):
107 107 tt = MyTT
108 108
109 109 a = A()
110 110 a.tt = 10
111 111 self.assertEqual(a.tt, 20)
112 112
113 113 def test_info(self):
114 114 class A(HasTraits):
115 115 tt = TraitType
116 116 a = A()
117 117 self.assertEqual(A.tt.info(), 'any value')
118 118
119 119 def test_error(self):
120 120 class A(HasTraits):
121 121 tt = TraitType
122 122 a = A()
123 123 self.assertRaises(TraitError, A.tt.error, a, 10)
124 124
125 125 def test_dynamic_initializer(self):
126 126 class A(HasTraits):
127 127 x = Int(10)
128 128 def _x_default(self):
129 129 return 11
130 130 class B(A):
131 131 x = Int(20)
132 132 class C(A):
133 133 def _x_default(self):
134 134 return 21
135 135
136 136 a = A()
137 137 self.assertEqual(a._trait_values, {})
138 138 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
139 139 self.assertEqual(a.x, 11)
140 140 self.assertEqual(a._trait_values, {'x': 11})
141 141 b = B()
142 142 self.assertEqual(b._trait_values, {'x': 20})
143 143 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
144 144 self.assertEqual(b.x, 20)
145 145 c = C()
146 146 self.assertEqual(c._trait_values, {})
147 147 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
148 148 self.assertEqual(c.x, 21)
149 149 self.assertEqual(c._trait_values, {'x': 21})
150 150 # Ensure that the base class remains unmolested when the _default
151 151 # initializer gets overridden in a subclass.
152 152 a = A()
153 153 c = C()
154 154 self.assertEqual(a._trait_values, {})
155 155 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
156 156 self.assertEqual(a.x, 11)
157 157 self.assertEqual(a._trait_values, {'x': 11})
158 158
159 159
160 160
161 161 class TestHasTraitsMeta(TestCase):
162 162
163 163 def test_metaclass(self):
164 164 self.assertEqual(type(HasTraits), MetaHasTraits)
165 165
166 166 class A(HasTraits):
167 167 a = Int
168 168
169 169 a = A()
170 170 self.assertEqual(type(a.__class__), MetaHasTraits)
171 171 self.assertEqual(a.a,0)
172 172 a.a = 10
173 173 self.assertEqual(a.a,10)
174 174
175 175 class B(HasTraits):
176 176 b = Int()
177 177
178 178 b = B()
179 179 self.assertEqual(b.b,0)
180 180 b.b = 10
181 181 self.assertEqual(b.b,10)
182 182
183 183 class C(HasTraits):
184 184 c = Int(30)
185 185
186 186 c = C()
187 187 self.assertEqual(c.c,30)
188 188 c.c = 10
189 189 self.assertEqual(c.c,10)
190 190
191 191 def test_this_class(self):
192 192 class A(HasTraits):
193 193 t = This()
194 194 tt = This()
195 195 class B(A):
196 196 tt = This()
197 197 ttt = This()
198 198 self.assertEqual(A.t.this_class, A)
199 199 self.assertEqual(B.t.this_class, A)
200 200 self.assertEqual(B.tt.this_class, B)
201 201 self.assertEqual(B.ttt.this_class, B)
202 202
203 203 class TestHasTraitsNotify(TestCase):
204 204
205 205 def setUp(self):
206 206 self._notify1 = []
207 207 self._notify2 = []
208 208
209 209 def notify1(self, name, old, new):
210 210 self._notify1.append((name, old, new))
211 211
212 212 def notify2(self, name, old, new):
213 213 self._notify2.append((name, old, new))
214 214
215 215 def test_notify_all(self):
216 216
217 217 class A(HasTraits):
218 218 a = Int
219 219 b = Float
220 220
221 221 a = A()
222 222 a.on_trait_change(self.notify1)
223 223 a.a = 0
224 224 self.assertEqual(len(self._notify1),0)
225 225 a.b = 0.0
226 226 self.assertEqual(len(self._notify1),0)
227 227 a.a = 10
228 228 self.assertTrue(('a',0,10) in self._notify1)
229 229 a.b = 10.0
230 230 self.assertTrue(('b',0.0,10.0) in self._notify1)
231 231 self.assertRaises(TraitError,setattr,a,'a','bad string')
232 232 self.assertRaises(TraitError,setattr,a,'b','bad string')
233 233 self._notify1 = []
234 234 a.on_trait_change(self.notify1,remove=True)
235 235 a.a = 20
236 236 a.b = 20.0
237 237 self.assertEqual(len(self._notify1),0)
238 238
239 239 def test_notify_one(self):
240 240
241 241 class A(HasTraits):
242 242 a = Int
243 243 b = Float
244 244
245 245 a = A()
246 246 a.on_trait_change(self.notify1, 'a')
247 247 a.a = 0
248 248 self.assertEqual(len(self._notify1),0)
249 249 a.a = 10
250 250 self.assertTrue(('a',0,10) in self._notify1)
251 251 self.assertRaises(TraitError,setattr,a,'a','bad string')
252 252
253 253 def test_subclass(self):
254 254
255 255 class A(HasTraits):
256 256 a = Int
257 257
258 258 class B(A):
259 259 b = Float
260 260
261 261 b = B()
262 262 self.assertEqual(b.a,0)
263 263 self.assertEqual(b.b,0.0)
264 264 b.a = 100
265 265 b.b = 100.0
266 266 self.assertEqual(b.a,100)
267 267 self.assertEqual(b.b,100.0)
268 268
269 269 def test_notify_subclass(self):
270 270
271 271 class A(HasTraits):
272 272 a = Int
273 273
274 274 class B(A):
275 275 b = Float
276 276
277 277 b = B()
278 278 b.on_trait_change(self.notify1, 'a')
279 279 b.on_trait_change(self.notify2, 'b')
280 280 b.a = 0
281 281 b.b = 0.0
282 282 self.assertEqual(len(self._notify1),0)
283 283 self.assertEqual(len(self._notify2),0)
284 284 b.a = 10
285 285 b.b = 10.0
286 286 self.assertTrue(('a',0,10) in self._notify1)
287 287 self.assertTrue(('b',0.0,10.0) in self._notify2)
288 288
289 289 def test_static_notify(self):
290 290
291 291 class A(HasTraits):
292 292 a = Int
293 293 _notify1 = []
294 294 def _a_changed(self, name, old, new):
295 295 self._notify1.append((name, old, new))
296 296
297 297 a = A()
298 298 a.a = 0
299 299 # This is broken!!!
300 300 self.assertEqual(len(a._notify1),0)
301 301 a.a = 10
302 302 self.assertTrue(('a',0,10) in a._notify1)
303 303
304 304 class B(A):
305 305 b = Float
306 306 _notify2 = []
307 307 def _b_changed(self, name, old, new):
308 308 self._notify2.append((name, old, new))
309 309
310 310 b = B()
311 311 b.a = 10
312 312 b.b = 10.0
313 313 self.assertTrue(('a',0,10) in b._notify1)
314 314 self.assertTrue(('b',0.0,10.0) in b._notify2)
315 315
316 316 def test_notify_args(self):
317 317
318 318 def callback0():
319 319 self.cb = ()
320 320 def callback1(name):
321 321 self.cb = (name,)
322 322 def callback2(name, new):
323 323 self.cb = (name, new)
324 324 def callback3(name, old, new):
325 325 self.cb = (name, old, new)
326 326
327 327 class A(HasTraits):
328 328 a = Int
329 329
330 330 a = A()
331 331 a.on_trait_change(callback0, 'a')
332 332 a.a = 10
333 333 self.assertEqual(self.cb,())
334 334 a.on_trait_change(callback0, 'a', remove=True)
335 335
336 336 a.on_trait_change(callback1, 'a')
337 337 a.a = 100
338 338 self.assertEqual(self.cb,('a',))
339 339 a.on_trait_change(callback1, 'a', remove=True)
340 340
341 341 a.on_trait_change(callback2, 'a')
342 342 a.a = 1000
343 343 self.assertEqual(self.cb,('a',1000))
344 344 a.on_trait_change(callback2, 'a', remove=True)
345 345
346 346 a.on_trait_change(callback3, 'a')
347 347 a.a = 10000
348 348 self.assertEqual(self.cb,('a',1000,10000))
349 349 a.on_trait_change(callback3, 'a', remove=True)
350 350
351 351 self.assertEqual(len(a._trait_notifiers['a']),0)
352 352
353 353 def test_notify_only_once(self):
354 354
355 355 class A(HasTraits):
356 356 listen_to = ['a']
357 357
358 358 a = Int(0)
359 359 b = 0
360 360
361 361 def __init__(self, **kwargs):
362 362 super(A, self).__init__(**kwargs)
363 363 self.on_trait_change(self.listener1, ['a'])
364 364
365 365 def listener1(self, name, old, new):
366 366 self.b += 1
367 367
368 368 class B(A):
369 369
370 370 c = 0
371 371 d = 0
372 372
373 373 def __init__(self, **kwargs):
374 374 super(B, self).__init__(**kwargs)
375 375 self.on_trait_change(self.listener2)
376 376
377 377 def listener2(self, name, old, new):
378 378 self.c += 1
379 379
380 380 def _a_changed(self, name, old, new):
381 381 self.d += 1
382 382
383 383 b = B()
384 384 b.a += 1
385 385 self.assertEqual(b.b, b.c)
386 386 self.assertEqual(b.b, b.d)
387 387 b.a += 1
388 388 self.assertEqual(b.b, b.c)
389 389 self.assertEqual(b.b, b.d)
390 390
391 391
392 392 class TestHasTraits(TestCase):
393 393
394 394 def test_trait_names(self):
395 395 class A(HasTraits):
396 396 i = Int
397 397 f = Float
398 398 a = A()
399 399 self.assertEqual(sorted(a.trait_names()),['f','i'])
400 400 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
401 401
402 402 def test_trait_metadata(self):
403 403 class A(HasTraits):
404 404 i = Int(config_key='MY_VALUE')
405 405 a = A()
406 406 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
407 407
408 408 def test_traits(self):
409 409 class A(HasTraits):
410 410 i = Int
411 411 f = Float
412 412 a = A()
413 413 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
414 414 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
415 415
416 416 def test_traits_metadata(self):
417 417 class A(HasTraits):
418 418 i = Int(config_key='VALUE1', other_thing='VALUE2')
419 419 f = Float(config_key='VALUE3', other_thing='VALUE2')
420 420 j = Int(0)
421 421 a = A()
422 422 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
423 423 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
424 424 self.assertEqual(traits, dict(i=A.i))
425 425
426 426 # This passes, but it shouldn't because I am replicating a bug in
427 427 # traits.
428 428 traits = a.traits(config_key=lambda v: True)
429 429 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
430 430
431 431 def test_init(self):
432 432 class A(HasTraits):
433 433 i = Int()
434 434 x = Float()
435 435 a = A(i=1, x=10.0)
436 436 self.assertEqual(a.i, 1)
437 437 self.assertEqual(a.x, 10.0)
438 438
439 439 def test_positional_args(self):
440 440 class A(HasTraits):
441 441 i = Int(0)
442 442 def __init__(self, i):
443 443 super(A, self).__init__()
444 444 self.i = i
445 445
446 446 a = A(5)
447 447 self.assertEqual(a.i, 5)
448 448 # should raise TypeError if no positional arg given
449 449 self.assertRaises(TypeError, A)
450 450
451 451 #-----------------------------------------------------------------------------
452 452 # Tests for specific trait types
453 453 #-----------------------------------------------------------------------------
454 454
455 455
456 456 class TestType(TestCase):
457 457
458 458 def test_default(self):
459 459
460 460 class B(object): pass
461 461 class A(HasTraits):
462 462 klass = Type
463 463
464 464 a = A()
465 465 self.assertEqual(a.klass, None)
466 466
467 467 a.klass = B
468 468 self.assertEqual(a.klass, B)
469 469 self.assertRaises(TraitError, setattr, a, 'klass', 10)
470 470
471 471 def test_value(self):
472 472
473 473 class B(object): pass
474 474 class C(object): pass
475 475 class A(HasTraits):
476 476 klass = Type(B)
477 477
478 478 a = A()
479 479 self.assertEqual(a.klass, B)
480 480 self.assertRaises(TraitError, setattr, a, 'klass', C)
481 481 self.assertRaises(TraitError, setattr, a, 'klass', object)
482 482 a.klass = B
483 483
484 484 def test_allow_none(self):
485 485
486 486 class B(object): pass
487 487 class C(B): pass
488 488 class A(HasTraits):
489 489 klass = Type(B, allow_none=False)
490 490
491 491 a = A()
492 492 self.assertEqual(a.klass, B)
493 493 self.assertRaises(TraitError, setattr, a, 'klass', None)
494 494 a.klass = C
495 495 self.assertEqual(a.klass, C)
496 496
497 497 def test_validate_klass(self):
498 498
499 499 class A(HasTraits):
500 500 klass = Type('no strings allowed')
501 501
502 502 self.assertRaises(ImportError, A)
503 503
504 504 class A(HasTraits):
505 505 klass = Type('rub.adub.Duck')
506 506
507 507 self.assertRaises(ImportError, A)
508 508
509 509 def test_validate_default(self):
510 510
511 511 class B(object): pass
512 512 class A(HasTraits):
513 513 klass = Type('bad default', B)
514 514
515 515 self.assertRaises(ImportError, A)
516 516
517 517 class C(HasTraits):
518 518 klass = Type(None, B, allow_none=False)
519 519
520 520 self.assertRaises(TraitError, C)
521 521
522 522 def test_str_klass(self):
523 523
524 524 class A(HasTraits):
525 525 klass = Type('IPython.utils.ipstruct.Struct')
526 526
527 527 from IPython.utils.ipstruct import Struct
528 528 a = A()
529 529 a.klass = Struct
530 530 self.assertEqual(a.klass, Struct)
531 531
532 532 self.assertRaises(TraitError, setattr, a, 'klass', 10)
533 533
534 534 def test_set_str_klass(self):
535 535
536 536 class A(HasTraits):
537 537 klass = Type()
538 538
539 539 a = A(klass='IPython.utils.ipstruct.Struct')
540 540 from IPython.utils.ipstruct import Struct
541 541 self.assertEqual(a.klass, Struct)
542 542
543 543 class TestInstance(TestCase):
544 544
545 545 def test_basic(self):
546 546 class Foo(object): pass
547 547 class Bar(Foo): pass
548 548 class Bah(object): pass
549 549
550 550 class A(HasTraits):
551 551 inst = Instance(Foo)
552 552
553 553 a = A()
554 554 self.assertTrue(a.inst is None)
555 555 a.inst = Foo()
556 556 self.assertTrue(isinstance(a.inst, Foo))
557 557 a.inst = Bar()
558 558 self.assertTrue(isinstance(a.inst, Foo))
559 559 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
560 560 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
561 561 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
562 562
563 def test_default_klass(self):
564 class Foo(object): pass
565 class Bar(Foo): pass
566 class Bah(object): pass
567
568 class FooInstance(Instance):
569 klass = Foo
570
571 class A(HasTraits):
572 inst = FooInstance()
573
574 a = A()
575 self.assertTrue(a.inst is None)
576 a.inst = Foo()
577 self.assertTrue(isinstance(a.inst, Foo))
578 a.inst = Bar()
579 self.assertTrue(isinstance(a.inst, Foo))
580 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
581 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
582 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
583
563 584 def test_unique_default_value(self):
564 585 class Foo(object): pass
565 586 class A(HasTraits):
566 587 inst = Instance(Foo,(),{})
567 588
568 589 a = A()
569 590 b = A()
570 591 self.assertTrue(a.inst is not b.inst)
571 592
572 593 def test_args_kw(self):
573 594 class Foo(object):
574 595 def __init__(self, c): self.c = c
575 596 class Bar(object): pass
576 597 class Bah(object):
577 598 def __init__(self, c, d):
578 599 self.c = c; self.d = d
579 600
580 601 class A(HasTraits):
581 602 inst = Instance(Foo, (10,))
582 603 a = A()
583 604 self.assertEqual(a.inst.c, 10)
584 605
585 606 class B(HasTraits):
586 607 inst = Instance(Bah, args=(10,), kw=dict(d=20))
587 608 b = B()
588 609 self.assertEqual(b.inst.c, 10)
589 610 self.assertEqual(b.inst.d, 20)
590 611
591 612 class C(HasTraits):
592 613 inst = Instance(Foo)
593 614 c = C()
594 615 self.assertTrue(c.inst is None)
595 616
596 617 def test_bad_default(self):
597 618 class Foo(object): pass
598 619
599 620 class A(HasTraits):
600 621 inst = Instance(Foo, allow_none=False)
601 622
602 623 self.assertRaises(TraitError, A)
603 624
604 625 def test_instance(self):
605 626 class Foo(object): pass
606 627
607 628 def inner():
608 629 class A(HasTraits):
609 630 inst = Instance(Foo())
610 631
611 632 self.assertRaises(TraitError, inner)
612 633
613 634
614 635 class TestThis(TestCase):
615 636
616 637 def test_this_class(self):
617 638 class Foo(HasTraits):
618 639 this = This
619 640
620 641 f = Foo()
621 642 self.assertEqual(f.this, None)
622 643 g = Foo()
623 644 f.this = g
624 645 self.assertEqual(f.this, g)
625 646 self.assertRaises(TraitError, setattr, f, 'this', 10)
626 647
627 648 def test_this_inst(self):
628 649 class Foo(HasTraits):
629 650 this = This()
630 651
631 652 f = Foo()
632 653 f.this = Foo()
633 654 self.assertTrue(isinstance(f.this, Foo))
634 655
635 656 def test_subclass(self):
636 657 class Foo(HasTraits):
637 658 t = This()
638 659 class Bar(Foo):
639 660 pass
640 661 f = Foo()
641 662 b = Bar()
642 663 f.t = b
643 664 b.t = f
644 665 self.assertEqual(f.t, b)
645 666 self.assertEqual(b.t, f)
646 667
647 668 def test_subclass_override(self):
648 669 class Foo(HasTraits):
649 670 t = This()
650 671 class Bar(Foo):
651 672 t = This()
652 673 f = Foo()
653 674 b = Bar()
654 675 f.t = b
655 676 self.assertEqual(f.t, b)
656 677 self.assertRaises(TraitError, setattr, b, 't', f)
657 678
658 679 class TraitTestBase(TestCase):
659 680 """A best testing class for basic trait types."""
660 681
661 682 def assign(self, value):
662 683 self.obj.value = value
663 684
664 685 def coerce(self, value):
665 686 return value
666 687
667 688 def test_good_values(self):
668 689 if hasattr(self, '_good_values'):
669 690 for value in self._good_values:
670 691 self.assign(value)
671 692 self.assertEqual(self.obj.value, self.coerce(value))
672 693
673 694 def test_bad_values(self):
674 695 if hasattr(self, '_bad_values'):
675 696 for value in self._bad_values:
676 697 try:
677 698 self.assertRaises(TraitError, self.assign, value)
678 699 except AssertionError:
679 700 assert False, value
680 701
681 702 def test_default_value(self):
682 703 if hasattr(self, '_default_value'):
683 704 self.assertEqual(self._default_value, self.obj.value)
684 705
685 706 def test_allow_none(self):
686 707 if (hasattr(self, '_bad_values') and hasattr(self, '_good_values') and
687 708 None in self._bad_values):
688 709 trait=self.obj.traits()['value']
689 710 try:
690 711 trait.allow_none = True
691 712 self._bad_values.remove(None)
692 713 #skip coerce. Allow None casts None to None.
693 714 self.assign(None)
694 715 self.assertEqual(self.obj.value,None)
695 716 self.test_good_values()
696 717 self.test_bad_values()
697 718 finally:
698 719 #tear down
699 720 trait.allow_none = False
700 721 self._bad_values.append(None)
701 722
702 723 def tearDown(self):
703 724 # restore default value after tests, if set
704 725 if hasattr(self, '_default_value'):
705 726 self.obj.value = self._default_value
706 727
707 728
708 729 class AnyTrait(HasTraits):
709 730
710 731 value = Any
711 732
712 733 class AnyTraitTest(TraitTestBase):
713 734
714 735 obj = AnyTrait()
715 736
716 737 _default_value = None
717 738 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
718 739 _bad_values = []
719 740
720 741
721 742 class IntTrait(HasTraits):
722 743
723 744 value = Int(99)
724 745
725 746 class TestInt(TraitTestBase):
726 747
727 748 obj = IntTrait()
728 749 _default_value = 99
729 750 _good_values = [10, -10]
730 751 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
731 752 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
732 753 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
733 754 if not py3compat.PY3:
734 755 _bad_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
735 756
736 757
737 758 class LongTrait(HasTraits):
738 759
739 760 value = Long(99 if py3compat.PY3 else long(99))
740 761
741 762 class TestLong(TraitTestBase):
742 763
743 764 obj = LongTrait()
744 765
745 766 _default_value = 99 if py3compat.PY3 else long(99)
746 767 _good_values = [10, -10]
747 768 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
748 769 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
749 770 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
750 771 u'-10.1']
751 772 if not py3compat.PY3:
752 773 # maxint undefined on py3, because int == long
753 774 _good_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
754 775 _bad_values.extend([[long(10)], (long(10),)])
755 776
756 777 @skipif(py3compat.PY3, "not relevant on py3")
757 778 def test_cast_small(self):
758 779 """Long casts ints to long"""
759 780 self.obj.value = 10
760 781 self.assertEqual(type(self.obj.value), long)
761 782
762 783
763 784 class IntegerTrait(HasTraits):
764 785 value = Integer(1)
765 786
766 787 class TestInteger(TestLong):
767 788 obj = IntegerTrait()
768 789 _default_value = 1
769 790
770 791 def coerce(self, n):
771 792 return int(n)
772 793
773 794 @skipif(py3compat.PY3, "not relevant on py3")
774 795 def test_cast_small(self):
775 796 """Integer casts small longs to int"""
776 797 if py3compat.PY3:
777 798 raise SkipTest("not relevant on py3")
778 799
779 800 self.obj.value = long(100)
780 801 self.assertEqual(type(self.obj.value), int)
781 802
782 803
783 804 class FloatTrait(HasTraits):
784 805
785 806 value = Float(99.0)
786 807
787 808 class TestFloat(TraitTestBase):
788 809
789 810 obj = FloatTrait()
790 811
791 812 _default_value = 99.0
792 813 _good_values = [10, -10, 10.1, -10.1]
793 814 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
794 815 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
795 816 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
796 817 if not py3compat.PY3:
797 818 _bad_values.extend([long(10), long(-10)])
798 819
799 820
800 821 class ComplexTrait(HasTraits):
801 822
802 823 value = Complex(99.0-99.0j)
803 824
804 825 class TestComplex(TraitTestBase):
805 826
806 827 obj = ComplexTrait()
807 828
808 829 _default_value = 99.0-99.0j
809 830 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
810 831 10.1j, 10.1+10.1j, 10.1-10.1j]
811 832 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
812 833 if not py3compat.PY3:
813 834 _bad_values.extend([long(10), long(-10)])
814 835
815 836
816 837 class BytesTrait(HasTraits):
817 838
818 839 value = Bytes(b'string')
819 840
820 841 class TestBytes(TraitTestBase):
821 842
822 843 obj = BytesTrait()
823 844
824 845 _default_value = b'string'
825 846 _good_values = [b'10', b'-10', b'10L',
826 847 b'-10L', b'10.1', b'-10.1', b'string']
827 848 _bad_values = [10, -10, 10.1, -10.1, 1j, [10],
828 849 ['ten'],{'ten': 10},(10,), None, u'string']
829 850 if not py3compat.PY3:
830 851 _bad_values.extend([long(10), long(-10)])
831 852
832 853
833 854 class UnicodeTrait(HasTraits):
834 855
835 856 value = Unicode(u'unicode')
836 857
837 858 class TestUnicode(TraitTestBase):
838 859
839 860 obj = UnicodeTrait()
840 861
841 862 _default_value = u'unicode'
842 863 _good_values = ['10', '-10', '10L', '-10L', '10.1',
843 864 '-10.1', '', u'', 'string', u'string', u"€"]
844 865 _bad_values = [10, -10, 10.1, -10.1, 1j,
845 866 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
846 867 if not py3compat.PY3:
847 868 _bad_values.extend([long(10), long(-10)])
848 869
849 870
850 871 class ObjectNameTrait(HasTraits):
851 872 value = ObjectName("abc")
852 873
853 874 class TestObjectName(TraitTestBase):
854 875 obj = ObjectNameTrait()
855 876
856 877 _default_value = "abc"
857 878 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
858 879 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
859 880 None, object(), object]
860 881 if sys.version_info[0] < 3:
861 882 _bad_values.append(u"ΓΎ")
862 883 else:
863 884 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
864 885
865 886
866 887 class DottedObjectNameTrait(HasTraits):
867 888 value = DottedObjectName("a.b")
868 889
869 890 class TestDottedObjectName(TraitTestBase):
870 891 obj = DottedObjectNameTrait()
871 892
872 893 _default_value = "a.b"
873 894 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
874 895 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None]
875 896 if sys.version_info[0] < 3:
876 897 _bad_values.append(u"t.ΓΎ")
877 898 else:
878 899 _good_values.append(u"t.ΓΎ")
879 900
880 901
881 902 class TCPAddressTrait(HasTraits):
882 903
883 904 value = TCPAddress()
884 905
885 906 class TestTCPAddress(TraitTestBase):
886 907
887 908 obj = TCPAddressTrait()
888 909
889 910 _default_value = ('127.0.0.1',0)
890 911 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
891 912 _bad_values = [(0,0),('localhost',10.0),('localhost',-1), None]
892 913
893 914 class ListTrait(HasTraits):
894 915
895 916 value = List(Int)
896 917
897 918 class TestList(TraitTestBase):
898 919
899 920 obj = ListTrait()
900 921
901 922 _default_value = []
902 923 _good_values = [[], [1], list(range(10)), (1,2)]
903 924 _bad_values = [10, [1,'a'], 'a']
904 925
905 926 def coerce(self, value):
906 927 if value is not None:
907 928 value = list(value)
908 929 return value
909 930
910 931 class Foo(object):
911 932 pass
912 933
913 934 class InstanceListTrait(HasTraits):
914 935
915 936 value = List(Instance(__name__+'.Foo'))
916 937
917 938 class TestInstanceList(TraitTestBase):
918 939
919 940 obj = InstanceListTrait()
920 941
921 942 def test_klass(self):
922 943 """Test that the instance klass is properly assigned."""
923 944 self.assertIs(self.obj.traits()['value']._trait.klass, Foo)
924 945
925 946 _default_value = []
926 947 _good_values = [[Foo(), Foo(), None], None]
927 948 _bad_values = [['1', 2,], '1', [Foo]]
928 949
929 950 class LenListTrait(HasTraits):
930 951
931 952 value = List(Int, [0], minlen=1, maxlen=2)
932 953
933 954 class TestLenList(TraitTestBase):
934 955
935 956 obj = LenListTrait()
936 957
937 958 _default_value = [0]
938 959 _good_values = [[1], [1,2], (1,2)]
939 960 _bad_values = [10, [1,'a'], 'a', [], list(range(3))]
940 961
941 962 def coerce(self, value):
942 963 if value is not None:
943 964 value = list(value)
944 965 return value
945 966
946 967 class TupleTrait(HasTraits):
947 968
948 969 value = Tuple(Int(allow_none=True))
949 970
950 971 class TestTupleTrait(TraitTestBase):
951 972
952 973 obj = TupleTrait()
953 974
954 975 _default_value = None
955 976 _good_values = [(1,), None, (0,), [1], (None,)]
956 977 _bad_values = [10, (1,2), ('a'), ()]
957 978
958 979 def coerce(self, value):
959 980 if value is not None:
960 981 value = tuple(value)
961 982 return value
962 983
963 984 def test_invalid_args(self):
964 985 self.assertRaises(TypeError, Tuple, 5)
965 986 self.assertRaises(TypeError, Tuple, default_value='hello')
966 987 t = Tuple(Int, CBytes, default_value=(1,5))
967 988
968 989 class LooseTupleTrait(HasTraits):
969 990
970 991 value = Tuple((1,2,3))
971 992
972 993 class TestLooseTupleTrait(TraitTestBase):
973 994
974 995 obj = LooseTupleTrait()
975 996
976 997 _default_value = (1,2,3)
977 998 _good_values = [(1,), None, [1], (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
978 999 _bad_values = [10, 'hello', {}]
979 1000
980 1001 def coerce(self, value):
981 1002 if value is not None:
982 1003 value = tuple(value)
983 1004 return value
984 1005
985 1006 def test_invalid_args(self):
986 1007 self.assertRaises(TypeError, Tuple, 5)
987 1008 self.assertRaises(TypeError, Tuple, default_value='hello')
988 1009 t = Tuple(Int, CBytes, default_value=(1,5))
989 1010
990 1011
991 1012 class MultiTupleTrait(HasTraits):
992 1013
993 1014 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
994 1015
995 1016 class TestMultiTuple(TraitTestBase):
996 1017
997 1018 obj = MultiTupleTrait()
998 1019
999 1020 _default_value = (99,b'bottles')
1000 1021 _good_values = [(1,b'a'), (2,b'b')]
1001 1022 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
1002 1023
1003 1024 class CRegExpTrait(HasTraits):
1004 1025
1005 1026 value = CRegExp(r'')
1006 1027
1007 1028 class TestCRegExp(TraitTestBase):
1008 1029
1009 1030 def coerce(self, value):
1010 1031 return re.compile(value)
1011 1032
1012 1033 obj = CRegExpTrait()
1013 1034
1014 1035 _default_value = re.compile(r'')
1015 1036 _good_values = [r'\d+', re.compile(r'\d+')]
1016 1037 _bad_values = [r'(', None, ()]
1017 1038
1018 1039 class DictTrait(HasTraits):
1019 1040 value = Dict()
1020 1041
1021 1042 def test_dict_assignment():
1022 1043 d = dict()
1023 1044 c = DictTrait()
1024 1045 c.value = d
1025 1046 d['a'] = 5
1026 1047 nt.assert_equal(d, c.value)
1027 1048 nt.assert_true(c.value is d)
1028 1049
1029 1050 class TestLink(TestCase):
1030 1051 def test_connect_same(self):
1031 1052 """Verify two traitlets of the same type can be linked together using link."""
1032 1053
1033 1054 # Create two simple classes with Int traitlets.
1034 1055 class A(HasTraits):
1035 1056 value = Int()
1036 1057 a = A(value=9)
1037 1058 b = A(value=8)
1038 1059
1039 1060 # Conenct the two classes.
1040 1061 c = link((a, 'value'), (b, 'value'))
1041 1062
1042 1063 # Make sure the values are the same at the point of linking.
1043 1064 self.assertEqual(a.value, b.value)
1044 1065
1045 1066 # Change one of the values to make sure they stay in sync.
1046 1067 a.value = 5
1047 1068 self.assertEqual(a.value, b.value)
1048 1069 b.value = 6
1049 1070 self.assertEqual(a.value, b.value)
1050 1071
1051 1072 def test_link_different(self):
1052 1073 """Verify two traitlets of different types can be linked together using link."""
1053 1074
1054 1075 # Create two simple classes with Int traitlets.
1055 1076 class A(HasTraits):
1056 1077 value = Int()
1057 1078 class B(HasTraits):
1058 1079 count = Int()
1059 1080 a = A(value=9)
1060 1081 b = B(count=8)
1061 1082
1062 1083 # Conenct the two classes.
1063 1084 c = link((a, 'value'), (b, 'count'))
1064 1085
1065 1086 # Make sure the values are the same at the point of linking.
1066 1087 self.assertEqual(a.value, b.count)
1067 1088
1068 1089 # Change one of the values to make sure they stay in sync.
1069 1090 a.value = 5
1070 1091 self.assertEqual(a.value, b.count)
1071 1092 b.count = 4
1072 1093 self.assertEqual(a.value, b.count)
1073 1094
1074 1095 def test_unlink(self):
1075 1096 """Verify two linked traitlets can be unlinked."""
1076 1097
1077 1098 # Create two simple classes with Int traitlets.
1078 1099 class A(HasTraits):
1079 1100 value = Int()
1080 1101 a = A(value=9)
1081 1102 b = A(value=8)
1082 1103
1083 1104 # Connect the two classes.
1084 1105 c = link((a, 'value'), (b, 'value'))
1085 1106 a.value = 4
1086 1107 c.unlink()
1087 1108
1088 1109 # Change one of the values to make sure they don't stay in sync.
1089 1110 a.value = 5
1090 1111 self.assertNotEqual(a.value, b.value)
1091 1112
1092 1113 def test_callbacks(self):
1093 1114 """Verify two linked traitlets have their callbacks called once."""
1094 1115
1095 1116 # Create two simple classes with Int traitlets.
1096 1117 class A(HasTraits):
1097 1118 value = Int()
1098 1119 class B(HasTraits):
1099 1120 count = Int()
1100 1121 a = A(value=9)
1101 1122 b = B(count=8)
1102 1123
1103 1124 # Register callbacks that count.
1104 1125 callback_count = []
1105 1126 def a_callback(name, old, new):
1106 1127 callback_count.append('a')
1107 1128 a.on_trait_change(a_callback, 'value')
1108 1129 def b_callback(name, old, new):
1109 1130 callback_count.append('b')
1110 1131 b.on_trait_change(b_callback, 'count')
1111 1132
1112 1133 # Connect the two classes.
1113 1134 c = link((a, 'value'), (b, 'count'))
1114 1135
1115 1136 # Make sure b's count was set to a's value once.
1116 1137 self.assertEqual(''.join(callback_count), 'b')
1117 1138 del callback_count[:]
1118 1139
1119 1140 # Make sure a's value was set to b's count once.
1120 1141 b.count = 5
1121 1142 self.assertEqual(''.join(callback_count), 'ba')
1122 1143 del callback_count[:]
1123 1144
1124 1145 # Make sure b's count was set to a's value once.
1125 1146 a.value = 4
1126 1147 self.assertEqual(''.join(callback_count), 'ab')
1127 1148 del callback_count[:]
1128 1149
1129 1150 class Pickleable(HasTraits):
1130 1151 i = Int()
1131 1152 j = Int()
1132 1153
1133 1154 def _i_default(self):
1134 1155 return 1
1135 1156
1136 1157 def _i_changed(self, name, old, new):
1137 1158 self.j = new
1138 1159
1139 1160 def test_pickle_hastraits():
1140 1161 c = Pickleable()
1141 1162 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1142 1163 p = pickle.dumps(c, protocol)
1143 1164 c2 = pickle.loads(p)
1144 1165 nt.assert_equal(c2.i, c.i)
1145 1166 nt.assert_equal(c2.j, c.j)
1146 1167
1147 1168 c.i = 5
1148 1169 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1149 1170 p = pickle.dumps(c, protocol)
1150 1171 c2 = pickle.loads(p)
1151 1172 nt.assert_equal(c2.i, c.i)
1152 1173 nt.assert_equal(c2.j, c.j)
1153 1174
@@ -1,1516 +1,1523
1 1 # encoding: utf-8
2 2 """
3 3 A lightweight Traits like module.
4 4
5 5 This is designed to provide a lightweight, simple, pure Python version of
6 6 many of the capabilities of enthought.traits. This includes:
7 7
8 8 * Validation
9 9 * Type specification with defaults
10 10 * Static and dynamic notification
11 11 * Basic predefined types
12 12 * An API that is similar to enthought.traits
13 13
14 14 We don't support:
15 15
16 16 * Delegation
17 17 * Automatic GUI generation
18 18 * A full set of trait types. Most importantly, we don't provide container
19 19 traits (list, dict, tuple) that can trigger notifications if their
20 20 contents change.
21 21 * API compatibility with enthought.traits
22 22
23 23 There are also some important difference in our design:
24 24
25 25 * enthought.traits does not validate default values. We do.
26 26
27 27 We choose to create this module because we need these capabilities, but
28 28 we need them to be pure Python so they work in all Python implementations,
29 29 including Jython and IronPython.
30 30
31 31 Inheritance diagram:
32 32
33 33 .. inheritance-diagram:: IPython.utils.traitlets
34 34 :parts: 3
35 35 """
36 36
37 37 # Copyright (c) IPython Development Team.
38 38 # Distributed under the terms of the Modified BSD License.
39 39 #
40 40 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
41 41 # also under the terms of the Modified BSD License.
42 42
43 43 import contextlib
44 44 import inspect
45 45 import re
46 46 import sys
47 47 import types
48 48 from types import FunctionType
49 49 try:
50 50 from types import ClassType, InstanceType
51 51 ClassTypes = (ClassType, type)
52 52 except:
53 53 ClassTypes = (type,)
54 54
55 55 from .importstring import import_item
56 56 from IPython.utils import py3compat
57 57 from IPython.utils.py3compat import iteritems
58 58 from IPython.testing.skipdoctest import skip_doctest
59 59
60 60 SequenceTypes = (list, tuple, set, frozenset)
61 61
62 62 #-----------------------------------------------------------------------------
63 63 # Basic classes
64 64 #-----------------------------------------------------------------------------
65 65
66 66
67 67 class NoDefaultSpecified ( object ): pass
68 68 NoDefaultSpecified = NoDefaultSpecified()
69 69
70 70
71 71 class Undefined ( object ): pass
72 72 Undefined = Undefined()
73 73
74 74 class TraitError(Exception):
75 75 pass
76 76
77 77 #-----------------------------------------------------------------------------
78 78 # Utilities
79 79 #-----------------------------------------------------------------------------
80 80
81 81
82 82 def class_of ( object ):
83 83 """ Returns a string containing the class name of an object with the
84 84 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
85 85 'a PlotValue').
86 86 """
87 87 if isinstance( object, py3compat.string_types ):
88 88 return add_article( object )
89 89
90 90 return add_article( object.__class__.__name__ )
91 91
92 92
93 93 def add_article ( name ):
94 94 """ Returns a string containing the correct indefinite article ('a' or 'an')
95 95 prefixed to the specified string.
96 96 """
97 97 if name[:1].lower() in 'aeiou':
98 98 return 'an ' + name
99 99
100 100 return 'a ' + name
101 101
102 102
103 103 def repr_type(obj):
104 104 """ Return a string representation of a value and its type for readable
105 105 error messages.
106 106 """
107 107 the_type = type(obj)
108 108 if (not py3compat.PY3) and the_type is InstanceType:
109 109 # Old-style class.
110 110 the_type = obj.__class__
111 111 msg = '%r %r' % (obj, the_type)
112 112 return msg
113 113
114 114
115 115 def is_trait(t):
116 116 """ Returns whether the given value is an instance or subclass of TraitType.
117 117 """
118 118 return (isinstance(t, TraitType) or
119 119 (isinstance(t, type) and issubclass(t, TraitType)))
120 120
121 121
122 122 def parse_notifier_name(name):
123 123 """Convert the name argument to a list of names.
124 124
125 125 Examples
126 126 --------
127 127
128 128 >>> parse_notifier_name('a')
129 129 ['a']
130 130 >>> parse_notifier_name(['a','b'])
131 131 ['a', 'b']
132 132 >>> parse_notifier_name(None)
133 133 ['anytrait']
134 134 """
135 135 if isinstance(name, str):
136 136 return [name]
137 137 elif name is None:
138 138 return ['anytrait']
139 139 elif isinstance(name, (list, tuple)):
140 140 for n in name:
141 141 assert isinstance(n, str), "names must be strings"
142 142 return name
143 143
144 144
145 145 class _SimpleTest:
146 146 def __init__ ( self, value ): self.value = value
147 147 def __call__ ( self, test ):
148 148 return test == self.value
149 149 def __repr__(self):
150 150 return "<SimpleTest(%r)" % self.value
151 151 def __str__(self):
152 152 return self.__repr__()
153 153
154 154
155 155 def getmembers(object, predicate=None):
156 156 """A safe version of inspect.getmembers that handles missing attributes.
157 157
158 158 This is useful when there are descriptor based attributes that for
159 159 some reason raise AttributeError even though they exist. This happens
160 160 in zope.inteface with the __provides__ attribute.
161 161 """
162 162 results = []
163 163 for key in dir(object):
164 164 try:
165 165 value = getattr(object, key)
166 166 except AttributeError:
167 167 pass
168 168 else:
169 169 if not predicate or predicate(value):
170 170 results.append((key, value))
171 171 results.sort()
172 172 return results
173 173
174 174 @skip_doctest
175 175 class link(object):
176 176 """Link traits from different objects together so they remain in sync.
177 177
178 178 Parameters
179 179 ----------
180 180 obj : pairs of objects/attributes
181 181
182 182 Examples
183 183 --------
184 184
185 185 >>> c = link((obj1, 'value'), (obj2, 'value'), (obj3, 'value'))
186 186 >>> obj1.value = 5 # updates other objects as well
187 187 """
188 188 updating = False
189 189 def __init__(self, *args):
190 190 if len(args) < 2:
191 191 raise TypeError('At least two traitlets must be provided.')
192 192
193 193 self.objects = {}
194 194 initial = getattr(args[0][0], args[0][1])
195 195 for obj,attr in args:
196 196 if getattr(obj, attr) != initial:
197 197 setattr(obj, attr, initial)
198 198
199 199 callback = self._make_closure(obj,attr)
200 200 obj.on_trait_change(callback, attr)
201 201 self.objects[(obj,attr)] = callback
202 202
203 203 @contextlib.contextmanager
204 204 def _busy_updating(self):
205 205 self.updating = True
206 206 try:
207 207 yield
208 208 finally:
209 209 self.updating = False
210 210
211 211 def _make_closure(self, sending_obj, sending_attr):
212 212 def update(name, old, new):
213 213 self._update(sending_obj, sending_attr, new)
214 214 return update
215 215
216 216 def _update(self, sending_obj, sending_attr, new):
217 217 if self.updating:
218 218 return
219 219 with self._busy_updating():
220 220 for obj,attr in self.objects.keys():
221 221 if obj is not sending_obj or attr != sending_attr:
222 222 setattr(obj, attr, new)
223 223
224 224 def unlink(self):
225 225 for key, callback in self.objects.items():
226 226 (obj,attr) = key
227 227 obj.on_trait_change(callback, attr, remove=True)
228 228
229 229 #-----------------------------------------------------------------------------
230 230 # Base TraitType for all traits
231 231 #-----------------------------------------------------------------------------
232 232
233 233
234 234 class TraitType(object):
235 235 """A base class for all trait descriptors.
236 236
237 237 Notes
238 238 -----
239 239 Our implementation of traits is based on Python's descriptor
240 240 prototol. This class is the base class for all such descriptors. The
241 241 only magic we use is a custom metaclass for the main :class:`HasTraits`
242 242 class that does the following:
243 243
244 244 1. Sets the :attr:`name` attribute of every :class:`TraitType`
245 245 instance in the class dict to the name of the attribute.
246 246 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
247 247 instance in the class dict to the *class* that declared the trait.
248 248 This is used by the :class:`This` trait to allow subclasses to
249 249 accept superclasses for :class:`This` values.
250 250 """
251 251
252 252
253 253 metadata = {}
254 254 default_value = Undefined
255 255 allow_none = False
256 256 info_text = 'any value'
257 257
258 258 def __init__(self, default_value=NoDefaultSpecified, allow_none=None, **metadata):
259 259 """Create a TraitType.
260 260 """
261 261 if default_value is not NoDefaultSpecified:
262 262 self.default_value = default_value
263 263 if allow_none is not None:
264 264 self.allow_none = allow_none
265 265
266 266 if len(metadata) > 0:
267 267 if len(self.metadata) > 0:
268 268 self._metadata = self.metadata.copy()
269 269 self._metadata.update(metadata)
270 270 else:
271 271 self._metadata = metadata
272 272 else:
273 273 self._metadata = self.metadata
274 274
275 275 self.init()
276 276
277 277 def init(self):
278 278 pass
279 279
280 280 def get_default_value(self):
281 281 """Create a new instance of the default value."""
282 282 return self.default_value
283 283
284 284 def instance_init(self, obj):
285 285 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
286 286
287 287 Some stages of initialization must be delayed until the parent
288 288 :class:`HasTraits` instance has been created. This method is
289 289 called in :meth:`HasTraits.__new__` after the instance has been
290 290 created.
291 291
292 292 This method trigger the creation and validation of default values
293 293 and also things like the resolution of str given class names in
294 294 :class:`Type` and :class`Instance`.
295 295
296 296 Parameters
297 297 ----------
298 298 obj : :class:`HasTraits` instance
299 299 The parent :class:`HasTraits` instance that has just been
300 300 created.
301 301 """
302 302 self.set_default_value(obj)
303 303
304 304 def set_default_value(self, obj):
305 305 """Set the default value on a per instance basis.
306 306
307 307 This method is called by :meth:`instance_init` to create and
308 308 validate the default value. The creation and validation of
309 309 default values must be delayed until the parent :class:`HasTraits`
310 310 class has been instantiated.
311 311 """
312 312 # Check for a deferred initializer defined in the same class as the
313 313 # trait declaration or above.
314 314 mro = type(obj).mro()
315 315 meth_name = '_%s_default' % self.name
316 316 for cls in mro[:mro.index(self.this_class)+1]:
317 317 if meth_name in cls.__dict__:
318 318 break
319 319 else:
320 320 # We didn't find one. Do static initialization.
321 321 dv = self.get_default_value()
322 322 newdv = self._validate(obj, dv)
323 323 obj._trait_values[self.name] = newdv
324 324 return
325 325 # Complete the dynamic initialization.
326 326 obj._trait_dyn_inits[self.name] = meth_name
327 327
328 328 def __get__(self, obj, cls=None):
329 329 """Get the value of the trait by self.name for the instance.
330 330
331 331 Default values are instantiated when :meth:`HasTraits.__new__`
332 332 is called. Thus by the time this method gets called either the
333 333 default value or a user defined value (they called :meth:`__set__`)
334 334 is in the :class:`HasTraits` instance.
335 335 """
336 336 if obj is None:
337 337 return self
338 338 else:
339 339 try:
340 340 value = obj._trait_values[self.name]
341 341 except KeyError:
342 342 # Check for a dynamic initializer.
343 343 if self.name in obj._trait_dyn_inits:
344 344 method = getattr(obj, obj._trait_dyn_inits[self.name])
345 345 value = method()
346 346 # FIXME: Do we really validate here?
347 347 value = self._validate(obj, value)
348 348 obj._trait_values[self.name] = value
349 349 return value
350 350 else:
351 351 raise TraitError('Unexpected error in TraitType: '
352 352 'both default value and dynamic initializer are '
353 353 'absent.')
354 354 except Exception:
355 355 # HasTraits should call set_default_value to populate
356 356 # this. So this should never be reached.
357 357 raise TraitError('Unexpected error in TraitType: '
358 358 'default value not set properly')
359 359 else:
360 360 return value
361 361
362 362 def __set__(self, obj, value):
363 363 new_value = self._validate(obj, value)
364 364 old_value = self.__get__(obj)
365 365 obj._trait_values[self.name] = new_value
366 366 try:
367 367 silent = bool(old_value == new_value)
368 368 except:
369 369 # if there is an error in comparing, default to notify
370 370 silent = False
371 371 if silent is not True:
372 372 # we explicitly compare silent to True just in case the equality
373 373 # comparison above returns something other than True/False
374 374 obj._notify_trait(self.name, old_value, new_value)
375 375
376 376 def _validate(self, obj, value):
377 377 if value is None and self.allow_none:
378 378 return value
379 379 if hasattr(self, 'validate'):
380 380 return self.validate(obj, value)
381 381 elif hasattr(self, 'is_valid_for'):
382 382 valid = self.is_valid_for(value)
383 383 if valid:
384 384 return value
385 385 else:
386 386 raise TraitError('invalid value for type: %r' % value)
387 387 elif hasattr(self, 'value_for'):
388 388 return self.value_for(value)
389 389 else:
390 390 return value
391 391
392 392 def info(self):
393 393 return self.info_text
394 394
395 395 def error(self, obj, value):
396 396 if obj is not None:
397 397 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
398 398 % (self.name, class_of(obj),
399 399 self.info(), repr_type(value))
400 400 else:
401 401 e = "The '%s' trait must be %s, but a value of %r was specified." \
402 402 % (self.name, self.info(), repr_type(value))
403 403 raise TraitError(e)
404 404
405 405 def get_metadata(self, key):
406 406 return getattr(self, '_metadata', {}).get(key, None)
407 407
408 408 def set_metadata(self, key, value):
409 409 getattr(self, '_metadata', {})[key] = value
410 410
411 411
412 412 #-----------------------------------------------------------------------------
413 413 # The HasTraits implementation
414 414 #-----------------------------------------------------------------------------
415 415
416 416
417 417 class MetaHasTraits(type):
418 418 """A metaclass for HasTraits.
419 419
420 420 This metaclass makes sure that any TraitType class attributes are
421 421 instantiated and sets their name attribute.
422 422 """
423 423
424 424 def __new__(mcls, name, bases, classdict):
425 425 """Create the HasTraits class.
426 426
427 427 This instantiates all TraitTypes in the class dict and sets their
428 428 :attr:`name` attribute.
429 429 """
430 430 # print "MetaHasTraitlets (mcls, name): ", mcls, name
431 431 # print "MetaHasTraitlets (bases): ", bases
432 432 # print "MetaHasTraitlets (classdict): ", classdict
433 433 for k,v in iteritems(classdict):
434 434 if isinstance(v, TraitType):
435 435 v.name = k
436 436 elif inspect.isclass(v):
437 437 if issubclass(v, TraitType):
438 438 vinst = v()
439 439 vinst.name = k
440 440 classdict[k] = vinst
441 441 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
442 442
443 443 def __init__(cls, name, bases, classdict):
444 444 """Finish initializing the HasTraits class.
445 445
446 446 This sets the :attr:`this_class` attribute of each TraitType in the
447 447 class dict to the newly created class ``cls``.
448 448 """
449 449 for k, v in iteritems(classdict):
450 450 if isinstance(v, TraitType):
451 451 v.this_class = cls
452 452 super(MetaHasTraits, cls).__init__(name, bases, classdict)
453 453
454 454 class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)):
455 455
456 456 def __new__(cls, *args, **kw):
457 457 # This is needed because object.__new__ only accepts
458 458 # the cls argument.
459 459 new_meth = super(HasTraits, cls).__new__
460 460 if new_meth is object.__new__:
461 461 inst = new_meth(cls)
462 462 else:
463 463 inst = new_meth(cls, **kw)
464 464 inst._trait_values = {}
465 465 inst._trait_notifiers = {}
466 466 inst._trait_dyn_inits = {}
467 467 # Here we tell all the TraitType instances to set their default
468 468 # values on the instance.
469 469 for key in dir(cls):
470 470 # Some descriptors raise AttributeError like zope.interface's
471 471 # __provides__ attributes even though they exist. This causes
472 472 # AttributeErrors even though they are listed in dir(cls).
473 473 try:
474 474 value = getattr(cls, key)
475 475 except AttributeError:
476 476 pass
477 477 else:
478 478 if isinstance(value, TraitType):
479 479 value.instance_init(inst)
480 480
481 481 return inst
482 482
483 483 def __init__(self, *args, **kw):
484 484 # Allow trait values to be set using keyword arguments.
485 485 # We need to use setattr for this to trigger validation and
486 486 # notifications.
487 487 for key, value in iteritems(kw):
488 488 setattr(self, key, value)
489 489
490 490 def _notify_trait(self, name, old_value, new_value):
491 491
492 492 # First dynamic ones
493 493 callables = []
494 494 callables.extend(self._trait_notifiers.get(name,[]))
495 495 callables.extend(self._trait_notifiers.get('anytrait',[]))
496 496
497 497 # Now static ones
498 498 try:
499 499 cb = getattr(self, '_%s_changed' % name)
500 500 except:
501 501 pass
502 502 else:
503 503 callables.append(cb)
504 504
505 505 # Call them all now
506 506 for c in callables:
507 507 # Traits catches and logs errors here. I allow them to raise
508 508 if callable(c):
509 509 argspec = inspect.getargspec(c)
510 510 nargs = len(argspec[0])
511 511 # Bound methods have an additional 'self' argument
512 512 # I don't know how to treat unbound methods, but they
513 513 # can't really be used for callbacks.
514 514 if isinstance(c, types.MethodType):
515 515 offset = -1
516 516 else:
517 517 offset = 0
518 518 if nargs + offset == 0:
519 519 c()
520 520 elif nargs + offset == 1:
521 521 c(name)
522 522 elif nargs + offset == 2:
523 523 c(name, new_value)
524 524 elif nargs + offset == 3:
525 525 c(name, old_value, new_value)
526 526 else:
527 527 raise TraitError('a trait changed callback '
528 528 'must have 0-3 arguments.')
529 529 else:
530 530 raise TraitError('a trait changed callback '
531 531 'must be callable.')
532 532
533 533
534 534 def _add_notifiers(self, handler, name):
535 535 if name not in self._trait_notifiers:
536 536 nlist = []
537 537 self._trait_notifiers[name] = nlist
538 538 else:
539 539 nlist = self._trait_notifiers[name]
540 540 if handler not in nlist:
541 541 nlist.append(handler)
542 542
543 543 def _remove_notifiers(self, handler, name):
544 544 if name in self._trait_notifiers:
545 545 nlist = self._trait_notifiers[name]
546 546 try:
547 547 index = nlist.index(handler)
548 548 except ValueError:
549 549 pass
550 550 else:
551 551 del nlist[index]
552 552
553 553 def on_trait_change(self, handler, name=None, remove=False):
554 554 """Setup a handler to be called when a trait changes.
555 555
556 556 This is used to setup dynamic notifications of trait changes.
557 557
558 558 Static handlers can be created by creating methods on a HasTraits
559 559 subclass with the naming convention '_[traitname]_changed'. Thus,
560 560 to create static handler for the trait 'a', create the method
561 561 _a_changed(self, name, old, new) (fewer arguments can be used, see
562 562 below).
563 563
564 564 Parameters
565 565 ----------
566 566 handler : callable
567 567 A callable that is called when a trait changes. Its
568 568 signature can be handler(), handler(name), handler(name, new)
569 569 or handler(name, old, new).
570 570 name : list, str, None
571 571 If None, the handler will apply to all traits. If a list
572 572 of str, handler will apply to all names in the list. If a
573 573 str, the handler will apply just to that name.
574 574 remove : bool
575 575 If False (the default), then install the handler. If True
576 576 then unintall it.
577 577 """
578 578 if remove:
579 579 names = parse_notifier_name(name)
580 580 for n in names:
581 581 self._remove_notifiers(handler, n)
582 582 else:
583 583 names = parse_notifier_name(name)
584 584 for n in names:
585 585 self._add_notifiers(handler, n)
586 586
587 587 @classmethod
588 588 def class_trait_names(cls, **metadata):
589 589 """Get a list of all the names of this class' traits.
590 590
591 591 This method is just like the :meth:`trait_names` method,
592 592 but is unbound.
593 593 """
594 594 return cls.class_traits(**metadata).keys()
595 595
596 596 @classmethod
597 597 def class_traits(cls, **metadata):
598 598 """Get a `dict` of all the traits of this class. The dictionary
599 599 is keyed on the name and the values are the TraitType objects.
600 600
601 601 This method is just like the :meth:`traits` method, but is unbound.
602 602
603 603 The TraitTypes returned don't know anything about the values
604 604 that the various HasTrait's instances are holding.
605 605
606 606 The metadata kwargs allow functions to be passed in which
607 607 filter traits based on metadata values. The functions should
608 608 take a single value as an argument and return a boolean. If
609 609 any function returns False, then the trait is not included in
610 610 the output. This does not allow for any simple way of
611 611 testing that a metadata name exists and has any
612 612 value because get_metadata returns None if a metadata key
613 613 doesn't exist.
614 614 """
615 615 traits = dict([memb for memb in getmembers(cls) if
616 616 isinstance(memb[1], TraitType)])
617 617
618 618 if len(metadata) == 0:
619 619 return traits
620 620
621 621 for meta_name, meta_eval in metadata.items():
622 622 if type(meta_eval) is not FunctionType:
623 623 metadata[meta_name] = _SimpleTest(meta_eval)
624 624
625 625 result = {}
626 626 for name, trait in traits.items():
627 627 for meta_name, meta_eval in metadata.items():
628 628 if not meta_eval(trait.get_metadata(meta_name)):
629 629 break
630 630 else:
631 631 result[name] = trait
632 632
633 633 return result
634 634
635 635 def trait_names(self, **metadata):
636 636 """Get a list of all the names of this class' traits."""
637 637 return self.traits(**metadata).keys()
638 638
639 639 def traits(self, **metadata):
640 640 """Get a `dict` of all the traits of this class. The dictionary
641 641 is keyed on the name and the values are the TraitType objects.
642 642
643 643 The TraitTypes returned don't know anything about the values
644 644 that the various HasTrait's instances are holding.
645 645
646 646 The metadata kwargs allow functions to be passed in which
647 647 filter traits based on metadata values. The functions should
648 648 take a single value as an argument and return a boolean. If
649 649 any function returns False, then the trait is not included in
650 650 the output. This does not allow for any simple way of
651 651 testing that a metadata name exists and has any
652 652 value because get_metadata returns None if a metadata key
653 653 doesn't exist.
654 654 """
655 655 traits = dict([memb for memb in getmembers(self.__class__) if
656 656 isinstance(memb[1], TraitType)])
657 657
658 658 if len(metadata) == 0:
659 659 return traits
660 660
661 661 for meta_name, meta_eval in metadata.items():
662 662 if type(meta_eval) is not FunctionType:
663 663 metadata[meta_name] = _SimpleTest(meta_eval)
664 664
665 665 result = {}
666 666 for name, trait in traits.items():
667 667 for meta_name, meta_eval in metadata.items():
668 668 if not meta_eval(trait.get_metadata(meta_name)):
669 669 break
670 670 else:
671 671 result[name] = trait
672 672
673 673 return result
674 674
675 675 def trait_metadata(self, traitname, key):
676 676 """Get metadata values for trait by key."""
677 677 try:
678 678 trait = getattr(self.__class__, traitname)
679 679 except AttributeError:
680 680 raise TraitError("Class %s does not have a trait named %s" %
681 681 (self.__class__.__name__, traitname))
682 682 else:
683 683 return trait.get_metadata(key)
684 684
685 685 #-----------------------------------------------------------------------------
686 686 # Actual TraitTypes implementations/subclasses
687 687 #-----------------------------------------------------------------------------
688 688
689 689 #-----------------------------------------------------------------------------
690 690 # TraitTypes subclasses for handling classes and instances of classes
691 691 #-----------------------------------------------------------------------------
692 692
693 693
694 694 class ClassBasedTraitType(TraitType):
695 695 """A trait with error reporting for Type, Instance and This."""
696 696
697 697 def error(self, obj, value):
698 698 kind = type(value)
699 699 if (not py3compat.PY3) and kind is InstanceType:
700 700 msg = 'class %s' % value.__class__.__name__
701 701 else:
702 702 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
703 703
704 704 if obj is not None:
705 705 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
706 706 % (self.name, class_of(obj),
707 707 self.info(), msg)
708 708 else:
709 709 e = "The '%s' trait must be %s, but a value of %r was specified." \
710 710 % (self.name, self.info(), msg)
711 711
712 712 raise TraitError(e)
713 713
714 714
715 715 class Type(ClassBasedTraitType):
716 716 """A trait whose value must be a subclass of a specified class."""
717 717
718 718 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
719 719 """Construct a Type trait
720 720
721 721 A Type trait specifies that its values must be subclasses of
722 722 a particular class.
723 723
724 724 If only ``default_value`` is given, it is used for the ``klass`` as
725 725 well.
726 726
727 727 Parameters
728 728 ----------
729 729 default_value : class, str or None
730 730 The default value must be a subclass of klass. If an str,
731 731 the str must be a fully specified class name, like 'foo.bar.Bah'.
732 732 The string is resolved into real class, when the parent
733 733 :class:`HasTraits` class is instantiated.
734 734 klass : class, str, None
735 735 Values of this trait must be a subclass of klass. The klass
736 736 may be specified in a string like: 'foo.bar.MyClass'.
737 737 The string is resolved into real class, when the parent
738 738 :class:`HasTraits` class is instantiated.
739 739 allow_none : boolean
740 740 Indicates whether None is allowed as an assignable value. Even if
741 741 ``False``, the default value may be ``None``.
742 742 """
743 743 if default_value is None:
744 744 if klass is None:
745 745 klass = object
746 746 elif klass is None:
747 747 klass = default_value
748 748
749 749 if not (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
750 750 raise TraitError("A Type trait must specify a class.")
751 751
752 752 self.klass = klass
753 753
754 754 super(Type, self).__init__(default_value, allow_none=allow_none, **metadata)
755 755
756 756 def validate(self, obj, value):
757 757 """Validates that the value is a valid object instance."""
758 758 if isinstance(value, py3compat.string_types):
759 759 try:
760 760 value = import_item(value)
761 761 except ImportError:
762 762 raise TraitError("The '%s' trait of %s instance must be a type, but "
763 763 "%r could not be imported" % (self.name, obj, value))
764 764 try:
765 765 if issubclass(value, self.klass):
766 766 return value
767 767 except:
768 768 pass
769 769
770 770 self.error(obj, value)
771 771
772 772 def info(self):
773 773 """ Returns a description of the trait."""
774 774 if isinstance(self.klass, py3compat.string_types):
775 775 klass = self.klass
776 776 else:
777 777 klass = self.klass.__name__
778 778 result = 'a subclass of ' + klass
779 779 if self.allow_none:
780 780 return result + ' or None'
781 781 return result
782 782
783 783 def instance_init(self, obj):
784 784 self._resolve_classes()
785 785 super(Type, self).instance_init(obj)
786 786
787 787 def _resolve_classes(self):
788 788 if isinstance(self.klass, py3compat.string_types):
789 789 self.klass = import_item(self.klass)
790 790 if isinstance(self.default_value, py3compat.string_types):
791 791 self.default_value = import_item(self.default_value)
792 792
793 793 def get_default_value(self):
794 794 return self.default_value
795 795
796 796
797 797 class DefaultValueGenerator(object):
798 798 """A class for generating new default value instances."""
799 799
800 800 def __init__(self, *args, **kw):
801 801 self.args = args
802 802 self.kw = kw
803 803
804 804 def generate(self, klass):
805 805 return klass(*self.args, **self.kw)
806 806
807 807
808 808 class Instance(ClassBasedTraitType):
809 809 """A trait whose value must be an instance of a specified class.
810 810
811 811 The value can also be an instance of a subclass of the specified class.
812
813 Subclasses can declare default classes by overriding the klass attribute
812 814 """
813 815
816 klass = None
817
814 818 def __init__(self, klass=None, args=None, kw=None,
815 819 allow_none=True, **metadata ):
816 820 """Construct an Instance trait.
817 821
818 822 This trait allows values that are instances of a particular
819 823 class or its sublclasses. Our implementation is quite different
820 824 from that of enthough.traits as we don't allow instances to be used
821 825 for klass and we handle the ``args`` and ``kw`` arguments differently.
822 826
823 827 Parameters
824 828 ----------
825 829 klass : class, str
826 830 The class that forms the basis for the trait. Class names
827 831 can also be specified as strings, like 'foo.bar.Bar'.
828 832 args : tuple
829 833 Positional arguments for generating the default value.
830 834 kw : dict
831 835 Keyword arguments for generating the default value.
832 836 allow_none : bool
833 837 Indicates whether None is allowed as a value.
834 838
835 839 Notes
836 840 -----
837 841 If both ``args`` and ``kw`` are None, then the default value is None.
838 842 If ``args`` is a tuple and ``kw`` is a dict, then the default is
839 created as ``klass(*args, **kw)``. If either ``args`` or ``kw`` is
840 not (but not both), None is replace by ``()`` or ``{}``.
843 created as ``klass(*args, **kw)``. If exactly one of ``args`` or ``kw`` is
844 None, the None is replaced by ``()`` or ``{}``, respectively.
841 845 """
842
843 if (klass is None) or (not (inspect.isclass(klass) or isinstance(klass, py3compat.string_types))):
844 raise TraitError('The klass argument must be a class'
845 ' you gave: %r' % klass)
846 self.klass = klass
846 if klass is None:
847 klass = self.klass
848
849 if (klass is not None) and (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
850 self.klass = klass
851 else:
852 raise TraitError('The klass attribute must be a class'
853 ' not: %r' % klass)
847 854
848 855 # self.klass is a class, so handle default_value
849 856 if args is None and kw is None:
850 857 default_value = None
851 858 else:
852 859 if args is None:
853 860 # kw is not None
854 861 args = ()
855 862 elif kw is None:
856 863 # args is not None
857 864 kw = {}
858 865
859 866 if not isinstance(kw, dict):
860 867 raise TraitError("The 'kw' argument must be a dict or None.")
861 868 if not isinstance(args, tuple):
862 869 raise TraitError("The 'args' argument must be a tuple or None.")
863 870
864 871 default_value = DefaultValueGenerator(*args, **kw)
865 872
866 873 super(Instance, self).__init__(default_value, allow_none=allow_none, **metadata)
867 874
868 875 def validate(self, obj, value):
869 876 if isinstance(value, self.klass):
870 877 return value
871 878 else:
872 879 self.error(obj, value)
873 880
874 881 def info(self):
875 882 if isinstance(self.klass, py3compat.string_types):
876 883 klass = self.klass
877 884 else:
878 885 klass = self.klass.__name__
879 886 result = class_of(klass)
880 887 if self.allow_none:
881 888 return result + ' or None'
882 889
883 890 return result
884 891
885 892 def instance_init(self, obj):
886 893 self._resolve_classes()
887 894 super(Instance, self).instance_init(obj)
888 895
889 896 def _resolve_classes(self):
890 897 if isinstance(self.klass, py3compat.string_types):
891 898 self.klass = import_item(self.klass)
892 899
893 900 def get_default_value(self):
894 901 """Instantiate a default value instance.
895 902
896 903 This is called when the containing HasTraits classes'
897 904 :meth:`__new__` method is called to ensure that a unique instance
898 905 is created for each HasTraits instance.
899 906 """
900 907 dv = self.default_value
901 908 if isinstance(dv, DefaultValueGenerator):
902 909 return dv.generate(self.klass)
903 910 else:
904 911 return dv
905 912
906 913
907 914 class This(ClassBasedTraitType):
908 915 """A trait for instances of the class containing this trait.
909 916
910 917 Because how how and when class bodies are executed, the ``This``
911 918 trait can only have a default value of None. This, and because we
912 919 always validate default values, ``allow_none`` is *always* true.
913 920 """
914 921
915 922 info_text = 'an instance of the same type as the receiver or None'
916 923
917 924 def __init__(self, **metadata):
918 925 super(This, self).__init__(None, **metadata)
919 926
920 927 def validate(self, obj, value):
921 928 # What if value is a superclass of obj.__class__? This is
922 929 # complicated if it was the superclass that defined the This
923 930 # trait.
924 931 if isinstance(value, self.this_class) or (value is None):
925 932 return value
926 933 else:
927 934 self.error(obj, value)
928 935
929 936
930 937 #-----------------------------------------------------------------------------
931 938 # Basic TraitTypes implementations/subclasses
932 939 #-----------------------------------------------------------------------------
933 940
934 941
935 942 class Any(TraitType):
936 943 default_value = None
937 944 info_text = 'any value'
938 945
939 946
940 947 class Int(TraitType):
941 948 """An int trait."""
942 949
943 950 default_value = 0
944 951 info_text = 'an int'
945 952
946 953 def validate(self, obj, value):
947 954 if isinstance(value, int):
948 955 return value
949 956 self.error(obj, value)
950 957
951 958 class CInt(Int):
952 959 """A casting version of the int trait."""
953 960
954 961 def validate(self, obj, value):
955 962 try:
956 963 return int(value)
957 964 except:
958 965 self.error(obj, value)
959 966
960 967 if py3compat.PY3:
961 968 Long, CLong = Int, CInt
962 969 Integer = Int
963 970 else:
964 971 class Long(TraitType):
965 972 """A long integer trait."""
966 973
967 974 default_value = 0
968 975 info_text = 'a long'
969 976
970 977 def validate(self, obj, value):
971 978 if isinstance(value, long):
972 979 return value
973 980 if isinstance(value, int):
974 981 return long(value)
975 982 self.error(obj, value)
976 983
977 984
978 985 class CLong(Long):
979 986 """A casting version of the long integer trait."""
980 987
981 988 def validate(self, obj, value):
982 989 try:
983 990 return long(value)
984 991 except:
985 992 self.error(obj, value)
986 993
987 994 class Integer(TraitType):
988 995 """An integer trait.
989 996
990 997 Longs that are unnecessary (<= sys.maxint) are cast to ints."""
991 998
992 999 default_value = 0
993 1000 info_text = 'an integer'
994 1001
995 1002 def validate(self, obj, value):
996 1003 if isinstance(value, int):
997 1004 return value
998 1005 if isinstance(value, long):
999 1006 # downcast longs that fit in int:
1000 1007 # note that int(n > sys.maxint) returns a long, so
1001 1008 # we don't need a condition on this cast
1002 1009 return int(value)
1003 1010 if sys.platform == "cli":
1004 1011 from System import Int64
1005 1012 if isinstance(value, Int64):
1006 1013 return int(value)
1007 1014 self.error(obj, value)
1008 1015
1009 1016
1010 1017 class Float(TraitType):
1011 1018 """A float trait."""
1012 1019
1013 1020 default_value = 0.0
1014 1021 info_text = 'a float'
1015 1022
1016 1023 def validate(self, obj, value):
1017 1024 if isinstance(value, float):
1018 1025 return value
1019 1026 if isinstance(value, int):
1020 1027 return float(value)
1021 1028 self.error(obj, value)
1022 1029
1023 1030
1024 1031 class CFloat(Float):
1025 1032 """A casting version of the float trait."""
1026 1033
1027 1034 def validate(self, obj, value):
1028 1035 try:
1029 1036 return float(value)
1030 1037 except:
1031 1038 self.error(obj, value)
1032 1039
1033 1040 class Complex(TraitType):
1034 1041 """A trait for complex numbers."""
1035 1042
1036 1043 default_value = 0.0 + 0.0j
1037 1044 info_text = 'a complex number'
1038 1045
1039 1046 def validate(self, obj, value):
1040 1047 if isinstance(value, complex):
1041 1048 return value
1042 1049 if isinstance(value, (float, int)):
1043 1050 return complex(value)
1044 1051 self.error(obj, value)
1045 1052
1046 1053
1047 1054 class CComplex(Complex):
1048 1055 """A casting version of the complex number trait."""
1049 1056
1050 1057 def validate (self, obj, value):
1051 1058 try:
1052 1059 return complex(value)
1053 1060 except:
1054 1061 self.error(obj, value)
1055 1062
1056 1063 # We should always be explicit about whether we're using bytes or unicode, both
1057 1064 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
1058 1065 # we don't have a Str type.
1059 1066 class Bytes(TraitType):
1060 1067 """A trait for byte strings."""
1061 1068
1062 1069 default_value = b''
1063 1070 info_text = 'a bytes object'
1064 1071
1065 1072 def validate(self, obj, value):
1066 1073 if isinstance(value, bytes):
1067 1074 return value
1068 1075 self.error(obj, value)
1069 1076
1070 1077
1071 1078 class CBytes(Bytes):
1072 1079 """A casting version of the byte string trait."""
1073 1080
1074 1081 def validate(self, obj, value):
1075 1082 try:
1076 1083 return bytes(value)
1077 1084 except:
1078 1085 self.error(obj, value)
1079 1086
1080 1087
1081 1088 class Unicode(TraitType):
1082 1089 """A trait for unicode strings."""
1083 1090
1084 1091 default_value = u''
1085 1092 info_text = 'a unicode string'
1086 1093
1087 1094 def validate(self, obj, value):
1088 1095 if isinstance(value, py3compat.unicode_type):
1089 1096 return value
1090 1097 if isinstance(value, bytes):
1091 1098 try:
1092 1099 return value.decode('ascii', 'strict')
1093 1100 except UnicodeDecodeError:
1094 1101 msg = "Could not decode {!r} for unicode trait '{}' of {} instance."
1095 1102 raise TraitError(msg.format(value, self.name, class_of(obj)))
1096 1103 self.error(obj, value)
1097 1104
1098 1105
1099 1106 class CUnicode(Unicode):
1100 1107 """A casting version of the unicode trait."""
1101 1108
1102 1109 def validate(self, obj, value):
1103 1110 try:
1104 1111 return py3compat.unicode_type(value)
1105 1112 except:
1106 1113 self.error(obj, value)
1107 1114
1108 1115
1109 1116 class ObjectName(TraitType):
1110 1117 """A string holding a valid object name in this version of Python.
1111 1118
1112 1119 This does not check that the name exists in any scope."""
1113 1120 info_text = "a valid object identifier in Python"
1114 1121
1115 1122 if py3compat.PY3:
1116 1123 # Python 3:
1117 1124 coerce_str = staticmethod(lambda _,s: s)
1118 1125
1119 1126 else:
1120 1127 # Python 2:
1121 1128 def coerce_str(self, obj, value):
1122 1129 "In Python 2, coerce ascii-only unicode to str"
1123 1130 if isinstance(value, unicode):
1124 1131 try:
1125 1132 return str(value)
1126 1133 except UnicodeEncodeError:
1127 1134 self.error(obj, value)
1128 1135 return value
1129 1136
1130 1137 def validate(self, obj, value):
1131 1138 value = self.coerce_str(obj, value)
1132 1139
1133 1140 if isinstance(value, str) and py3compat.isidentifier(value):
1134 1141 return value
1135 1142 self.error(obj, value)
1136 1143
1137 1144 class DottedObjectName(ObjectName):
1138 1145 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1139 1146 def validate(self, obj, value):
1140 1147 value = self.coerce_str(obj, value)
1141 1148
1142 1149 if isinstance(value, str) and py3compat.isidentifier(value, dotted=True):
1143 1150 return value
1144 1151 self.error(obj, value)
1145 1152
1146 1153
1147 1154 class Bool(TraitType):
1148 1155 """A boolean (True, False) trait."""
1149 1156
1150 1157 default_value = False
1151 1158 info_text = 'a boolean'
1152 1159
1153 1160 def validate(self, obj, value):
1154 1161 if isinstance(value, bool):
1155 1162 return value
1156 1163 self.error(obj, value)
1157 1164
1158 1165
1159 1166 class CBool(Bool):
1160 1167 """A casting version of the boolean trait."""
1161 1168
1162 1169 def validate(self, obj, value):
1163 1170 try:
1164 1171 return bool(value)
1165 1172 except:
1166 1173 self.error(obj, value)
1167 1174
1168 1175
1169 1176 class Enum(TraitType):
1170 1177 """An enum that whose value must be in a given sequence."""
1171 1178
1172 1179 def __init__(self, values, default_value=None, allow_none=True, **metadata):
1173 1180 self.values = values
1174 1181 super(Enum, self).__init__(default_value, allow_none=allow_none, **metadata)
1175 1182
1176 1183 def validate(self, obj, value):
1177 1184 if value in self.values:
1178 1185 return value
1179 1186 self.error(obj, value)
1180 1187
1181 1188 def info(self):
1182 1189 """ Returns a description of the trait."""
1183 1190 result = 'any of ' + repr(self.values)
1184 1191 if self.allow_none:
1185 1192 return result + ' or None'
1186 1193 return result
1187 1194
1188 1195 class CaselessStrEnum(Enum):
1189 1196 """An enum of strings that are caseless in validate."""
1190 1197
1191 1198 def validate(self, obj, value):
1192 1199 if not isinstance(value, py3compat.string_types):
1193 1200 self.error(obj, value)
1194 1201
1195 1202 for v in self.values:
1196 1203 if v.lower() == value.lower():
1197 1204 return v
1198 1205 self.error(obj, value)
1199 1206
1200 1207 class Container(Instance):
1201 1208 """An instance of a container (list, set, etc.)
1202 1209
1203 1210 To be subclassed by overriding klass.
1204 1211 """
1205 1212 klass = None
1206 1213 _cast_types = ()
1207 1214 _valid_defaults = SequenceTypes
1208 1215 _trait = None
1209 1216
1210 1217 def __init__(self, trait=None, default_value=None, allow_none=True,
1211 1218 **metadata):
1212 1219 """Create a container trait type from a list, set, or tuple.
1213 1220
1214 1221 The default value is created by doing ``List(default_value)``,
1215 1222 which creates a copy of the ``default_value``.
1216 1223
1217 1224 ``trait`` can be specified, which restricts the type of elements
1218 1225 in the container to that TraitType.
1219 1226
1220 1227 If only one arg is given and it is not a Trait, it is taken as
1221 1228 ``default_value``:
1222 1229
1223 1230 ``c = List([1,2,3])``
1224 1231
1225 1232 Parameters
1226 1233 ----------
1227 1234
1228 1235 trait : TraitType [ optional ]
1229 1236 the type for restricting the contents of the Container. If unspecified,
1230 1237 types are not checked.
1231 1238
1232 1239 default_value : SequenceType [ optional ]
1233 1240 The default value for the Trait. Must be list/tuple/set, and
1234 1241 will be cast to the container type.
1235 1242
1236 1243 allow_none : Bool [ default True ]
1237 1244 Whether to allow the value to be None
1238 1245
1239 1246 **metadata : any
1240 1247 further keys for extensions to the Trait (e.g. config)
1241 1248
1242 1249 """
1243 1250 # allow List([values]):
1244 1251 if default_value is None and not is_trait(trait):
1245 1252 default_value = trait
1246 1253 trait = None
1247 1254
1248 1255 if default_value is None:
1249 1256 args = ()
1250 1257 elif isinstance(default_value, self._valid_defaults):
1251 1258 args = (default_value,)
1252 1259 else:
1253 1260 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1254 1261
1255 1262 if is_trait(trait):
1256 1263 self._trait = trait() if isinstance(trait, type) else trait
1257 1264 self._trait.name = 'element'
1258 1265 elif trait is not None:
1259 1266 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1260 1267
1261 1268 super(Container,self).__init__(klass=self.klass, args=args,
1262 1269 allow_none=allow_none, **metadata)
1263 1270
1264 1271 def element_error(self, obj, element, validator):
1265 1272 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1266 1273 % (self.name, class_of(obj), validator.info(), repr_type(element))
1267 1274 raise TraitError(e)
1268 1275
1269 1276 def validate(self, obj, value):
1270 1277 if isinstance(value, self._cast_types):
1271 1278 value = self.klass(value)
1272 1279 value = super(Container, self).validate(obj, value)
1273 1280 if value is None:
1274 1281 return value
1275 1282
1276 1283 value = self.validate_elements(obj, value)
1277 1284
1278 1285 return value
1279 1286
1280 1287 def validate_elements(self, obj, value):
1281 1288 validated = []
1282 1289 if self._trait is None or isinstance(self._trait, Any):
1283 1290 return value
1284 1291 for v in value:
1285 1292 try:
1286 1293 v = self._trait._validate(obj, v)
1287 1294 except TraitError:
1288 1295 self.element_error(obj, v, self._trait)
1289 1296 else:
1290 1297 validated.append(v)
1291 1298 return self.klass(validated)
1292 1299
1293 1300 def instance_init(self, obj):
1294 1301 if isinstance(self._trait, Instance):
1295 1302 self._trait._resolve_classes()
1296 1303 super(Container, self).instance_init(obj)
1297 1304
1298 1305
1299 1306 class List(Container):
1300 1307 """An instance of a Python list."""
1301 1308 klass = list
1302 1309 _cast_types = (tuple,)
1303 1310
1304 1311 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxsize,
1305 1312 allow_none=True, **metadata):
1306 1313 """Create a List trait type from a list, set, or tuple.
1307 1314
1308 1315 The default value is created by doing ``List(default_value)``,
1309 1316 which creates a copy of the ``default_value``.
1310 1317
1311 1318 ``trait`` can be specified, which restricts the type of elements
1312 1319 in the container to that TraitType.
1313 1320
1314 1321 If only one arg is given and it is not a Trait, it is taken as
1315 1322 ``default_value``:
1316 1323
1317 1324 ``c = List([1,2,3])``
1318 1325
1319 1326 Parameters
1320 1327 ----------
1321 1328
1322 1329 trait : TraitType [ optional ]
1323 1330 the type for restricting the contents of the Container. If unspecified,
1324 1331 types are not checked.
1325 1332
1326 1333 default_value : SequenceType [ optional ]
1327 1334 The default value for the Trait. Must be list/tuple/set, and
1328 1335 will be cast to the container type.
1329 1336
1330 1337 minlen : Int [ default 0 ]
1331 1338 The minimum length of the input list
1332 1339
1333 1340 maxlen : Int [ default sys.maxsize ]
1334 1341 The maximum length of the input list
1335 1342
1336 1343 allow_none : Bool [ default True ]
1337 1344 Whether to allow the value to be None
1338 1345
1339 1346 **metadata : any
1340 1347 further keys for extensions to the Trait (e.g. config)
1341 1348
1342 1349 """
1343 1350 self._minlen = minlen
1344 1351 self._maxlen = maxlen
1345 1352 super(List, self).__init__(trait=trait, default_value=default_value,
1346 1353 allow_none=allow_none, **metadata)
1347 1354
1348 1355 def length_error(self, obj, value):
1349 1356 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1350 1357 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1351 1358 raise TraitError(e)
1352 1359
1353 1360 def validate_elements(self, obj, value):
1354 1361 length = len(value)
1355 1362 if length < self._minlen or length > self._maxlen:
1356 1363 self.length_error(obj, value)
1357 1364
1358 1365 return super(List, self).validate_elements(obj, value)
1359 1366
1360 1367 def validate(self, obj, value):
1361 1368 value = super(List, self).validate(obj, value)
1362 1369
1363 1370 value = self.validate_elements(obj, value)
1364 1371
1365 1372 return value
1366 1373
1367 1374
1368 1375
1369 1376 class Set(List):
1370 1377 """An instance of a Python set."""
1371 1378 klass = set
1372 1379 _cast_types = (tuple, list)
1373 1380
1374 1381 class Tuple(Container):
1375 1382 """An instance of a Python tuple."""
1376 1383 klass = tuple
1377 1384 _cast_types = (list,)
1378 1385
1379 1386 def __init__(self, *traits, **metadata):
1380 1387 """Tuple(*traits, default_value=None, allow_none=True, **medatata)
1381 1388
1382 1389 Create a tuple from a list, set, or tuple.
1383 1390
1384 1391 Create a fixed-type tuple with Traits:
1385 1392
1386 1393 ``t = Tuple(Int, Str, CStr)``
1387 1394
1388 1395 would be length 3, with Int,Str,CStr for each element.
1389 1396
1390 1397 If only one arg is given and it is not a Trait, it is taken as
1391 1398 default_value:
1392 1399
1393 1400 ``t = Tuple((1,2,3))``
1394 1401
1395 1402 Otherwise, ``default_value`` *must* be specified by keyword.
1396 1403
1397 1404 Parameters
1398 1405 ----------
1399 1406
1400 1407 *traits : TraitTypes [ optional ]
1401 1408 the tsype for restricting the contents of the Tuple. If unspecified,
1402 1409 types are not checked. If specified, then each positional argument
1403 1410 corresponds to an element of the tuple. Tuples defined with traits
1404 1411 are of fixed length.
1405 1412
1406 1413 default_value : SequenceType [ optional ]
1407 1414 The default value for the Tuple. Must be list/tuple/set, and
1408 1415 will be cast to a tuple. If `traits` are specified, the
1409 1416 `default_value` must conform to the shape and type they specify.
1410 1417
1411 1418 allow_none : Bool [ default True ]
1412 1419 Whether to allow the value to be None
1413 1420
1414 1421 **metadata : any
1415 1422 further keys for extensions to the Trait (e.g. config)
1416 1423
1417 1424 """
1418 1425 default_value = metadata.pop('default_value', None)
1419 1426 allow_none = metadata.pop('allow_none', True)
1420 1427
1421 1428 # allow Tuple((values,)):
1422 1429 if len(traits) == 1 and default_value is None and not is_trait(traits[0]):
1423 1430 default_value = traits[0]
1424 1431 traits = ()
1425 1432
1426 1433 if default_value is None:
1427 1434 args = ()
1428 1435 elif isinstance(default_value, self._valid_defaults):
1429 1436 args = (default_value,)
1430 1437 else:
1431 1438 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1432 1439
1433 1440 self._traits = []
1434 1441 for trait in traits:
1435 1442 t = trait() if isinstance(trait, type) else trait
1436 1443 t.name = 'element'
1437 1444 self._traits.append(t)
1438 1445
1439 1446 if self._traits and default_value is None:
1440 1447 # don't allow default to be an empty container if length is specified
1441 1448 args = None
1442 1449 super(Container,self).__init__(klass=self.klass, args=args,
1443 1450 allow_none=allow_none, **metadata)
1444 1451
1445 1452 def validate_elements(self, obj, value):
1446 1453 if not self._traits:
1447 1454 # nothing to validate
1448 1455 return value
1449 1456 if len(value) != len(self._traits):
1450 1457 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1451 1458 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1452 1459 raise TraitError(e)
1453 1460
1454 1461 validated = []
1455 1462 for t,v in zip(self._traits, value):
1456 1463 try:
1457 1464 v = t._validate(obj, v)
1458 1465 except TraitError:
1459 1466 self.element_error(obj, v, t)
1460 1467 else:
1461 1468 validated.append(v)
1462 1469 return tuple(validated)
1463 1470
1464 1471
1465 1472 class Dict(Instance):
1466 1473 """An instance of a Python dict."""
1467 1474
1468 1475 def __init__(self, default_value=None, allow_none=True, **metadata):
1469 1476 """Create a dict trait type from a dict.
1470 1477
1471 1478 The default value is created by doing ``dict(default_value)``,
1472 1479 which creates a copy of the ``default_value``.
1473 1480 """
1474 1481 if default_value is None:
1475 1482 args = ((),)
1476 1483 elif isinstance(default_value, dict):
1477 1484 args = (default_value,)
1478 1485 elif isinstance(default_value, SequenceTypes):
1479 1486 args = (default_value,)
1480 1487 else:
1481 1488 raise TypeError('default value of Dict was %s' % default_value)
1482 1489
1483 1490 super(Dict,self).__init__(klass=dict, args=args,
1484 1491 allow_none=allow_none, **metadata)
1485 1492
1486 1493 class TCPAddress(TraitType):
1487 1494 """A trait for an (ip, port) tuple.
1488 1495
1489 1496 This allows for both IPv4 IP addresses as well as hostnames.
1490 1497 """
1491 1498
1492 1499 default_value = ('127.0.0.1', 0)
1493 1500 info_text = 'an (ip, port) tuple'
1494 1501
1495 1502 def validate(self, obj, value):
1496 1503 if isinstance(value, tuple):
1497 1504 if len(value) == 2:
1498 1505 if isinstance(value[0], py3compat.string_types) and isinstance(value[1], int):
1499 1506 port = value[1]
1500 1507 if port >= 0 and port <= 65535:
1501 1508 return value
1502 1509 self.error(obj, value)
1503 1510
1504 1511 class CRegExp(TraitType):
1505 1512 """A casting compiled regular expression trait.
1506 1513
1507 1514 Accepts both strings and compiled regular expressions. The resulting
1508 1515 attribute will be a compiled regular expression."""
1509 1516
1510 1517 info_text = 'a regular expression'
1511 1518
1512 1519 def validate(self, obj, value):
1513 1520 try:
1514 1521 return re.compile(value)
1515 1522 except:
1516 1523 self.error(obj, value)
General Comments 0
You need to be logged in to leave comments. Login now