##// END OF EJS Templates
Merge pull request #6128 from jasongrout/widget-trait-serialization...
Brian E. Granger -
r17330:3c3d1ae5 merge
parent child Browse files
Show More
@@ -1,482 +1,484
1 // Copyright (c) IPython Development Team.
1 // Copyright (c) IPython Development Team.
2 // Distributed under the terms of the Modified BSD License.
2 // Distributed under the terms of the Modified BSD License.
3
3
4 define(["widgets/js/manager",
4 define(["widgets/js/manager",
5 "underscore",
5 "underscore",
6 "backbone",
6 "backbone",
7 "jquery",
7 "jquery",
8 "base/js/namespace",
8 "base/js/namespace",
9 ], function(widgetmanager, _, Backbone, $, IPython){
9 ], function(widgetmanager, _, Backbone, $, IPython){
10
10
11 var WidgetModel = Backbone.Model.extend({
11 var WidgetModel = Backbone.Model.extend({
12 constructor: function (widget_manager, model_id, comm) {
12 constructor: function (widget_manager, model_id, comm) {
13 // Constructor
13 // Constructor
14 //
14 //
15 // Creates a WidgetModel instance.
15 // Creates a WidgetModel instance.
16 //
16 //
17 // Parameters
17 // Parameters
18 // ----------
18 // ----------
19 // widget_manager : WidgetManager instance
19 // widget_manager : WidgetManager instance
20 // model_id : string
20 // model_id : string
21 // An ID unique to this model.
21 // An ID unique to this model.
22 // comm : Comm instance (optional)
22 // comm : Comm instance (optional)
23 this.widget_manager = widget_manager;
23 this.widget_manager = widget_manager;
24 this._buffered_state_diff = {};
24 this._buffered_state_diff = {};
25 this.pending_msgs = 0;
25 this.pending_msgs = 0;
26 this.msg_buffer = null;
26 this.msg_buffer = null;
27 this.key_value_lock = null;
27 this.key_value_lock = null;
28 this.id = model_id;
28 this.id = model_id;
29 this.views = [];
29 this.views = [];
30
30
31 if (comm !== undefined) {
31 if (comm !== undefined) {
32 // Remember comm associated with the model.
32 // Remember comm associated with the model.
33 this.comm = comm;
33 this.comm = comm;
34 comm.model = this;
34 comm.model = this;
35
35
36 // Hook comm messages up to model.
36 // Hook comm messages up to model.
37 comm.on_close($.proxy(this._handle_comm_closed, this));
37 comm.on_close($.proxy(this._handle_comm_closed, this));
38 comm.on_msg($.proxy(this._handle_comm_msg, this));
38 comm.on_msg($.proxy(this._handle_comm_msg, this));
39 }
39 }
40 return Backbone.Model.apply(this);
40 return Backbone.Model.apply(this);
41 },
41 },
42
42
43 send: function (content, callbacks) {
43 send: function (content, callbacks) {
44 // Send a custom msg over the comm.
44 // Send a custom msg over the comm.
45 if (this.comm !== undefined) {
45 if (this.comm !== undefined) {
46 var data = {method: 'custom', content: content};
46 var data = {method: 'custom', content: content};
47 this.comm.send(data, callbacks);
47 this.comm.send(data, callbacks);
48 this.pending_msgs++;
48 this.pending_msgs++;
49 }
49 }
50 },
50 },
51
51
52 _handle_comm_closed: function (msg) {
52 _handle_comm_closed: function (msg) {
53 // Handle when a widget is closed.
53 // Handle when a widget is closed.
54 this.trigger('comm:close');
54 this.trigger('comm:close');
55 delete this.comm.model; // Delete ref so GC will collect widget model.
55 delete this.comm.model; // Delete ref so GC will collect widget model.
56 delete this.comm;
56 delete this.comm;
57 delete this.model_id; // Delete id from model so widget manager cleans up.
57 delete this.model_id; // Delete id from model so widget manager cleans up.
58 _.each(this.views, function(view, i) {
58 _.each(this.views, function(view, i) {
59 view.remove();
59 view.remove();
60 });
60 });
61 },
61 },
62
62
63 _handle_comm_msg: function (msg) {
63 _handle_comm_msg: function (msg) {
64 // Handle incoming comm msg.
64 // Handle incoming comm msg.
65 var method = msg.content.data.method;
65 var method = msg.content.data.method;
66 switch (method) {
66 switch (method) {
67 case 'update':
67 case 'update':
68 this.apply_update(msg.content.data.state);
68 this.apply_update(msg.content.data.state);
69 break;
69 break;
70 case 'custom':
70 case 'custom':
71 this.trigger('msg:custom', msg.content.data.content);
71 this.trigger('msg:custom', msg.content.data.content);
72 break;
72 break;
73 case 'display':
73 case 'display':
74 this.widget_manager.display_view(msg, this);
74 this.widget_manager.display_view(msg, this);
75 break;
75 break;
76 }
76 }
77 },
77 },
78
78
79 apply_update: function (state) {
79 apply_update: function (state) {
80 // Handle when a widget is updated via the python side.
80 // Handle when a widget is updated via the python side.
81 var that = this;
81 var that = this;
82 _.each(state, function(value, key) {
82 _.each(state, function(value, key) {
83 that.key_value_lock = [key, value];
83 that.key_value_lock = [key, value];
84 try {
84 try {
85 WidgetModel.__super__.set.apply(that, [key, that._unpack_models(value)]);
85 WidgetModel.__super__.set.apply(that, [key, that._unpack_models(value)]);
86 } finally {
86 } finally {
87 that.key_value_lock = null;
87 that.key_value_lock = null;
88 }
88 }
89 });
89 });
90 },
90 },
91
91
92 _handle_status: function (msg, callbacks) {
92 _handle_status: function (msg, callbacks) {
93 // Handle status msgs.
93 // Handle status msgs.
94
94
95 // execution_state : ('busy', 'idle', 'starting')
95 // execution_state : ('busy', 'idle', 'starting')
96 if (this.comm !== undefined) {
96 if (this.comm !== undefined) {
97 if (msg.content.execution_state ==='idle') {
97 if (msg.content.execution_state ==='idle') {
98 // Send buffer if this message caused another message to be
98 // Send buffer if this message caused another message to be
99 // throttled.
99 // throttled.
100 if (this.msg_buffer !== null &&
100 if (this.msg_buffer !== null &&
101 (this.get('msg_throttle') || 3) === this.pending_msgs) {
101 (this.get('msg_throttle') || 3) === this.pending_msgs) {
102 var data = {method: 'backbone', sync_method: 'update', sync_data: this.msg_buffer};
102 var data = {method: 'backbone', sync_method: 'update', sync_data: this.msg_buffer};
103 this.comm.send(data, callbacks);
103 this.comm.send(data, callbacks);
104 this.msg_buffer = null;
104 this.msg_buffer = null;
105 } else {
105 } else {
106 --this.pending_msgs;
106 --this.pending_msgs;
107 }
107 }
108 }
108 }
109 }
109 }
110 },
110 },
111
111
112 callbacks: function(view) {
112 callbacks: function(view) {
113 // Create msg callbacks for a comm msg.
113 // Create msg callbacks for a comm msg.
114 var callbacks = this.widget_manager.callbacks(view);
114 var callbacks = this.widget_manager.callbacks(view);
115
115
116 if (callbacks.iopub === undefined) {
116 if (callbacks.iopub === undefined) {
117 callbacks.iopub = {};
117 callbacks.iopub = {};
118 }
118 }
119
119
120 var that = this;
120 var that = this;
121 callbacks.iopub.status = function (msg) {
121 callbacks.iopub.status = function (msg) {
122 that._handle_status(msg, callbacks);
122 that._handle_status(msg, callbacks);
123 };
123 };
124 return callbacks;
124 return callbacks;
125 },
125 },
126
126
127 set: function(key, val, options) {
127 set: function(key, val, options) {
128 // Set a value.
128 // Set a value.
129 var return_value = WidgetModel.__super__.set.apply(this, arguments);
129 var return_value = WidgetModel.__super__.set.apply(this, arguments);
130
130
131 // Backbone only remembers the diff of the most recent set()
131 // Backbone only remembers the diff of the most recent set()
132 // operation. Calling set multiple times in a row results in a
132 // operation. Calling set multiple times in a row results in a
133 // loss of diff information. Here we keep our own running diff.
133 // loss of diff information. Here we keep our own running diff.
134 this._buffered_state_diff = $.extend(this._buffered_state_diff, this.changedAttributes() || {});
134 this._buffered_state_diff = $.extend(this._buffered_state_diff, this.changedAttributes() || {});
135 return return_value;
135 return return_value;
136 },
136 },
137
137
138 sync: function (method, model, options) {
138 sync: function (method, model, options) {
139 // Handle sync to the back-end. Called when a model.save() is called.
139 // Handle sync to the back-end. Called when a model.save() is called.
140
140
141 // Make sure a comm exists.
141 // Make sure a comm exists.
142 var error = options.error || function() {
142 var error = options.error || function() {
143 console.error('Backbone sync error:', arguments);
143 console.error('Backbone sync error:', arguments);
144 };
144 };
145 if (this.comm === undefined) {
145 if (this.comm === undefined) {
146 error();
146 error();
147 return false;
147 return false;
148 }
148 }
149
149
150 // Delete any key value pairs that the back-end already knows about.
150 // Delete any key value pairs that the back-end already knows about.
151 var attrs = (method === 'patch') ? options.attrs : model.toJSON(options);
151 var attrs = (method === 'patch') ? options.attrs : model.toJSON(options);
152 if (this.key_value_lock !== null) {
152 if (this.key_value_lock !== null) {
153 var key = this.key_value_lock[0];
153 var key = this.key_value_lock[0];
154 var value = this.key_value_lock[1];
154 var value = this.key_value_lock[1];
155 if (attrs[key] === value) {
155 if (attrs[key] === value) {
156 delete attrs[key];
156 delete attrs[key];
157 }
157 }
158 }
158 }
159
159
160 // Only sync if there are attributes to send to the back-end.
160 // Only sync if there are attributes to send to the back-end.
161 attrs = this._pack_models(attrs);
161 attrs = this._pack_models(attrs);
162 if (_.size(attrs) > 0) {
162 if (_.size(attrs) > 0) {
163
163
164 // If this message was sent via backbone itself, it will not
164 // If this message was sent via backbone itself, it will not
165 // have any callbacks. It's important that we create callbacks
165 // have any callbacks. It's important that we create callbacks
166 // so we can listen for status messages, etc...
166 // so we can listen for status messages, etc...
167 var callbacks = options.callbacks || this.callbacks();
167 var callbacks = options.callbacks || this.callbacks();
168
168
169 // Check throttle.
169 // Check throttle.
170 if (this.pending_msgs >= (this.get('msg_throttle') || 3)) {
170 if (this.pending_msgs >= (this.get('msg_throttle') || 3)) {
171 // The throttle has been exceeded, buffer the current msg so
171 // The throttle has been exceeded, buffer the current msg so
172 // it can be sent once the kernel has finished processing
172 // it can be sent once the kernel has finished processing
173 // some of the existing messages.
173 // some of the existing messages.
174
174
175 // Combine updates if it is a 'patch' sync, otherwise replace updates
175 // Combine updates if it is a 'patch' sync, otherwise replace updates
176 switch (method) {
176 switch (method) {
177 case 'patch':
177 case 'patch':
178 this.msg_buffer = $.extend(this.msg_buffer || {}, attrs);
178 this.msg_buffer = $.extend(this.msg_buffer || {}, attrs);
179 break;
179 break;
180 case 'update':
180 case 'update':
181 case 'create':
181 case 'create':
182 this.msg_buffer = attrs;
182 this.msg_buffer = attrs;
183 break;
183 break;
184 default:
184 default:
185 error();
185 error();
186 return false;
186 return false;
187 }
187 }
188 this.msg_buffer_callbacks = callbacks;
188 this.msg_buffer_callbacks = callbacks;
189
189
190 } else {
190 } else {
191 // We haven't exceeded the throttle, send the message like
191 // We haven't exceeded the throttle, send the message like
192 // normal.
192 // normal.
193 var data = {method: 'backbone', sync_data: attrs};
193 var data = {method: 'backbone', sync_data: attrs};
194 this.comm.send(data, callbacks);
194 this.comm.send(data, callbacks);
195 this.pending_msgs++;
195 this.pending_msgs++;
196 }
196 }
197 }
197 }
198 // Since the comm is a one-way communication, assume the message
198 // Since the comm is a one-way communication, assume the message
199 // arrived. Don't call success since we don't have a model back from the server
199 // arrived. Don't call success since we don't have a model back from the server
200 // this means we miss out on the 'sync' event.
200 // this means we miss out on the 'sync' event.
201 this._buffered_state_diff = {};
201 this._buffered_state_diff = {};
202 },
202 },
203
203
204 save_changes: function(callbacks) {
204 save_changes: function(callbacks) {
205 // Push this model's state to the back-end
205 // Push this model's state to the back-end
206 //
206 //
207 // This invokes a Backbone.Sync.
207 // This invokes a Backbone.Sync.
208 this.save(this._buffered_state_diff, {patch: true, callbacks: callbacks});
208 this.save(this._buffered_state_diff, {patch: true, callbacks: callbacks});
209 },
209 },
210
210
211 _pack_models: function(value) {
211 _pack_models: function(value) {
212 // Replace models with model ids recursively.
212 // Replace models with model ids recursively.
213 var that = this;
213 var that = this;
214 var packed;
214 var packed;
215 if (value instanceof Backbone.Model) {
215 if (value instanceof Backbone.Model) {
216 return value.id;
216 return "IPY_MODEL_" + value.id;
217
217
218 } else if ($.isArray(value)) {
218 } else if ($.isArray(value)) {
219 packed = [];
219 packed = [];
220 _.each(value, function(sub_value, key) {
220 _.each(value, function(sub_value, key) {
221 packed.push(that._pack_models(sub_value));
221 packed.push(that._pack_models(sub_value));
222 });
222 });
223 return packed;
223 return packed;
224
224
225 } else if (value instanceof Object) {
225 } else if (value instanceof Object) {
226 packed = {};
226 packed = {};
227 _.each(value, function(sub_value, key) {
227 _.each(value, function(sub_value, key) {
228 packed[key] = that._pack_models(sub_value);
228 packed[key] = that._pack_models(sub_value);
229 });
229 });
230 return packed;
230 return packed;
231
231
232 } else {
232 } else {
233 return value;
233 return value;
234 }
234 }
235 },
235 },
236
236
237 _unpack_models: function(value) {
237 _unpack_models: function(value) {
238 // Replace model ids with models recursively.
238 // Replace model ids with models recursively.
239 var that = this;
239 var that = this;
240 var unpacked;
240 var unpacked;
241 if ($.isArray(value)) {
241 if ($.isArray(value)) {
242 unpacked = [];
242 unpacked = [];
243 _.each(value, function(sub_value, key) {
243 _.each(value, function(sub_value, key) {
244 unpacked.push(that._unpack_models(sub_value));
244 unpacked.push(that._unpack_models(sub_value));
245 });
245 });
246 return unpacked;
246 return unpacked;
247
247
248 } else if (value instanceof Object) {
248 } else if (value instanceof Object) {
249 unpacked = {};
249 unpacked = {};
250 _.each(value, function(sub_value, key) {
250 _.each(value, function(sub_value, key) {
251 unpacked[key] = that._unpack_models(sub_value);
251 unpacked[key] = that._unpack_models(sub_value);
252 });
252 });
253 return unpacked;
253 return unpacked;
254
254
255 } else if (typeof value === 'string' && value.slice(0,10) === "IPY_MODEL_") {
256 var model = this.widget_manager.get_model(value.slice(10, value.length));
257 if (model) {
258 return model;
259 } else {
260 return value;
261 }
255 } else {
262 } else {
256 var model = this.widget_manager.get_model(value);
257 if (model) {
258 return model;
259 } else {
260 return value;
263 return value;
261 }
262 }
264 }
263 },
265 },
264
266
265 });
267 });
266 widgetmanager.WidgetManager.register_widget_model('WidgetModel', WidgetModel);
268 widgetmanager.WidgetManager.register_widget_model('WidgetModel', WidgetModel);
267
269
268
270
269 var WidgetView = Backbone.View.extend({
271 var WidgetView = Backbone.View.extend({
270 initialize: function(parameters) {
272 initialize: function(parameters) {
271 // Public constructor.
273 // Public constructor.
272 this.model.on('change',this.update,this);
274 this.model.on('change',this.update,this);
273 this.options = parameters.options;
275 this.options = parameters.options;
274 this.child_model_views = {};
276 this.child_model_views = {};
275 this.child_views = {};
277 this.child_views = {};
276 this.model.views.push(this);
278 this.model.views.push(this);
277 this.id = this.id || IPython.utils.uuid();
279 this.id = this.id || IPython.utils.uuid();
278 },
280 },
279
281
280 update: function(){
282 update: function(){
281 // Triggered on model change.
283 // Triggered on model change.
282 //
284 //
283 // Update view to be consistent with this.model
285 // Update view to be consistent with this.model
284 },
286 },
285
287
286 create_child_view: function(child_model, options) {
288 create_child_view: function(child_model, options) {
287 // Create and return a child view.
289 // Create and return a child view.
288 //
290 //
289 // -given a model and (optionally) a view name if the view name is
291 // -given a model and (optionally) a view name if the view name is
290 // not given, it defaults to the model's default view attribute.
292 // not given, it defaults to the model's default view attribute.
291
293
292 // TODO: this is hacky, and makes the view depend on this cell attribute and widget manager behavior
294 // TODO: this is hacky, and makes the view depend on this cell attribute and widget manager behavior
293 // it would be great to have the widget manager add the cell metadata
295 // it would be great to have the widget manager add the cell metadata
294 // to the subview without having to add it here.
296 // to the subview without having to add it here.
295 options = $.extend({ parent: this }, options || {});
297 options = $.extend({ parent: this }, options || {});
296 var child_view = this.model.widget_manager.create_view(child_model, options, this);
298 var child_view = this.model.widget_manager.create_view(child_model, options, this);
297
299
298 // Associate the view id with the model id.
300 // Associate the view id with the model id.
299 if (this.child_model_views[child_model.id] === undefined) {
301 if (this.child_model_views[child_model.id] === undefined) {
300 this.child_model_views[child_model.id] = [];
302 this.child_model_views[child_model.id] = [];
301 }
303 }
302 this.child_model_views[child_model.id].push(child_view.id);
304 this.child_model_views[child_model.id].push(child_view.id);
303
305
304 // Remember the view by id.
306 // Remember the view by id.
305 this.child_views[child_view.id] = child_view;
307 this.child_views[child_view.id] = child_view;
306 return child_view;
308 return child_view;
307 },
309 },
308
310
309 pop_child_view: function(child_model) {
311 pop_child_view: function(child_model) {
310 // Delete a child view that was previously created using create_child_view.
312 // Delete a child view that was previously created using create_child_view.
311 var view_ids = this.child_model_views[child_model.id];
313 var view_ids = this.child_model_views[child_model.id];
312 if (view_ids !== undefined) {
314 if (view_ids !== undefined) {
313
315
314 // Only delete the first view in the list.
316 // Only delete the first view in the list.
315 var view_id = view_ids[0];
317 var view_id = view_ids[0];
316 var view = this.child_views[view_id];
318 var view = this.child_views[view_id];
317 delete this.child_views[view_id];
319 delete this.child_views[view_id];
318 view_ids.splice(0,1);
320 view_ids.splice(0,1);
319 child_model.views.pop(view);
321 child_model.views.pop(view);
320
322
321 // Remove the view list specific to this model if it is empty.
323 // Remove the view list specific to this model if it is empty.
322 if (view_ids.length === 0) {
324 if (view_ids.length === 0) {
323 delete this.child_model_views[child_model.id];
325 delete this.child_model_views[child_model.id];
324 }
326 }
325 return view;
327 return view;
326 }
328 }
327 return null;
329 return null;
328 },
330 },
329
331
330 do_diff: function(old_list, new_list, removed_callback, added_callback) {
332 do_diff: function(old_list, new_list, removed_callback, added_callback) {
331 // Difference a changed list and call remove and add callbacks for
333 // Difference a changed list and call remove and add callbacks for
332 // each removed and added item in the new list.
334 // each removed and added item in the new list.
333 //
335 //
334 // Parameters
336 // Parameters
335 // ----------
337 // ----------
336 // old_list : array
338 // old_list : array
337 // new_list : array
339 // new_list : array
338 // removed_callback : Callback(item)
340 // removed_callback : Callback(item)
339 // Callback that is called for each item removed.
341 // Callback that is called for each item removed.
340 // added_callback : Callback(item)
342 // added_callback : Callback(item)
341 // Callback that is called for each item added.
343 // Callback that is called for each item added.
342
344
343 // Walk the lists until an unequal entry is found.
345 // Walk the lists until an unequal entry is found.
344 var i;
346 var i;
345 for (i = 0; i < new_list.length; i++) {
347 for (i = 0; i < new_list.length; i++) {
346 if (i < old_list.length || new_list[i] !== old_list[i]) {
348 if (i < old_list.length || new_list[i] !== old_list[i]) {
347 break;
349 break;
348 }
350 }
349 }
351 }
350
352
351 // Remove the non-matching items from the old list.
353 // Remove the non-matching items from the old list.
352 for (var j = i; j < old_list.length; j++) {
354 for (var j = i; j < old_list.length; j++) {
353 removed_callback(old_list[j]);
355 removed_callback(old_list[j]);
354 }
356 }
355
357
356 // Add the rest of the new list items.
358 // Add the rest of the new list items.
357 for (i; i < new_list.length; i++) {
359 for (i; i < new_list.length; i++) {
358 added_callback(new_list[i]);
360 added_callback(new_list[i]);
359 }
361 }
360 },
362 },
361
363
362 callbacks: function(){
364 callbacks: function(){
363 // Create msg callbacks for a comm msg.
365 // Create msg callbacks for a comm msg.
364 return this.model.callbacks(this);
366 return this.model.callbacks(this);
365 },
367 },
366
368
367 render: function(){
369 render: function(){
368 // Render the view.
370 // Render the view.
369 //
371 //
370 // By default, this is only called the first time the view is created
372 // By default, this is only called the first time the view is created
371 },
373 },
372
374
373 show: function(){
375 show: function(){
374 // Show the widget-area
376 // Show the widget-area
375 if (this.options && this.options.cell &&
377 if (this.options && this.options.cell &&
376 this.options.cell.widget_area !== undefined) {
378 this.options.cell.widget_area !== undefined) {
377 this.options.cell.widget_area.show();
379 this.options.cell.widget_area.show();
378 }
380 }
379 },
381 },
380
382
381 send: function (content) {
383 send: function (content) {
382 // Send a custom msg associated with this view.
384 // Send a custom msg associated with this view.
383 this.model.send(content, this.callbacks());
385 this.model.send(content, this.callbacks());
384 },
386 },
385
387
386 touch: function () {
388 touch: function () {
387 this.model.save_changes(this.callbacks());
389 this.model.save_changes(this.callbacks());
388 },
390 },
389 });
391 });
390
392
391
393
392 var DOMWidgetView = WidgetView.extend({
394 var DOMWidgetView = WidgetView.extend({
393 initialize: function (options) {
395 initialize: function (options) {
394 // Public constructor
396 // Public constructor
395
397
396 // In the future we may want to make changes more granular
398 // In the future we may want to make changes more granular
397 // (e.g., trigger on visible:change).
399 // (e.g., trigger on visible:change).
398 this.model.on('change', this.update, this);
400 this.model.on('change', this.update, this);
399 this.model.on('msg:custom', this.on_msg, this);
401 this.model.on('msg:custom', this.on_msg, this);
400 DOMWidgetView.__super__.initialize.apply(this, arguments);
402 DOMWidgetView.__super__.initialize.apply(this, arguments);
401 this.on('displayed', this.show, this);
403 this.on('displayed', this.show, this);
402 },
404 },
403
405
404 on_msg: function(msg) {
406 on_msg: function(msg) {
405 // Handle DOM specific msgs.
407 // Handle DOM specific msgs.
406 switch(msg.msg_type) {
408 switch(msg.msg_type) {
407 case 'add_class':
409 case 'add_class':
408 this.add_class(msg.selector, msg.class_list);
410 this.add_class(msg.selector, msg.class_list);
409 break;
411 break;
410 case 'remove_class':
412 case 'remove_class':
411 this.remove_class(msg.selector, msg.class_list);
413 this.remove_class(msg.selector, msg.class_list);
412 break;
414 break;
413 }
415 }
414 },
416 },
415
417
416 add_class: function (selector, class_list) {
418 add_class: function (selector, class_list) {
417 // Add a DOM class to an element.
419 // Add a DOM class to an element.
418 this._get_selector_element(selector).addClass(class_list);
420 this._get_selector_element(selector).addClass(class_list);
419 },
421 },
420
422
421 remove_class: function (selector, class_list) {
423 remove_class: function (selector, class_list) {
422 // Remove a DOM class from an element.
424 // Remove a DOM class from an element.
423 this._get_selector_element(selector).removeClass(class_list);
425 this._get_selector_element(selector).removeClass(class_list);
424 },
426 },
425
427
426 update: function () {
428 update: function () {
427 // Update the contents of this view
429 // Update the contents of this view
428 //
430 //
429 // Called when the model is changed. The model may have been
431 // Called when the model is changed. The model may have been
430 // changed by another view or by a state update from the back-end.
432 // changed by another view or by a state update from the back-end.
431 // The very first update seems to happen before the element is
433 // The very first update seems to happen before the element is
432 // finished rendering so we use setTimeout to give the element time
434 // finished rendering so we use setTimeout to give the element time
433 // to render
435 // to render
434 var e = this.$el;
436 var e = this.$el;
435 var visible = this.model.get('visible');
437 var visible = this.model.get('visible');
436 setTimeout(function() {e.toggle(visible);},0);
438 setTimeout(function() {e.toggle(visible);},0);
437
439
438 var css = this.model.get('_css');
440 var css = this.model.get('_css');
439 if (css === undefined) {return;}
441 if (css === undefined) {return;}
440 for (var i = 0; i < css.length; i++) {
442 for (var i = 0; i < css.length; i++) {
441 // Apply the css traits to all elements that match the selector.
443 // Apply the css traits to all elements that match the selector.
442 var selector = css[i][0];
444 var selector = css[i][0];
443 var elements = this._get_selector_element(selector);
445 var elements = this._get_selector_element(selector);
444 if (elements.length > 0) {
446 if (elements.length > 0) {
445 var trait_key = css[i][1];
447 var trait_key = css[i][1];
446 var trait_value = css[i][2];
448 var trait_value = css[i][2];
447 elements.css(trait_key ,trait_value);
449 elements.css(trait_key ,trait_value);
448 }
450 }
449 }
451 }
450 },
452 },
451
453
452 _get_selector_element: function (selector) {
454 _get_selector_element: function (selector) {
453 // Get the elements via the css selector.
455 // Get the elements via the css selector.
454
456
455 // If the selector is blank, apply the style to the $el_to_style
457 // If the selector is blank, apply the style to the $el_to_style
456 // element. If the $el_to_style element is not defined, use apply
458 // element. If the $el_to_style element is not defined, use apply
457 // the style to the view's element.
459 // the style to the view's element.
458 var elements;
460 var elements;
459 if (!selector) {
461 if (!selector) {
460 if (this.$el_to_style === undefined) {
462 if (this.$el_to_style === undefined) {
461 elements = this.$el;
463 elements = this.$el;
462 } else {
464 } else {
463 elements = this.$el_to_style;
465 elements = this.$el_to_style;
464 }
466 }
465 } else {
467 } else {
466 elements = this.$el.find(selector);
468 elements = this.$el.find(selector);
467 }
469 }
468 return elements;
470 return elements;
469 },
471 },
470 });
472 });
471
473
472 var widget = {
474 var widget = {
473 'WidgetModel': WidgetModel,
475 'WidgetModel': WidgetModel,
474 'WidgetView': WidgetView,
476 'WidgetView': WidgetView,
475 'DOMWidgetView': DOMWidgetView,
477 'DOMWidgetView': DOMWidgetView,
476 };
478 };
477
479
478 // For backwards compatability.
480 // For backwards compatability.
479 $.extend(IPython, widget);
481 $.extend(IPython, widget);
480
482
481 return widget;
483 return widget;
482 });
484 });
@@ -1,440 +1,453
1 """Base Widget class. Allows user to create widgets in the back-end that render
1 """Base Widget class. Allows user to create widgets in the back-end that render
2 in the IPython notebook front-end.
2 in the IPython notebook front-end.
3 """
3 """
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (c) 2013, the IPython Development Team.
5 # Copyright (c) 2013, the IPython Development Team.
6 #
6 #
7 # Distributed under the terms of the Modified BSD License.
7 # Distributed under the terms of the Modified BSD License.
8 #
8 #
9 # The full license is in the file COPYING.txt, distributed with this software.
9 # The full license is in the file COPYING.txt, distributed with this software.
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11
11
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Imports
13 # Imports
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 from contextlib import contextmanager
15 from contextlib import contextmanager
16
16
17 from IPython.core.getipython import get_ipython
17 from IPython.core.getipython import get_ipython
18 from IPython.kernel.comm import Comm
18 from IPython.kernel.comm import Comm
19 from IPython.config import LoggingConfigurable
19 from IPython.config import LoggingConfigurable
20 from IPython.utils.traitlets import Unicode, Dict, Instance, Bool, List, Tuple, Int
20 from IPython.utils.traitlets import Unicode, Dict, Instance, Bool, List, Tuple, Int
21 from IPython.utils.py3compat import string_types
21 from IPython.utils.py3compat import string_types
22
22
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24 # Classes
24 # Classes
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26 class CallbackDispatcher(LoggingConfigurable):
26 class CallbackDispatcher(LoggingConfigurable):
27 """A structure for registering and running callbacks"""
27 """A structure for registering and running callbacks"""
28 callbacks = List()
28 callbacks = List()
29
29
30 def __call__(self, *args, **kwargs):
30 def __call__(self, *args, **kwargs):
31 """Call all of the registered callbacks."""
31 """Call all of the registered callbacks."""
32 value = None
32 value = None
33 for callback in self.callbacks:
33 for callback in self.callbacks:
34 try:
34 try:
35 local_value = callback(*args, **kwargs)
35 local_value = callback(*args, **kwargs)
36 except Exception as e:
36 except Exception as e:
37 ip = get_ipython()
37 ip = get_ipython()
38 if ip is None:
38 if ip is None:
39 self.log.warn("Exception in callback %s: %s", callback, e, exc_info=True)
39 self.log.warn("Exception in callback %s: %s", callback, e, exc_info=True)
40 else:
40 else:
41 ip.showtraceback()
41 ip.showtraceback()
42 else:
42 else:
43 value = local_value if local_value is not None else value
43 value = local_value if local_value is not None else value
44 return value
44 return value
45
45
46 def register_callback(self, callback, remove=False):
46 def register_callback(self, callback, remove=False):
47 """(Un)Register a callback
47 """(Un)Register a callback
48
48
49 Parameters
49 Parameters
50 ----------
50 ----------
51 callback: method handle
51 callback: method handle
52 Method to be registered or unregistered.
52 Method to be registered or unregistered.
53 remove=False: bool
53 remove=False: bool
54 Whether to unregister the callback."""
54 Whether to unregister the callback."""
55
55
56 # (Un)Register the callback.
56 # (Un)Register the callback.
57 if remove and callback in self.callbacks:
57 if remove and callback in self.callbacks:
58 self.callbacks.remove(callback)
58 self.callbacks.remove(callback)
59 elif not remove and callback not in self.callbacks:
59 elif not remove and callback not in self.callbacks:
60 self.callbacks.append(callback)
60 self.callbacks.append(callback)
61
61
62 def _show_traceback(method):
62 def _show_traceback(method):
63 """decorator for showing tracebacks in IPython"""
63 """decorator for showing tracebacks in IPython"""
64 def m(self, *args, **kwargs):
64 def m(self, *args, **kwargs):
65 try:
65 try:
66 return(method(self, *args, **kwargs))
66 return(method(self, *args, **kwargs))
67 except Exception as e:
67 except Exception as e:
68 ip = get_ipython()
68 ip = get_ipython()
69 if ip is None:
69 if ip is None:
70 self.log.warn("Exception in widget method %s: %s", method, e, exc_info=True)
70 self.log.warn("Exception in widget method %s: %s", method, e, exc_info=True)
71 else:
71 else:
72 ip.showtraceback()
72 ip.showtraceback()
73 return m
73 return m
74
74
75 class Widget(LoggingConfigurable):
75 class Widget(LoggingConfigurable):
76 #-------------------------------------------------------------------------
76 #-------------------------------------------------------------------------
77 # Class attributes
77 # Class attributes
78 #-------------------------------------------------------------------------
78 #-------------------------------------------------------------------------
79 _widget_construction_callback = None
79 _widget_construction_callback = None
80 widgets = {}
80 widgets = {}
81
81
82 @staticmethod
82 @staticmethod
83 def on_widget_constructed(callback):
83 def on_widget_constructed(callback):
84 """Registers a callback to be called when a widget is constructed.
84 """Registers a callback to be called when a widget is constructed.
85
85
86 The callback must have the following signature:
86 The callback must have the following signature:
87 callback(widget)"""
87 callback(widget)"""
88 Widget._widget_construction_callback = callback
88 Widget._widget_construction_callback = callback
89
89
90 @staticmethod
90 @staticmethod
91 def _call_widget_constructed(widget):
91 def _call_widget_constructed(widget):
92 """Static method, called when a widget is constructed."""
92 """Static method, called when a widget is constructed."""
93 if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback):
93 if Widget._widget_construction_callback is not None and callable(Widget._widget_construction_callback):
94 Widget._widget_construction_callback(widget)
94 Widget._widget_construction_callback(widget)
95
95
96 #-------------------------------------------------------------------------
96 #-------------------------------------------------------------------------
97 # Traits
97 # Traits
98 #-------------------------------------------------------------------------
98 #-------------------------------------------------------------------------
99 _model_name = Unicode('WidgetModel', help="""Name of the backbone model
99 _model_name = Unicode('WidgetModel', help="""Name of the backbone model
100 registered in the front-end to create and sync this widget with.""")
100 registered in the front-end to create and sync this widget with.""")
101 _view_name = Unicode(help="""Default view registered in the front-end
101 _view_name = Unicode(help="""Default view registered in the front-end
102 to use to represent the widget.""", sync=True)
102 to use to represent the widget.""", sync=True)
103 _comm = Instance('IPython.kernel.comm.Comm')
103 _comm = Instance('IPython.kernel.comm.Comm')
104
104
105 msg_throttle = Int(3, sync=True, help="""Maximum number of msgs the
105 msg_throttle = Int(3, sync=True, help="""Maximum number of msgs the
106 front-end can send before receiving an idle msg from the back-end.""")
106 front-end can send before receiving an idle msg from the back-end.""")
107
107
108 keys = List()
108 keys = List()
109 def _keys_default(self):
109 def _keys_default(self):
110 return [name for name in self.traits(sync=True)]
110 return [name for name in self.traits(sync=True)]
111
111
112 _property_lock = Tuple((None, None))
112 _property_lock = Tuple((None, None))
113
113
114 _display_callbacks = Instance(CallbackDispatcher, ())
114 _display_callbacks = Instance(CallbackDispatcher, ())
115 _msg_callbacks = Instance(CallbackDispatcher, ())
115 _msg_callbacks = Instance(CallbackDispatcher, ())
116
116
117 #-------------------------------------------------------------------------
117 #-------------------------------------------------------------------------
118 # (Con/de)structor
118 # (Con/de)structor
119 #-------------------------------------------------------------------------
119 #-------------------------------------------------------------------------
120 def __init__(self, **kwargs):
120 def __init__(self, **kwargs):
121 """Public constructor"""
121 """Public constructor"""
122 super(Widget, self).__init__(**kwargs)
122 super(Widget, self).__init__(**kwargs)
123
123
124 self.on_trait_change(self._handle_property_changed, self.keys)
124 self.on_trait_change(self._handle_property_changed, self.keys)
125 Widget._call_widget_constructed(self)
125 Widget._call_widget_constructed(self)
126
126
127 def __del__(self):
127 def __del__(self):
128 """Object disposal"""
128 """Object disposal"""
129 self.close()
129 self.close()
130
130
131 #-------------------------------------------------------------------------
131 #-------------------------------------------------------------------------
132 # Properties
132 # Properties
133 #-------------------------------------------------------------------------
133 #-------------------------------------------------------------------------
134
134
135 @property
135 @property
136 def comm(self):
136 def comm(self):
137 """Gets the Comm associated with this widget.
137 """Gets the Comm associated with this widget.
138
138
139 If a Comm doesn't exist yet, a Comm will be created automagically."""
139 If a Comm doesn't exist yet, a Comm will be created automagically."""
140 if self._comm is None:
140 if self._comm is None:
141 # Create a comm.
141 # Create a comm.
142 self._comm = Comm(target_name=self._model_name)
142 self._comm = Comm(target_name=self._model_name)
143 self._comm.on_msg(self._handle_msg)
143 self._comm.on_msg(self._handle_msg)
144 self._comm.on_close(self._close)
144 self._comm.on_close(self._close)
145 Widget.widgets[self.model_id] = self
145 Widget.widgets[self.model_id] = self
146
146
147 # first update
147 # first update
148 self.send_state()
148 self.send_state()
149 return self._comm
149 return self._comm
150
150
151 @property
151 @property
152 def model_id(self):
152 def model_id(self):
153 """Gets the model id of this widget.
153 """Gets the model id of this widget.
154
154
155 If a Comm doesn't exist yet, a Comm will be created automagically."""
155 If a Comm doesn't exist yet, a Comm will be created automagically."""
156 return self.comm.comm_id
156 return self.comm.comm_id
157
157
158 #-------------------------------------------------------------------------
158 #-------------------------------------------------------------------------
159 # Methods
159 # Methods
160 #-------------------------------------------------------------------------
160 #-------------------------------------------------------------------------
161 def _close(self):
161 def _close(self):
162 """Private close - cleanup objects, registry entries"""
162 """Private close - cleanup objects, registry entries"""
163 del Widget.widgets[self.model_id]
163 del Widget.widgets[self.model_id]
164 self._comm = None
164 self._comm = None
165
165
166 def close(self):
166 def close(self):
167 """Close method.
167 """Close method.
168
168
169 Closes the widget which closes the underlying comm.
169 Closes the widget which closes the underlying comm.
170 When the comm is closed, all of the widget views are automatically
170 When the comm is closed, all of the widget views are automatically
171 removed from the front-end."""
171 removed from the front-end."""
172 if self._comm is not None:
172 if self._comm is not None:
173 self._comm.close()
173 self._comm.close()
174 self._close()
174 self._close()
175
175
176 def send_state(self, key=None):
176 def send_state(self, key=None):
177 """Sends the widget state, or a piece of it, to the front-end.
177 """Sends the widget state, or a piece of it, to the front-end.
178
178
179 Parameters
179 Parameters
180 ----------
180 ----------
181 key : unicode (optional)
181 key : unicode (optional)
182 A single property's name to sync with the front-end.
182 A single property's name to sync with the front-end.
183 """
183 """
184 self._send({
184 self._send({
185 "method" : "update",
185 "method" : "update",
186 "state" : self.get_state()
186 "state" : self.get_state()
187 })
187 })
188
188
189 def get_state(self, key=None):
189 def get_state(self, key=None):
190 """Gets the widget state, or a piece of it.
190 """Gets the widget state, or a piece of it.
191
191
192 Parameters
192 Parameters
193 ----------
193 ----------
194 key : unicode (optional)
194 key : unicode (optional)
195 A single property's name to get.
195 A single property's name to get.
196 """
196 """
197 keys = self.keys if key is None else [key]
197 keys = self.keys if key is None else [key]
198 return {k: self._pack_widgets(getattr(self, k)) for k in keys}
198 state = {}
199
199 for k in keys:
200 f = self.trait_metadata(k, 'to_json')
201 if f is None:
202 f = self._trait_to_json
203 value = getattr(self, k)
204 state[k] = f(value)
205 return state
206
200 def send(self, content):
207 def send(self, content):
201 """Sends a custom msg to the widget model in the front-end.
208 """Sends a custom msg to the widget model in the front-end.
202
209
203 Parameters
210 Parameters
204 ----------
211 ----------
205 content : dict
212 content : dict
206 Content of the message to send.
213 Content of the message to send.
207 """
214 """
208 self._send({"method": "custom", "content": content})
215 self._send({"method": "custom", "content": content})
209
216
210 def on_msg(self, callback, remove=False):
217 def on_msg(self, callback, remove=False):
211 """(Un)Register a custom msg receive callback.
218 """(Un)Register a custom msg receive callback.
212
219
213 Parameters
220 Parameters
214 ----------
221 ----------
215 callback: callable
222 callback: callable
216 callback will be passed two arguments when a message arrives::
223 callback will be passed two arguments when a message arrives::
217
224
218 callback(widget, content)
225 callback(widget, content)
219
226
220 remove: bool
227 remove: bool
221 True if the callback should be unregistered."""
228 True if the callback should be unregistered."""
222 self._msg_callbacks.register_callback(callback, remove=remove)
229 self._msg_callbacks.register_callback(callback, remove=remove)
223
230
224 def on_displayed(self, callback, remove=False):
231 def on_displayed(self, callback, remove=False):
225 """(Un)Register a widget displayed callback.
232 """(Un)Register a widget displayed callback.
226
233
227 Parameters
234 Parameters
228 ----------
235 ----------
229 callback: method handler
236 callback: method handler
230 Must have a signature of::
237 Must have a signature of::
231
238
232 callback(widget, **kwargs)
239 callback(widget, **kwargs)
233
240
234 kwargs from display are passed through without modification.
241 kwargs from display are passed through without modification.
235 remove: bool
242 remove: bool
236 True if the callback should be unregistered."""
243 True if the callback should be unregistered."""
237 self._display_callbacks.register_callback(callback, remove=remove)
244 self._display_callbacks.register_callback(callback, remove=remove)
238
245
239 #-------------------------------------------------------------------------
246 #-------------------------------------------------------------------------
240 # Support methods
247 # Support methods
241 #-------------------------------------------------------------------------
248 #-------------------------------------------------------------------------
242 @contextmanager
249 @contextmanager
243 def _lock_property(self, key, value):
250 def _lock_property(self, key, value):
244 """Lock a property-value pair.
251 """Lock a property-value pair.
245
252
246 NOTE: This, in addition to the single lock for all state changes, is
253 NOTE: This, in addition to the single lock for all state changes, is
247 flawed. In the future we may want to look into buffering state changes
254 flawed. In the future we may want to look into buffering state changes
248 back to the front-end."""
255 back to the front-end."""
249 self._property_lock = (key, value)
256 self._property_lock = (key, value)
250 try:
257 try:
251 yield
258 yield
252 finally:
259 finally:
253 self._property_lock = (None, None)
260 self._property_lock = (None, None)
254
261
255 def _should_send_property(self, key, value):
262 def _should_send_property(self, key, value):
256 """Check the property lock (property_lock)"""
263 """Check the property lock (property_lock)"""
257 return key != self._property_lock[0] or \
264 return key != self._property_lock[0] or \
258 value != self._property_lock[1]
265 value != self._property_lock[1]
259
266
260 # Event handlers
267 # Event handlers
261 @_show_traceback
268 @_show_traceback
262 def _handle_msg(self, msg):
269 def _handle_msg(self, msg):
263 """Called when a msg is received from the front-end"""
270 """Called when a msg is received from the front-end"""
264 data = msg['content']['data']
271 data = msg['content']['data']
265 method = data['method']
272 method = data['method']
266 if not method in ['backbone', 'custom']:
273 if not method in ['backbone', 'custom']:
267 self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method)
274 self.log.error('Unknown front-end to back-end widget msg with method "%s"' % method)
268
275
269 # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
276 # Handle backbone sync methods CREATE, PATCH, and UPDATE all in one.
270 if method == 'backbone' and 'sync_data' in data:
277 if method == 'backbone' and 'sync_data' in data:
271 sync_data = data['sync_data']
278 sync_data = data['sync_data']
272 self._handle_receive_state(sync_data) # handles all methods
279 self._handle_receive_state(sync_data) # handles all methods
273
280
274 # Handle a custom msg from the front-end
281 # Handle a custom msg from the front-end
275 elif method == 'custom':
282 elif method == 'custom':
276 if 'content' in data:
283 if 'content' in data:
277 self._handle_custom_msg(data['content'])
284 self._handle_custom_msg(data['content'])
278
285
279 def _handle_receive_state(self, sync_data):
286 def _handle_receive_state(self, sync_data):
280 """Called when a state is received from the front-end."""
287 """Called when a state is received from the front-end."""
281 for name in self.keys:
288 for name in self.keys:
282 if name in sync_data:
289 if name in sync_data:
283 value = self._unpack_widgets(sync_data[name])
290 f = self.trait_metadata(name, 'from_json')
291 if f is None:
292 f = self._trait_from_json
293 value = f(sync_data[name])
284 with self._lock_property(name, value):
294 with self._lock_property(name, value):
285 setattr(self, name, value)
295 setattr(self, name, value)
286
296
287 def _handle_custom_msg(self, content):
297 def _handle_custom_msg(self, content):
288 """Called when a custom msg is received."""
298 """Called when a custom msg is received."""
289 self._msg_callbacks(self, content)
299 self._msg_callbacks(self, content)
290
300
291 def _handle_property_changed(self, name, old, new):
301 def _handle_property_changed(self, name, old, new):
292 """Called when a property has been changed."""
302 """Called when a property has been changed."""
293 # Make sure this isn't information that the front-end just sent us.
303 # Make sure this isn't information that the front-end just sent us.
294 if self._should_send_property(name, new):
304 if self._should_send_property(name, new):
295 # Send new state to front-end
305 # Send new state to front-end
296 self.send_state(key=name)
306 self.send_state(key=name)
297
307
298 def _handle_displayed(self, **kwargs):
308 def _handle_displayed(self, **kwargs):
299 """Called when a view has been displayed for this widget instance"""
309 """Called when a view has been displayed for this widget instance"""
300 self._display_callbacks(self, **kwargs)
310 self._display_callbacks(self, **kwargs)
301
311
302 def _pack_widgets(self, x):
312 def _trait_to_json(self, x):
303 """Recursively converts all widget instances to model id strings.
313 """Convert a trait value to json
304
314
305 Children widgets will be stored and transmitted to the front-end by
315 Traverse lists/tuples and dicts and serialize their values as well.
306 their model ids. Return value must be JSON-able."""
316 Replace any widgets with their model_id
317 """
307 if isinstance(x, dict):
318 if isinstance(x, dict):
308 return {k: self._pack_widgets(v) for k, v in x.items()}
319 return {k: self._trait_to_json(v) for k, v in x.items()}
309 elif isinstance(x, (list, tuple)):
320 elif isinstance(x, (list, tuple)):
310 return [self._pack_widgets(v) for v in x]
321 return [self._trait_to_json(v) for v in x]
311 elif isinstance(x, Widget):
322 elif isinstance(x, Widget):
312 return x.model_id
323 return "IPY_MODEL_" + x.model_id
313 else:
324 else:
314 return x # Value must be JSON-able
325 return x # Value must be JSON-able
315
326
316 def _unpack_widgets(self, x):
327 def _trait_from_json(self, x):
317 """Recursively converts all model id strings to widget instances.
328 """Convert json values to objects
318
329
319 Children widgets will be stored and transmitted to the front-end by
330 Replace any strings representing valid model id values to Widget references.
320 their model ids."""
331 """
321 if isinstance(x, dict):
332 if isinstance(x, dict):
322 return {k: self._unpack_widgets(v) for k, v in x.items()}
333 return {k: self._trait_from_json(v) for k, v in x.items()}
323 elif isinstance(x, (list, tuple)):
334 elif isinstance(x, (list, tuple)):
324 return [self._unpack_widgets(v) for v in x]
335 return [self._trait_from_json(v) for v in x]
325 elif isinstance(x, string_types):
336 elif isinstance(x, string_types) and x.startswith('IPY_MODEL_') and x[10:] in Widget.widgets:
326 return x if x not in Widget.widgets else Widget.widgets[x]
337 # we want to support having child widgets at any level in a hierarchy
338 # trusting that a widget UUID will not appear out in the wild
339 return Widget.widgets[x]
327 else:
340 else:
328 return x
341 return x
329
342
330 def _ipython_display_(self, **kwargs):
343 def _ipython_display_(self, **kwargs):
331 """Called when `IPython.display.display` is called on the widget."""
344 """Called when `IPython.display.display` is called on the widget."""
332 # Show view. By sending a display message, the comm is opened and the
345 # Show view. By sending a display message, the comm is opened and the
333 # initial state is sent.
346 # initial state is sent.
334 self._send({"method": "display"})
347 self._send({"method": "display"})
335 self._handle_displayed(**kwargs)
348 self._handle_displayed(**kwargs)
336
349
337 def _send(self, msg):
350 def _send(self, msg):
338 """Sends a message to the model in the front-end."""
351 """Sends a message to the model in the front-end."""
339 self.comm.send(msg)
352 self.comm.send(msg)
340
353
341
354
342 class DOMWidget(Widget):
355 class DOMWidget(Widget):
343 visible = Bool(True, help="Whether the widget is visible.", sync=True)
356 visible = Bool(True, help="Whether the widget is visible.", sync=True)
344 _css = List(sync=True) # Internal CSS property list: (selector, key, value)
357 _css = List(sync=True) # Internal CSS property list: (selector, key, value)
345
358
346 def get_css(self, key, selector=""):
359 def get_css(self, key, selector=""):
347 """Get a CSS property of the widget.
360 """Get a CSS property of the widget.
348
361
349 Note: This function does not actually request the CSS from the
362 Note: This function does not actually request the CSS from the
350 front-end; Only properties that have been set with set_css can be read.
363 front-end; Only properties that have been set with set_css can be read.
351
364
352 Parameters
365 Parameters
353 ----------
366 ----------
354 key: unicode
367 key: unicode
355 CSS key
368 CSS key
356 selector: unicode (optional)
369 selector: unicode (optional)
357 JQuery selector used when the CSS key/value was set.
370 JQuery selector used when the CSS key/value was set.
358 """
371 """
359 if selector in self._css and key in self._css[selector]:
372 if selector in self._css and key in self._css[selector]:
360 return self._css[selector][key]
373 return self._css[selector][key]
361 else:
374 else:
362 return None
375 return None
363
376
364 def set_css(self, dict_or_key, value=None, selector=''):
377 def set_css(self, dict_or_key, value=None, selector=''):
365 """Set one or more CSS properties of the widget.
378 """Set one or more CSS properties of the widget.
366
379
367 This function has two signatures:
380 This function has two signatures:
368 - set_css(css_dict, selector='')
381 - set_css(css_dict, selector='')
369 - set_css(key, value, selector='')
382 - set_css(key, value, selector='')
370
383
371 Parameters
384 Parameters
372 ----------
385 ----------
373 css_dict : dict
386 css_dict : dict
374 CSS key/value pairs to apply
387 CSS key/value pairs to apply
375 key: unicode
388 key: unicode
376 CSS key
389 CSS key
377 value:
390 value:
378 CSS value
391 CSS value
379 selector: unicode (optional, kwarg only)
392 selector: unicode (optional, kwarg only)
380 JQuery selector to use to apply the CSS key/value. If no selector
393 JQuery selector to use to apply the CSS key/value. If no selector
381 is provided, an empty selector is used. An empty selector makes the
394 is provided, an empty selector is used. An empty selector makes the
382 front-end try to apply the css to a default element. The default
395 front-end try to apply the css to a default element. The default
383 element is an attribute unique to each view, which is a DOM element
396 element is an attribute unique to each view, which is a DOM element
384 of the view that should be styled with common CSS (see
397 of the view that should be styled with common CSS (see
385 `$el_to_style` in the Javascript code).
398 `$el_to_style` in the Javascript code).
386 """
399 """
387 if value is None:
400 if value is None:
388 css_dict = dict_or_key
401 css_dict = dict_or_key
389 else:
402 else:
390 css_dict = {dict_or_key: value}
403 css_dict = {dict_or_key: value}
391
404
392 for (key, value) in css_dict.items():
405 for (key, value) in css_dict.items():
393 # First remove the selector/key pair from the css list if it exists.
406 # First remove the selector/key pair from the css list if it exists.
394 # Then add the selector/key pair and new value to the bottom of the
407 # Then add the selector/key pair and new value to the bottom of the
395 # list.
408 # list.
396 self._css = [x for x in self._css if not (x[0]==selector and x[1]==key)]
409 self._css = [x for x in self._css if not (x[0]==selector and x[1]==key)]
397 self._css += [(selector, key, value)]
410 self._css += [(selector, key, value)]
398 self.send_state('_css')
411 self.send_state('_css')
399
412
400 def add_class(self, class_names, selector=""):
413 def add_class(self, class_names, selector=""):
401 """Add class[es] to a DOM element.
414 """Add class[es] to a DOM element.
402
415
403 Parameters
416 Parameters
404 ----------
417 ----------
405 class_names: unicode or list
418 class_names: unicode or list
406 Class name(s) to add to the DOM element(s).
419 Class name(s) to add to the DOM element(s).
407 selector: unicode (optional)
420 selector: unicode (optional)
408 JQuery selector to select the DOM element(s) that the class(es) will
421 JQuery selector to select the DOM element(s) that the class(es) will
409 be added to.
422 be added to.
410 """
423 """
411 class_list = class_names
424 class_list = class_names
412 if isinstance(class_list, (list, tuple)):
425 if isinstance(class_list, (list, tuple)):
413 class_list = ' '.join(class_list)
426 class_list = ' '.join(class_list)
414
427
415 self.send({
428 self.send({
416 "msg_type" : "add_class",
429 "msg_type" : "add_class",
417 "class_list" : class_list,
430 "class_list" : class_list,
418 "selector" : selector
431 "selector" : selector
419 })
432 })
420
433
421 def remove_class(self, class_names, selector=""):
434 def remove_class(self, class_names, selector=""):
422 """Remove class[es] from a DOM element.
435 """Remove class[es] from a DOM element.
423
436
424 Parameters
437 Parameters
425 ----------
438 ----------
426 class_names: unicode or list
439 class_names: unicode or list
427 Class name(s) to remove from the DOM element(s).
440 Class name(s) to remove from the DOM element(s).
428 selector: unicode (optional)
441 selector: unicode (optional)
429 JQuery selector to select the DOM element(s) that the class(es) will
442 JQuery selector to select the DOM element(s) that the class(es) will
430 be removed from.
443 be removed from.
431 """
444 """
432 class_list = class_names
445 class_list = class_names
433 if isinstance(class_list, (list, tuple)):
446 if isinstance(class_list, (list, tuple)):
434 class_list = ' '.join(class_list)
447 class_list = ' '.join(class_list)
435
448
436 self.send({
449 self.send({
437 "msg_type" : "remove_class",
450 "msg_type" : "remove_class",
438 "class_list" : class_list,
451 "class_list" : class_list,
439 "selector" : selector,
452 "selector" : selector,
440 })
453 })
@@ -1,1153 +1,1174
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for IPython.utils.traitlets."""
2 """Tests for IPython.utils.traitlets."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6 #
6 #
7 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
7 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
8 # also under the terms of the Modified BSD License.
8 # also under the terms of the Modified BSD License.
9
9
10 import pickle
10 import pickle
11 import re
11 import re
12 import sys
12 import sys
13 from unittest import TestCase
13 from unittest import TestCase
14
14
15 import nose.tools as nt
15 import nose.tools as nt
16 from nose import SkipTest
16 from nose import SkipTest
17
17
18 from IPython.utils.traitlets import (
18 from IPython.utils.traitlets import (
19 HasTraits, MetaHasTraits, TraitType, Any, CBytes, Dict,
19 HasTraits, MetaHasTraits, TraitType, Any, CBytes, Dict,
20 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
20 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
21 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
21 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
22 ObjectName, DottedObjectName, CRegExp, link
22 ObjectName, DottedObjectName, CRegExp, link
23 )
23 )
24 from IPython.utils import py3compat
24 from IPython.utils import py3compat
25 from IPython.testing.decorators import skipif
25 from IPython.testing.decorators import skipif
26
26
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28 # Helper classes for testing
28 # Helper classes for testing
29 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
30
30
31
31
32 class HasTraitsStub(HasTraits):
32 class HasTraitsStub(HasTraits):
33
33
34 def _notify_trait(self, name, old, new):
34 def _notify_trait(self, name, old, new):
35 self._notify_name = name
35 self._notify_name = name
36 self._notify_old = old
36 self._notify_old = old
37 self._notify_new = new
37 self._notify_new = new
38
38
39
39
40 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
41 # Test classes
41 # Test classes
42 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
43
43
44
44
45 class TestTraitType(TestCase):
45 class TestTraitType(TestCase):
46
46
47 def test_get_undefined(self):
47 def test_get_undefined(self):
48 class A(HasTraits):
48 class A(HasTraits):
49 a = TraitType
49 a = TraitType
50 a = A()
50 a = A()
51 self.assertEqual(a.a, Undefined)
51 self.assertEqual(a.a, Undefined)
52
52
53 def test_set(self):
53 def test_set(self):
54 class A(HasTraitsStub):
54 class A(HasTraitsStub):
55 a = TraitType
55 a = TraitType
56
56
57 a = A()
57 a = A()
58 a.a = 10
58 a.a = 10
59 self.assertEqual(a.a, 10)
59 self.assertEqual(a.a, 10)
60 self.assertEqual(a._notify_name, 'a')
60 self.assertEqual(a._notify_name, 'a')
61 self.assertEqual(a._notify_old, Undefined)
61 self.assertEqual(a._notify_old, Undefined)
62 self.assertEqual(a._notify_new, 10)
62 self.assertEqual(a._notify_new, 10)
63
63
64 def test_validate(self):
64 def test_validate(self):
65 class MyTT(TraitType):
65 class MyTT(TraitType):
66 def validate(self, inst, value):
66 def validate(self, inst, value):
67 return -1
67 return -1
68 class A(HasTraitsStub):
68 class A(HasTraitsStub):
69 tt = MyTT
69 tt = MyTT
70
70
71 a = A()
71 a = A()
72 a.tt = 10
72 a.tt = 10
73 self.assertEqual(a.tt, -1)
73 self.assertEqual(a.tt, -1)
74
74
75 def test_default_validate(self):
75 def test_default_validate(self):
76 class MyIntTT(TraitType):
76 class MyIntTT(TraitType):
77 def validate(self, obj, value):
77 def validate(self, obj, value):
78 if isinstance(value, int):
78 if isinstance(value, int):
79 return value
79 return value
80 self.error(obj, value)
80 self.error(obj, value)
81 class A(HasTraits):
81 class A(HasTraits):
82 tt = MyIntTT(10)
82 tt = MyIntTT(10)
83 a = A()
83 a = A()
84 self.assertEqual(a.tt, 10)
84 self.assertEqual(a.tt, 10)
85
85
86 # Defaults are validated when the HasTraits is instantiated
86 # Defaults are validated when the HasTraits is instantiated
87 class B(HasTraits):
87 class B(HasTraits):
88 tt = MyIntTT('bad default')
88 tt = MyIntTT('bad default')
89 self.assertRaises(TraitError, B)
89 self.assertRaises(TraitError, B)
90
90
91 def test_is_valid_for(self):
91 def test_is_valid_for(self):
92 class MyTT(TraitType):
92 class MyTT(TraitType):
93 def is_valid_for(self, value):
93 def is_valid_for(self, value):
94 return True
94 return True
95 class A(HasTraits):
95 class A(HasTraits):
96 tt = MyTT
96 tt = MyTT
97
97
98 a = A()
98 a = A()
99 a.tt = 10
99 a.tt = 10
100 self.assertEqual(a.tt, 10)
100 self.assertEqual(a.tt, 10)
101
101
102 def test_value_for(self):
102 def test_value_for(self):
103 class MyTT(TraitType):
103 class MyTT(TraitType):
104 def value_for(self, value):
104 def value_for(self, value):
105 return 20
105 return 20
106 class A(HasTraits):
106 class A(HasTraits):
107 tt = MyTT
107 tt = MyTT
108
108
109 a = A()
109 a = A()
110 a.tt = 10
110 a.tt = 10
111 self.assertEqual(a.tt, 20)
111 self.assertEqual(a.tt, 20)
112
112
113 def test_info(self):
113 def test_info(self):
114 class A(HasTraits):
114 class A(HasTraits):
115 tt = TraitType
115 tt = TraitType
116 a = A()
116 a = A()
117 self.assertEqual(A.tt.info(), 'any value')
117 self.assertEqual(A.tt.info(), 'any value')
118
118
119 def test_error(self):
119 def test_error(self):
120 class A(HasTraits):
120 class A(HasTraits):
121 tt = TraitType
121 tt = TraitType
122 a = A()
122 a = A()
123 self.assertRaises(TraitError, A.tt.error, a, 10)
123 self.assertRaises(TraitError, A.tt.error, a, 10)
124
124
125 def test_dynamic_initializer(self):
125 def test_dynamic_initializer(self):
126 class A(HasTraits):
126 class A(HasTraits):
127 x = Int(10)
127 x = Int(10)
128 def _x_default(self):
128 def _x_default(self):
129 return 11
129 return 11
130 class B(A):
130 class B(A):
131 x = Int(20)
131 x = Int(20)
132 class C(A):
132 class C(A):
133 def _x_default(self):
133 def _x_default(self):
134 return 21
134 return 21
135
135
136 a = A()
136 a = A()
137 self.assertEqual(a._trait_values, {})
137 self.assertEqual(a._trait_values, {})
138 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
138 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
139 self.assertEqual(a.x, 11)
139 self.assertEqual(a.x, 11)
140 self.assertEqual(a._trait_values, {'x': 11})
140 self.assertEqual(a._trait_values, {'x': 11})
141 b = B()
141 b = B()
142 self.assertEqual(b._trait_values, {'x': 20})
142 self.assertEqual(b._trait_values, {'x': 20})
143 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
143 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
144 self.assertEqual(b.x, 20)
144 self.assertEqual(b.x, 20)
145 c = C()
145 c = C()
146 self.assertEqual(c._trait_values, {})
146 self.assertEqual(c._trait_values, {})
147 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
147 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
148 self.assertEqual(c.x, 21)
148 self.assertEqual(c.x, 21)
149 self.assertEqual(c._trait_values, {'x': 21})
149 self.assertEqual(c._trait_values, {'x': 21})
150 # Ensure that the base class remains unmolested when the _default
150 # Ensure that the base class remains unmolested when the _default
151 # initializer gets overridden in a subclass.
151 # initializer gets overridden in a subclass.
152 a = A()
152 a = A()
153 c = C()
153 c = C()
154 self.assertEqual(a._trait_values, {})
154 self.assertEqual(a._trait_values, {})
155 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
155 self.assertEqual(list(a._trait_dyn_inits.keys()), ['x'])
156 self.assertEqual(a.x, 11)
156 self.assertEqual(a.x, 11)
157 self.assertEqual(a._trait_values, {'x': 11})
157 self.assertEqual(a._trait_values, {'x': 11})
158
158
159
159
160
160
161 class TestHasTraitsMeta(TestCase):
161 class TestHasTraitsMeta(TestCase):
162
162
163 def test_metaclass(self):
163 def test_metaclass(self):
164 self.assertEqual(type(HasTraits), MetaHasTraits)
164 self.assertEqual(type(HasTraits), MetaHasTraits)
165
165
166 class A(HasTraits):
166 class A(HasTraits):
167 a = Int
167 a = Int
168
168
169 a = A()
169 a = A()
170 self.assertEqual(type(a.__class__), MetaHasTraits)
170 self.assertEqual(type(a.__class__), MetaHasTraits)
171 self.assertEqual(a.a,0)
171 self.assertEqual(a.a,0)
172 a.a = 10
172 a.a = 10
173 self.assertEqual(a.a,10)
173 self.assertEqual(a.a,10)
174
174
175 class B(HasTraits):
175 class B(HasTraits):
176 b = Int()
176 b = Int()
177
177
178 b = B()
178 b = B()
179 self.assertEqual(b.b,0)
179 self.assertEqual(b.b,0)
180 b.b = 10
180 b.b = 10
181 self.assertEqual(b.b,10)
181 self.assertEqual(b.b,10)
182
182
183 class C(HasTraits):
183 class C(HasTraits):
184 c = Int(30)
184 c = Int(30)
185
185
186 c = C()
186 c = C()
187 self.assertEqual(c.c,30)
187 self.assertEqual(c.c,30)
188 c.c = 10
188 c.c = 10
189 self.assertEqual(c.c,10)
189 self.assertEqual(c.c,10)
190
190
191 def test_this_class(self):
191 def test_this_class(self):
192 class A(HasTraits):
192 class A(HasTraits):
193 t = This()
193 t = This()
194 tt = This()
194 tt = This()
195 class B(A):
195 class B(A):
196 tt = This()
196 tt = This()
197 ttt = This()
197 ttt = This()
198 self.assertEqual(A.t.this_class, A)
198 self.assertEqual(A.t.this_class, A)
199 self.assertEqual(B.t.this_class, A)
199 self.assertEqual(B.t.this_class, A)
200 self.assertEqual(B.tt.this_class, B)
200 self.assertEqual(B.tt.this_class, B)
201 self.assertEqual(B.ttt.this_class, B)
201 self.assertEqual(B.ttt.this_class, B)
202
202
203 class TestHasTraitsNotify(TestCase):
203 class TestHasTraitsNotify(TestCase):
204
204
205 def setUp(self):
205 def setUp(self):
206 self._notify1 = []
206 self._notify1 = []
207 self._notify2 = []
207 self._notify2 = []
208
208
209 def notify1(self, name, old, new):
209 def notify1(self, name, old, new):
210 self._notify1.append((name, old, new))
210 self._notify1.append((name, old, new))
211
211
212 def notify2(self, name, old, new):
212 def notify2(self, name, old, new):
213 self._notify2.append((name, old, new))
213 self._notify2.append((name, old, new))
214
214
215 def test_notify_all(self):
215 def test_notify_all(self):
216
216
217 class A(HasTraits):
217 class A(HasTraits):
218 a = Int
218 a = Int
219 b = Float
219 b = Float
220
220
221 a = A()
221 a = A()
222 a.on_trait_change(self.notify1)
222 a.on_trait_change(self.notify1)
223 a.a = 0
223 a.a = 0
224 self.assertEqual(len(self._notify1),0)
224 self.assertEqual(len(self._notify1),0)
225 a.b = 0.0
225 a.b = 0.0
226 self.assertEqual(len(self._notify1),0)
226 self.assertEqual(len(self._notify1),0)
227 a.a = 10
227 a.a = 10
228 self.assertTrue(('a',0,10) in self._notify1)
228 self.assertTrue(('a',0,10) in self._notify1)
229 a.b = 10.0
229 a.b = 10.0
230 self.assertTrue(('b',0.0,10.0) in self._notify1)
230 self.assertTrue(('b',0.0,10.0) in self._notify1)
231 self.assertRaises(TraitError,setattr,a,'a','bad string')
231 self.assertRaises(TraitError,setattr,a,'a','bad string')
232 self.assertRaises(TraitError,setattr,a,'b','bad string')
232 self.assertRaises(TraitError,setattr,a,'b','bad string')
233 self._notify1 = []
233 self._notify1 = []
234 a.on_trait_change(self.notify1,remove=True)
234 a.on_trait_change(self.notify1,remove=True)
235 a.a = 20
235 a.a = 20
236 a.b = 20.0
236 a.b = 20.0
237 self.assertEqual(len(self._notify1),0)
237 self.assertEqual(len(self._notify1),0)
238
238
239 def test_notify_one(self):
239 def test_notify_one(self):
240
240
241 class A(HasTraits):
241 class A(HasTraits):
242 a = Int
242 a = Int
243 b = Float
243 b = Float
244
244
245 a = A()
245 a = A()
246 a.on_trait_change(self.notify1, 'a')
246 a.on_trait_change(self.notify1, 'a')
247 a.a = 0
247 a.a = 0
248 self.assertEqual(len(self._notify1),0)
248 self.assertEqual(len(self._notify1),0)
249 a.a = 10
249 a.a = 10
250 self.assertTrue(('a',0,10) in self._notify1)
250 self.assertTrue(('a',0,10) in self._notify1)
251 self.assertRaises(TraitError,setattr,a,'a','bad string')
251 self.assertRaises(TraitError,setattr,a,'a','bad string')
252
252
253 def test_subclass(self):
253 def test_subclass(self):
254
254
255 class A(HasTraits):
255 class A(HasTraits):
256 a = Int
256 a = Int
257
257
258 class B(A):
258 class B(A):
259 b = Float
259 b = Float
260
260
261 b = B()
261 b = B()
262 self.assertEqual(b.a,0)
262 self.assertEqual(b.a,0)
263 self.assertEqual(b.b,0.0)
263 self.assertEqual(b.b,0.0)
264 b.a = 100
264 b.a = 100
265 b.b = 100.0
265 b.b = 100.0
266 self.assertEqual(b.a,100)
266 self.assertEqual(b.a,100)
267 self.assertEqual(b.b,100.0)
267 self.assertEqual(b.b,100.0)
268
268
269 def test_notify_subclass(self):
269 def test_notify_subclass(self):
270
270
271 class A(HasTraits):
271 class A(HasTraits):
272 a = Int
272 a = Int
273
273
274 class B(A):
274 class B(A):
275 b = Float
275 b = Float
276
276
277 b = B()
277 b = B()
278 b.on_trait_change(self.notify1, 'a')
278 b.on_trait_change(self.notify1, 'a')
279 b.on_trait_change(self.notify2, 'b')
279 b.on_trait_change(self.notify2, 'b')
280 b.a = 0
280 b.a = 0
281 b.b = 0.0
281 b.b = 0.0
282 self.assertEqual(len(self._notify1),0)
282 self.assertEqual(len(self._notify1),0)
283 self.assertEqual(len(self._notify2),0)
283 self.assertEqual(len(self._notify2),0)
284 b.a = 10
284 b.a = 10
285 b.b = 10.0
285 b.b = 10.0
286 self.assertTrue(('a',0,10) in self._notify1)
286 self.assertTrue(('a',0,10) in self._notify1)
287 self.assertTrue(('b',0.0,10.0) in self._notify2)
287 self.assertTrue(('b',0.0,10.0) in self._notify2)
288
288
289 def test_static_notify(self):
289 def test_static_notify(self):
290
290
291 class A(HasTraits):
291 class A(HasTraits):
292 a = Int
292 a = Int
293 _notify1 = []
293 _notify1 = []
294 def _a_changed(self, name, old, new):
294 def _a_changed(self, name, old, new):
295 self._notify1.append((name, old, new))
295 self._notify1.append((name, old, new))
296
296
297 a = A()
297 a = A()
298 a.a = 0
298 a.a = 0
299 # This is broken!!!
299 # This is broken!!!
300 self.assertEqual(len(a._notify1),0)
300 self.assertEqual(len(a._notify1),0)
301 a.a = 10
301 a.a = 10
302 self.assertTrue(('a',0,10) in a._notify1)
302 self.assertTrue(('a',0,10) in a._notify1)
303
303
304 class B(A):
304 class B(A):
305 b = Float
305 b = Float
306 _notify2 = []
306 _notify2 = []
307 def _b_changed(self, name, old, new):
307 def _b_changed(self, name, old, new):
308 self._notify2.append((name, old, new))
308 self._notify2.append((name, old, new))
309
309
310 b = B()
310 b = B()
311 b.a = 10
311 b.a = 10
312 b.b = 10.0
312 b.b = 10.0
313 self.assertTrue(('a',0,10) in b._notify1)
313 self.assertTrue(('a',0,10) in b._notify1)
314 self.assertTrue(('b',0.0,10.0) in b._notify2)
314 self.assertTrue(('b',0.0,10.0) in b._notify2)
315
315
316 def test_notify_args(self):
316 def test_notify_args(self):
317
317
318 def callback0():
318 def callback0():
319 self.cb = ()
319 self.cb = ()
320 def callback1(name):
320 def callback1(name):
321 self.cb = (name,)
321 self.cb = (name,)
322 def callback2(name, new):
322 def callback2(name, new):
323 self.cb = (name, new)
323 self.cb = (name, new)
324 def callback3(name, old, new):
324 def callback3(name, old, new):
325 self.cb = (name, old, new)
325 self.cb = (name, old, new)
326
326
327 class A(HasTraits):
327 class A(HasTraits):
328 a = Int
328 a = Int
329
329
330 a = A()
330 a = A()
331 a.on_trait_change(callback0, 'a')
331 a.on_trait_change(callback0, 'a')
332 a.a = 10
332 a.a = 10
333 self.assertEqual(self.cb,())
333 self.assertEqual(self.cb,())
334 a.on_trait_change(callback0, 'a', remove=True)
334 a.on_trait_change(callback0, 'a', remove=True)
335
335
336 a.on_trait_change(callback1, 'a')
336 a.on_trait_change(callback1, 'a')
337 a.a = 100
337 a.a = 100
338 self.assertEqual(self.cb,('a',))
338 self.assertEqual(self.cb,('a',))
339 a.on_trait_change(callback1, 'a', remove=True)
339 a.on_trait_change(callback1, 'a', remove=True)
340
340
341 a.on_trait_change(callback2, 'a')
341 a.on_trait_change(callback2, 'a')
342 a.a = 1000
342 a.a = 1000
343 self.assertEqual(self.cb,('a',1000))
343 self.assertEqual(self.cb,('a',1000))
344 a.on_trait_change(callback2, 'a', remove=True)
344 a.on_trait_change(callback2, 'a', remove=True)
345
345
346 a.on_trait_change(callback3, 'a')
346 a.on_trait_change(callback3, 'a')
347 a.a = 10000
347 a.a = 10000
348 self.assertEqual(self.cb,('a',1000,10000))
348 self.assertEqual(self.cb,('a',1000,10000))
349 a.on_trait_change(callback3, 'a', remove=True)
349 a.on_trait_change(callback3, 'a', remove=True)
350
350
351 self.assertEqual(len(a._trait_notifiers['a']),0)
351 self.assertEqual(len(a._trait_notifiers['a']),0)
352
352
353 def test_notify_only_once(self):
353 def test_notify_only_once(self):
354
354
355 class A(HasTraits):
355 class A(HasTraits):
356 listen_to = ['a']
356 listen_to = ['a']
357
357
358 a = Int(0)
358 a = Int(0)
359 b = 0
359 b = 0
360
360
361 def __init__(self, **kwargs):
361 def __init__(self, **kwargs):
362 super(A, self).__init__(**kwargs)
362 super(A, self).__init__(**kwargs)
363 self.on_trait_change(self.listener1, ['a'])
363 self.on_trait_change(self.listener1, ['a'])
364
364
365 def listener1(self, name, old, new):
365 def listener1(self, name, old, new):
366 self.b += 1
366 self.b += 1
367
367
368 class B(A):
368 class B(A):
369
369
370 c = 0
370 c = 0
371 d = 0
371 d = 0
372
372
373 def __init__(self, **kwargs):
373 def __init__(self, **kwargs):
374 super(B, self).__init__(**kwargs)
374 super(B, self).__init__(**kwargs)
375 self.on_trait_change(self.listener2)
375 self.on_trait_change(self.listener2)
376
376
377 def listener2(self, name, old, new):
377 def listener2(self, name, old, new):
378 self.c += 1
378 self.c += 1
379
379
380 def _a_changed(self, name, old, new):
380 def _a_changed(self, name, old, new):
381 self.d += 1
381 self.d += 1
382
382
383 b = B()
383 b = B()
384 b.a += 1
384 b.a += 1
385 self.assertEqual(b.b, b.c)
385 self.assertEqual(b.b, b.c)
386 self.assertEqual(b.b, b.d)
386 self.assertEqual(b.b, b.d)
387 b.a += 1
387 b.a += 1
388 self.assertEqual(b.b, b.c)
388 self.assertEqual(b.b, b.c)
389 self.assertEqual(b.b, b.d)
389 self.assertEqual(b.b, b.d)
390
390
391
391
392 class TestHasTraits(TestCase):
392 class TestHasTraits(TestCase):
393
393
394 def test_trait_names(self):
394 def test_trait_names(self):
395 class A(HasTraits):
395 class A(HasTraits):
396 i = Int
396 i = Int
397 f = Float
397 f = Float
398 a = A()
398 a = A()
399 self.assertEqual(sorted(a.trait_names()),['f','i'])
399 self.assertEqual(sorted(a.trait_names()),['f','i'])
400 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
400 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
401
401
402 def test_trait_metadata(self):
402 def test_trait_metadata(self):
403 class A(HasTraits):
403 class A(HasTraits):
404 i = Int(config_key='MY_VALUE')
404 i = Int(config_key='MY_VALUE')
405 a = A()
405 a = A()
406 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
406 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
407
407
408 def test_traits(self):
408 def test_traits(self):
409 class A(HasTraits):
409 class A(HasTraits):
410 i = Int
410 i = Int
411 f = Float
411 f = Float
412 a = A()
412 a = A()
413 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
413 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
414 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
414 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
415
415
416 def test_traits_metadata(self):
416 def test_traits_metadata(self):
417 class A(HasTraits):
417 class A(HasTraits):
418 i = Int(config_key='VALUE1', other_thing='VALUE2')
418 i = Int(config_key='VALUE1', other_thing='VALUE2')
419 f = Float(config_key='VALUE3', other_thing='VALUE2')
419 f = Float(config_key='VALUE3', other_thing='VALUE2')
420 j = Int(0)
420 j = Int(0)
421 a = A()
421 a = A()
422 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
422 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
423 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
423 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
424 self.assertEqual(traits, dict(i=A.i))
424 self.assertEqual(traits, dict(i=A.i))
425
425
426 # This passes, but it shouldn't because I am replicating a bug in
426 # This passes, but it shouldn't because I am replicating a bug in
427 # traits.
427 # traits.
428 traits = a.traits(config_key=lambda v: True)
428 traits = a.traits(config_key=lambda v: True)
429 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
429 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
430
430
431 def test_init(self):
431 def test_init(self):
432 class A(HasTraits):
432 class A(HasTraits):
433 i = Int()
433 i = Int()
434 x = Float()
434 x = Float()
435 a = A(i=1, x=10.0)
435 a = A(i=1, x=10.0)
436 self.assertEqual(a.i, 1)
436 self.assertEqual(a.i, 1)
437 self.assertEqual(a.x, 10.0)
437 self.assertEqual(a.x, 10.0)
438
438
439 def test_positional_args(self):
439 def test_positional_args(self):
440 class A(HasTraits):
440 class A(HasTraits):
441 i = Int(0)
441 i = Int(0)
442 def __init__(self, i):
442 def __init__(self, i):
443 super(A, self).__init__()
443 super(A, self).__init__()
444 self.i = i
444 self.i = i
445
445
446 a = A(5)
446 a = A(5)
447 self.assertEqual(a.i, 5)
447 self.assertEqual(a.i, 5)
448 # should raise TypeError if no positional arg given
448 # should raise TypeError if no positional arg given
449 self.assertRaises(TypeError, A)
449 self.assertRaises(TypeError, A)
450
450
451 #-----------------------------------------------------------------------------
451 #-----------------------------------------------------------------------------
452 # Tests for specific trait types
452 # Tests for specific trait types
453 #-----------------------------------------------------------------------------
453 #-----------------------------------------------------------------------------
454
454
455
455
456 class TestType(TestCase):
456 class TestType(TestCase):
457
457
458 def test_default(self):
458 def test_default(self):
459
459
460 class B(object): pass
460 class B(object): pass
461 class A(HasTraits):
461 class A(HasTraits):
462 klass = Type
462 klass = Type
463
463
464 a = A()
464 a = A()
465 self.assertEqual(a.klass, None)
465 self.assertEqual(a.klass, None)
466
466
467 a.klass = B
467 a.klass = B
468 self.assertEqual(a.klass, B)
468 self.assertEqual(a.klass, B)
469 self.assertRaises(TraitError, setattr, a, 'klass', 10)
469 self.assertRaises(TraitError, setattr, a, 'klass', 10)
470
470
471 def test_value(self):
471 def test_value(self):
472
472
473 class B(object): pass
473 class B(object): pass
474 class C(object): pass
474 class C(object): pass
475 class A(HasTraits):
475 class A(HasTraits):
476 klass = Type(B)
476 klass = Type(B)
477
477
478 a = A()
478 a = A()
479 self.assertEqual(a.klass, B)
479 self.assertEqual(a.klass, B)
480 self.assertRaises(TraitError, setattr, a, 'klass', C)
480 self.assertRaises(TraitError, setattr, a, 'klass', C)
481 self.assertRaises(TraitError, setattr, a, 'klass', object)
481 self.assertRaises(TraitError, setattr, a, 'klass', object)
482 a.klass = B
482 a.klass = B
483
483
484 def test_allow_none(self):
484 def test_allow_none(self):
485
485
486 class B(object): pass
486 class B(object): pass
487 class C(B): pass
487 class C(B): pass
488 class A(HasTraits):
488 class A(HasTraits):
489 klass = Type(B, allow_none=False)
489 klass = Type(B, allow_none=False)
490
490
491 a = A()
491 a = A()
492 self.assertEqual(a.klass, B)
492 self.assertEqual(a.klass, B)
493 self.assertRaises(TraitError, setattr, a, 'klass', None)
493 self.assertRaises(TraitError, setattr, a, 'klass', None)
494 a.klass = C
494 a.klass = C
495 self.assertEqual(a.klass, C)
495 self.assertEqual(a.klass, C)
496
496
497 def test_validate_klass(self):
497 def test_validate_klass(self):
498
498
499 class A(HasTraits):
499 class A(HasTraits):
500 klass = Type('no strings allowed')
500 klass = Type('no strings allowed')
501
501
502 self.assertRaises(ImportError, A)
502 self.assertRaises(ImportError, A)
503
503
504 class A(HasTraits):
504 class A(HasTraits):
505 klass = Type('rub.adub.Duck')
505 klass = Type('rub.adub.Duck')
506
506
507 self.assertRaises(ImportError, A)
507 self.assertRaises(ImportError, A)
508
508
509 def test_validate_default(self):
509 def test_validate_default(self):
510
510
511 class B(object): pass
511 class B(object): pass
512 class A(HasTraits):
512 class A(HasTraits):
513 klass = Type('bad default', B)
513 klass = Type('bad default', B)
514
514
515 self.assertRaises(ImportError, A)
515 self.assertRaises(ImportError, A)
516
516
517 class C(HasTraits):
517 class C(HasTraits):
518 klass = Type(None, B, allow_none=False)
518 klass = Type(None, B, allow_none=False)
519
519
520 self.assertRaises(TraitError, C)
520 self.assertRaises(TraitError, C)
521
521
522 def test_str_klass(self):
522 def test_str_klass(self):
523
523
524 class A(HasTraits):
524 class A(HasTraits):
525 klass = Type('IPython.utils.ipstruct.Struct')
525 klass = Type('IPython.utils.ipstruct.Struct')
526
526
527 from IPython.utils.ipstruct import Struct
527 from IPython.utils.ipstruct import Struct
528 a = A()
528 a = A()
529 a.klass = Struct
529 a.klass = Struct
530 self.assertEqual(a.klass, Struct)
530 self.assertEqual(a.klass, Struct)
531
531
532 self.assertRaises(TraitError, setattr, a, 'klass', 10)
532 self.assertRaises(TraitError, setattr, a, 'klass', 10)
533
533
534 def test_set_str_klass(self):
534 def test_set_str_klass(self):
535
535
536 class A(HasTraits):
536 class A(HasTraits):
537 klass = Type()
537 klass = Type()
538
538
539 a = A(klass='IPython.utils.ipstruct.Struct')
539 a = A(klass='IPython.utils.ipstruct.Struct')
540 from IPython.utils.ipstruct import Struct
540 from IPython.utils.ipstruct import Struct
541 self.assertEqual(a.klass, Struct)
541 self.assertEqual(a.klass, Struct)
542
542
543 class TestInstance(TestCase):
543 class TestInstance(TestCase):
544
544
545 def test_basic(self):
545 def test_basic(self):
546 class Foo(object): pass
546 class Foo(object): pass
547 class Bar(Foo): pass
547 class Bar(Foo): pass
548 class Bah(object): pass
548 class Bah(object): pass
549
549
550 class A(HasTraits):
550 class A(HasTraits):
551 inst = Instance(Foo)
551 inst = Instance(Foo)
552
552
553 a = A()
553 a = A()
554 self.assertTrue(a.inst is None)
554 self.assertTrue(a.inst is None)
555 a.inst = Foo()
555 a.inst = Foo()
556 self.assertTrue(isinstance(a.inst, Foo))
556 self.assertTrue(isinstance(a.inst, Foo))
557 a.inst = Bar()
557 a.inst = Bar()
558 self.assertTrue(isinstance(a.inst, Foo))
558 self.assertTrue(isinstance(a.inst, Foo))
559 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
559 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
560 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
560 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
561 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
561 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
562
562
563 def test_default_klass(self):
564 class Foo(object): pass
565 class Bar(Foo): pass
566 class Bah(object): pass
567
568 class FooInstance(Instance):
569 klass = Foo
570
571 class A(HasTraits):
572 inst = FooInstance()
573
574 a = A()
575 self.assertTrue(a.inst is None)
576 a.inst = Foo()
577 self.assertTrue(isinstance(a.inst, Foo))
578 a.inst = Bar()
579 self.assertTrue(isinstance(a.inst, Foo))
580 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
581 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
582 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
583
563 def test_unique_default_value(self):
584 def test_unique_default_value(self):
564 class Foo(object): pass
585 class Foo(object): pass
565 class A(HasTraits):
586 class A(HasTraits):
566 inst = Instance(Foo,(),{})
587 inst = Instance(Foo,(),{})
567
588
568 a = A()
589 a = A()
569 b = A()
590 b = A()
570 self.assertTrue(a.inst is not b.inst)
591 self.assertTrue(a.inst is not b.inst)
571
592
572 def test_args_kw(self):
593 def test_args_kw(self):
573 class Foo(object):
594 class Foo(object):
574 def __init__(self, c): self.c = c
595 def __init__(self, c): self.c = c
575 class Bar(object): pass
596 class Bar(object): pass
576 class Bah(object):
597 class Bah(object):
577 def __init__(self, c, d):
598 def __init__(self, c, d):
578 self.c = c; self.d = d
599 self.c = c; self.d = d
579
600
580 class A(HasTraits):
601 class A(HasTraits):
581 inst = Instance(Foo, (10,))
602 inst = Instance(Foo, (10,))
582 a = A()
603 a = A()
583 self.assertEqual(a.inst.c, 10)
604 self.assertEqual(a.inst.c, 10)
584
605
585 class B(HasTraits):
606 class B(HasTraits):
586 inst = Instance(Bah, args=(10,), kw=dict(d=20))
607 inst = Instance(Bah, args=(10,), kw=dict(d=20))
587 b = B()
608 b = B()
588 self.assertEqual(b.inst.c, 10)
609 self.assertEqual(b.inst.c, 10)
589 self.assertEqual(b.inst.d, 20)
610 self.assertEqual(b.inst.d, 20)
590
611
591 class C(HasTraits):
612 class C(HasTraits):
592 inst = Instance(Foo)
613 inst = Instance(Foo)
593 c = C()
614 c = C()
594 self.assertTrue(c.inst is None)
615 self.assertTrue(c.inst is None)
595
616
596 def test_bad_default(self):
617 def test_bad_default(self):
597 class Foo(object): pass
618 class Foo(object): pass
598
619
599 class A(HasTraits):
620 class A(HasTraits):
600 inst = Instance(Foo, allow_none=False)
621 inst = Instance(Foo, allow_none=False)
601
622
602 self.assertRaises(TraitError, A)
623 self.assertRaises(TraitError, A)
603
624
604 def test_instance(self):
625 def test_instance(self):
605 class Foo(object): pass
626 class Foo(object): pass
606
627
607 def inner():
628 def inner():
608 class A(HasTraits):
629 class A(HasTraits):
609 inst = Instance(Foo())
630 inst = Instance(Foo())
610
631
611 self.assertRaises(TraitError, inner)
632 self.assertRaises(TraitError, inner)
612
633
613
634
614 class TestThis(TestCase):
635 class TestThis(TestCase):
615
636
616 def test_this_class(self):
637 def test_this_class(self):
617 class Foo(HasTraits):
638 class Foo(HasTraits):
618 this = This
639 this = This
619
640
620 f = Foo()
641 f = Foo()
621 self.assertEqual(f.this, None)
642 self.assertEqual(f.this, None)
622 g = Foo()
643 g = Foo()
623 f.this = g
644 f.this = g
624 self.assertEqual(f.this, g)
645 self.assertEqual(f.this, g)
625 self.assertRaises(TraitError, setattr, f, 'this', 10)
646 self.assertRaises(TraitError, setattr, f, 'this', 10)
626
647
627 def test_this_inst(self):
648 def test_this_inst(self):
628 class Foo(HasTraits):
649 class Foo(HasTraits):
629 this = This()
650 this = This()
630
651
631 f = Foo()
652 f = Foo()
632 f.this = Foo()
653 f.this = Foo()
633 self.assertTrue(isinstance(f.this, Foo))
654 self.assertTrue(isinstance(f.this, Foo))
634
655
635 def test_subclass(self):
656 def test_subclass(self):
636 class Foo(HasTraits):
657 class Foo(HasTraits):
637 t = This()
658 t = This()
638 class Bar(Foo):
659 class Bar(Foo):
639 pass
660 pass
640 f = Foo()
661 f = Foo()
641 b = Bar()
662 b = Bar()
642 f.t = b
663 f.t = b
643 b.t = f
664 b.t = f
644 self.assertEqual(f.t, b)
665 self.assertEqual(f.t, b)
645 self.assertEqual(b.t, f)
666 self.assertEqual(b.t, f)
646
667
647 def test_subclass_override(self):
668 def test_subclass_override(self):
648 class Foo(HasTraits):
669 class Foo(HasTraits):
649 t = This()
670 t = This()
650 class Bar(Foo):
671 class Bar(Foo):
651 t = This()
672 t = This()
652 f = Foo()
673 f = Foo()
653 b = Bar()
674 b = Bar()
654 f.t = b
675 f.t = b
655 self.assertEqual(f.t, b)
676 self.assertEqual(f.t, b)
656 self.assertRaises(TraitError, setattr, b, 't', f)
677 self.assertRaises(TraitError, setattr, b, 't', f)
657
678
658 class TraitTestBase(TestCase):
679 class TraitTestBase(TestCase):
659 """A best testing class for basic trait types."""
680 """A best testing class for basic trait types."""
660
681
661 def assign(self, value):
682 def assign(self, value):
662 self.obj.value = value
683 self.obj.value = value
663
684
664 def coerce(self, value):
685 def coerce(self, value):
665 return value
686 return value
666
687
667 def test_good_values(self):
688 def test_good_values(self):
668 if hasattr(self, '_good_values'):
689 if hasattr(self, '_good_values'):
669 for value in self._good_values:
690 for value in self._good_values:
670 self.assign(value)
691 self.assign(value)
671 self.assertEqual(self.obj.value, self.coerce(value))
692 self.assertEqual(self.obj.value, self.coerce(value))
672
693
673 def test_bad_values(self):
694 def test_bad_values(self):
674 if hasattr(self, '_bad_values'):
695 if hasattr(self, '_bad_values'):
675 for value in self._bad_values:
696 for value in self._bad_values:
676 try:
697 try:
677 self.assertRaises(TraitError, self.assign, value)
698 self.assertRaises(TraitError, self.assign, value)
678 except AssertionError:
699 except AssertionError:
679 assert False, value
700 assert False, value
680
701
681 def test_default_value(self):
702 def test_default_value(self):
682 if hasattr(self, '_default_value'):
703 if hasattr(self, '_default_value'):
683 self.assertEqual(self._default_value, self.obj.value)
704 self.assertEqual(self._default_value, self.obj.value)
684
705
685 def test_allow_none(self):
706 def test_allow_none(self):
686 if (hasattr(self, '_bad_values') and hasattr(self, '_good_values') and
707 if (hasattr(self, '_bad_values') and hasattr(self, '_good_values') and
687 None in self._bad_values):
708 None in self._bad_values):
688 trait=self.obj.traits()['value']
709 trait=self.obj.traits()['value']
689 try:
710 try:
690 trait.allow_none = True
711 trait.allow_none = True
691 self._bad_values.remove(None)
712 self._bad_values.remove(None)
692 #skip coerce. Allow None casts None to None.
713 #skip coerce. Allow None casts None to None.
693 self.assign(None)
714 self.assign(None)
694 self.assertEqual(self.obj.value,None)
715 self.assertEqual(self.obj.value,None)
695 self.test_good_values()
716 self.test_good_values()
696 self.test_bad_values()
717 self.test_bad_values()
697 finally:
718 finally:
698 #tear down
719 #tear down
699 trait.allow_none = False
720 trait.allow_none = False
700 self._bad_values.append(None)
721 self._bad_values.append(None)
701
722
702 def tearDown(self):
723 def tearDown(self):
703 # restore default value after tests, if set
724 # restore default value after tests, if set
704 if hasattr(self, '_default_value'):
725 if hasattr(self, '_default_value'):
705 self.obj.value = self._default_value
726 self.obj.value = self._default_value
706
727
707
728
708 class AnyTrait(HasTraits):
729 class AnyTrait(HasTraits):
709
730
710 value = Any
731 value = Any
711
732
712 class AnyTraitTest(TraitTestBase):
733 class AnyTraitTest(TraitTestBase):
713
734
714 obj = AnyTrait()
735 obj = AnyTrait()
715
736
716 _default_value = None
737 _default_value = None
717 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
738 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
718 _bad_values = []
739 _bad_values = []
719
740
720
741
721 class IntTrait(HasTraits):
742 class IntTrait(HasTraits):
722
743
723 value = Int(99)
744 value = Int(99)
724
745
725 class TestInt(TraitTestBase):
746 class TestInt(TraitTestBase):
726
747
727 obj = IntTrait()
748 obj = IntTrait()
728 _default_value = 99
749 _default_value = 99
729 _good_values = [10, -10]
750 _good_values = [10, -10]
730 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
751 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
731 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
752 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
732 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
753 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
733 if not py3compat.PY3:
754 if not py3compat.PY3:
734 _bad_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
755 _bad_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
735
756
736
757
737 class LongTrait(HasTraits):
758 class LongTrait(HasTraits):
738
759
739 value = Long(99 if py3compat.PY3 else long(99))
760 value = Long(99 if py3compat.PY3 else long(99))
740
761
741 class TestLong(TraitTestBase):
762 class TestLong(TraitTestBase):
742
763
743 obj = LongTrait()
764 obj = LongTrait()
744
765
745 _default_value = 99 if py3compat.PY3 else long(99)
766 _default_value = 99 if py3compat.PY3 else long(99)
746 _good_values = [10, -10]
767 _good_values = [10, -10]
747 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
768 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,),
748 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
769 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
749 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
770 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
750 u'-10.1']
771 u'-10.1']
751 if not py3compat.PY3:
772 if not py3compat.PY3:
752 # maxint undefined on py3, because int == long
773 # maxint undefined on py3, because int == long
753 _good_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
774 _good_values.extend([long(10), long(-10), 10*sys.maxint, -10*sys.maxint])
754 _bad_values.extend([[long(10)], (long(10),)])
775 _bad_values.extend([[long(10)], (long(10),)])
755
776
756 @skipif(py3compat.PY3, "not relevant on py3")
777 @skipif(py3compat.PY3, "not relevant on py3")
757 def test_cast_small(self):
778 def test_cast_small(self):
758 """Long casts ints to long"""
779 """Long casts ints to long"""
759 self.obj.value = 10
780 self.obj.value = 10
760 self.assertEqual(type(self.obj.value), long)
781 self.assertEqual(type(self.obj.value), long)
761
782
762
783
763 class IntegerTrait(HasTraits):
784 class IntegerTrait(HasTraits):
764 value = Integer(1)
785 value = Integer(1)
765
786
766 class TestInteger(TestLong):
787 class TestInteger(TestLong):
767 obj = IntegerTrait()
788 obj = IntegerTrait()
768 _default_value = 1
789 _default_value = 1
769
790
770 def coerce(self, n):
791 def coerce(self, n):
771 return int(n)
792 return int(n)
772
793
773 @skipif(py3compat.PY3, "not relevant on py3")
794 @skipif(py3compat.PY3, "not relevant on py3")
774 def test_cast_small(self):
795 def test_cast_small(self):
775 """Integer casts small longs to int"""
796 """Integer casts small longs to int"""
776 if py3compat.PY3:
797 if py3compat.PY3:
777 raise SkipTest("not relevant on py3")
798 raise SkipTest("not relevant on py3")
778
799
779 self.obj.value = long(100)
800 self.obj.value = long(100)
780 self.assertEqual(type(self.obj.value), int)
801 self.assertEqual(type(self.obj.value), int)
781
802
782
803
783 class FloatTrait(HasTraits):
804 class FloatTrait(HasTraits):
784
805
785 value = Float(99.0)
806 value = Float(99.0)
786
807
787 class TestFloat(TraitTestBase):
808 class TestFloat(TraitTestBase):
788
809
789 obj = FloatTrait()
810 obj = FloatTrait()
790
811
791 _default_value = 99.0
812 _default_value = 99.0
792 _good_values = [10, -10, 10.1, -10.1]
813 _good_values = [10, -10, 10.1, -10.1]
793 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
814 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
794 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
815 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
795 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
816 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
796 if not py3compat.PY3:
817 if not py3compat.PY3:
797 _bad_values.extend([long(10), long(-10)])
818 _bad_values.extend([long(10), long(-10)])
798
819
799
820
800 class ComplexTrait(HasTraits):
821 class ComplexTrait(HasTraits):
801
822
802 value = Complex(99.0-99.0j)
823 value = Complex(99.0-99.0j)
803
824
804 class TestComplex(TraitTestBase):
825 class TestComplex(TraitTestBase):
805
826
806 obj = ComplexTrait()
827 obj = ComplexTrait()
807
828
808 _default_value = 99.0-99.0j
829 _default_value = 99.0-99.0j
809 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
830 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
810 10.1j, 10.1+10.1j, 10.1-10.1j]
831 10.1j, 10.1+10.1j, 10.1-10.1j]
811 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
832 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
812 if not py3compat.PY3:
833 if not py3compat.PY3:
813 _bad_values.extend([long(10), long(-10)])
834 _bad_values.extend([long(10), long(-10)])
814
835
815
836
816 class BytesTrait(HasTraits):
837 class BytesTrait(HasTraits):
817
838
818 value = Bytes(b'string')
839 value = Bytes(b'string')
819
840
820 class TestBytes(TraitTestBase):
841 class TestBytes(TraitTestBase):
821
842
822 obj = BytesTrait()
843 obj = BytesTrait()
823
844
824 _default_value = b'string'
845 _default_value = b'string'
825 _good_values = [b'10', b'-10', b'10L',
846 _good_values = [b'10', b'-10', b'10L',
826 b'-10L', b'10.1', b'-10.1', b'string']
847 b'-10L', b'10.1', b'-10.1', b'string']
827 _bad_values = [10, -10, 10.1, -10.1, 1j, [10],
848 _bad_values = [10, -10, 10.1, -10.1, 1j, [10],
828 ['ten'],{'ten': 10},(10,), None, u'string']
849 ['ten'],{'ten': 10},(10,), None, u'string']
829 if not py3compat.PY3:
850 if not py3compat.PY3:
830 _bad_values.extend([long(10), long(-10)])
851 _bad_values.extend([long(10), long(-10)])
831
852
832
853
833 class UnicodeTrait(HasTraits):
854 class UnicodeTrait(HasTraits):
834
855
835 value = Unicode(u'unicode')
856 value = Unicode(u'unicode')
836
857
837 class TestUnicode(TraitTestBase):
858 class TestUnicode(TraitTestBase):
838
859
839 obj = UnicodeTrait()
860 obj = UnicodeTrait()
840
861
841 _default_value = u'unicode'
862 _default_value = u'unicode'
842 _good_values = ['10', '-10', '10L', '-10L', '10.1',
863 _good_values = ['10', '-10', '10L', '-10L', '10.1',
843 '-10.1', '', u'', 'string', u'string', u"€"]
864 '-10.1', '', u'', 'string', u'string', u"€"]
844 _bad_values = [10, -10, 10.1, -10.1, 1j,
865 _bad_values = [10, -10, 10.1, -10.1, 1j,
845 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
866 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
846 if not py3compat.PY3:
867 if not py3compat.PY3:
847 _bad_values.extend([long(10), long(-10)])
868 _bad_values.extend([long(10), long(-10)])
848
869
849
870
850 class ObjectNameTrait(HasTraits):
871 class ObjectNameTrait(HasTraits):
851 value = ObjectName("abc")
872 value = ObjectName("abc")
852
873
853 class TestObjectName(TraitTestBase):
874 class TestObjectName(TraitTestBase):
854 obj = ObjectNameTrait()
875 obj = ObjectNameTrait()
855
876
856 _default_value = "abc"
877 _default_value = "abc"
857 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
878 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
858 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
879 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
859 None, object(), object]
880 None, object(), object]
860 if sys.version_info[0] < 3:
881 if sys.version_info[0] < 3:
861 _bad_values.append(u"ΓΎ")
882 _bad_values.append(u"ΓΎ")
862 else:
883 else:
863 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
884 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
864
885
865
886
866 class DottedObjectNameTrait(HasTraits):
887 class DottedObjectNameTrait(HasTraits):
867 value = DottedObjectName("a.b")
888 value = DottedObjectName("a.b")
868
889
869 class TestDottedObjectName(TraitTestBase):
890 class TestDottedObjectName(TraitTestBase):
870 obj = DottedObjectNameTrait()
891 obj = DottedObjectNameTrait()
871
892
872 _default_value = "a.b"
893 _default_value = "a.b"
873 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
894 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
874 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None]
895 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None]
875 if sys.version_info[0] < 3:
896 if sys.version_info[0] < 3:
876 _bad_values.append(u"t.ΓΎ")
897 _bad_values.append(u"t.ΓΎ")
877 else:
898 else:
878 _good_values.append(u"t.ΓΎ")
899 _good_values.append(u"t.ΓΎ")
879
900
880
901
881 class TCPAddressTrait(HasTraits):
902 class TCPAddressTrait(HasTraits):
882
903
883 value = TCPAddress()
904 value = TCPAddress()
884
905
885 class TestTCPAddress(TraitTestBase):
906 class TestTCPAddress(TraitTestBase):
886
907
887 obj = TCPAddressTrait()
908 obj = TCPAddressTrait()
888
909
889 _default_value = ('127.0.0.1',0)
910 _default_value = ('127.0.0.1',0)
890 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
911 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
891 _bad_values = [(0,0),('localhost',10.0),('localhost',-1), None]
912 _bad_values = [(0,0),('localhost',10.0),('localhost',-1), None]
892
913
893 class ListTrait(HasTraits):
914 class ListTrait(HasTraits):
894
915
895 value = List(Int)
916 value = List(Int)
896
917
897 class TestList(TraitTestBase):
918 class TestList(TraitTestBase):
898
919
899 obj = ListTrait()
920 obj = ListTrait()
900
921
901 _default_value = []
922 _default_value = []
902 _good_values = [[], [1], list(range(10)), (1,2)]
923 _good_values = [[], [1], list(range(10)), (1,2)]
903 _bad_values = [10, [1,'a'], 'a']
924 _bad_values = [10, [1,'a'], 'a']
904
925
905 def coerce(self, value):
926 def coerce(self, value):
906 if value is not None:
927 if value is not None:
907 value = list(value)
928 value = list(value)
908 return value
929 return value
909
930
910 class Foo(object):
931 class Foo(object):
911 pass
932 pass
912
933
913 class InstanceListTrait(HasTraits):
934 class InstanceListTrait(HasTraits):
914
935
915 value = List(Instance(__name__+'.Foo'))
936 value = List(Instance(__name__+'.Foo'))
916
937
917 class TestInstanceList(TraitTestBase):
938 class TestInstanceList(TraitTestBase):
918
939
919 obj = InstanceListTrait()
940 obj = InstanceListTrait()
920
941
921 def test_klass(self):
942 def test_klass(self):
922 """Test that the instance klass is properly assigned."""
943 """Test that the instance klass is properly assigned."""
923 self.assertIs(self.obj.traits()['value']._trait.klass, Foo)
944 self.assertIs(self.obj.traits()['value']._trait.klass, Foo)
924
945
925 _default_value = []
946 _default_value = []
926 _good_values = [[Foo(), Foo(), None], None]
947 _good_values = [[Foo(), Foo(), None], None]
927 _bad_values = [['1', 2,], '1', [Foo]]
948 _bad_values = [['1', 2,], '1', [Foo]]
928
949
929 class LenListTrait(HasTraits):
950 class LenListTrait(HasTraits):
930
951
931 value = List(Int, [0], minlen=1, maxlen=2)
952 value = List(Int, [0], minlen=1, maxlen=2)
932
953
933 class TestLenList(TraitTestBase):
954 class TestLenList(TraitTestBase):
934
955
935 obj = LenListTrait()
956 obj = LenListTrait()
936
957
937 _default_value = [0]
958 _default_value = [0]
938 _good_values = [[1], [1,2], (1,2)]
959 _good_values = [[1], [1,2], (1,2)]
939 _bad_values = [10, [1,'a'], 'a', [], list(range(3))]
960 _bad_values = [10, [1,'a'], 'a', [], list(range(3))]
940
961
941 def coerce(self, value):
962 def coerce(self, value):
942 if value is not None:
963 if value is not None:
943 value = list(value)
964 value = list(value)
944 return value
965 return value
945
966
946 class TupleTrait(HasTraits):
967 class TupleTrait(HasTraits):
947
968
948 value = Tuple(Int(allow_none=True))
969 value = Tuple(Int(allow_none=True))
949
970
950 class TestTupleTrait(TraitTestBase):
971 class TestTupleTrait(TraitTestBase):
951
972
952 obj = TupleTrait()
973 obj = TupleTrait()
953
974
954 _default_value = None
975 _default_value = None
955 _good_values = [(1,), None, (0,), [1], (None,)]
976 _good_values = [(1,), None, (0,), [1], (None,)]
956 _bad_values = [10, (1,2), ('a'), ()]
977 _bad_values = [10, (1,2), ('a'), ()]
957
978
958 def coerce(self, value):
979 def coerce(self, value):
959 if value is not None:
980 if value is not None:
960 value = tuple(value)
981 value = tuple(value)
961 return value
982 return value
962
983
963 def test_invalid_args(self):
984 def test_invalid_args(self):
964 self.assertRaises(TypeError, Tuple, 5)
985 self.assertRaises(TypeError, Tuple, 5)
965 self.assertRaises(TypeError, Tuple, default_value='hello')
986 self.assertRaises(TypeError, Tuple, default_value='hello')
966 t = Tuple(Int, CBytes, default_value=(1,5))
987 t = Tuple(Int, CBytes, default_value=(1,5))
967
988
968 class LooseTupleTrait(HasTraits):
989 class LooseTupleTrait(HasTraits):
969
990
970 value = Tuple((1,2,3))
991 value = Tuple((1,2,3))
971
992
972 class TestLooseTupleTrait(TraitTestBase):
993 class TestLooseTupleTrait(TraitTestBase):
973
994
974 obj = LooseTupleTrait()
995 obj = LooseTupleTrait()
975
996
976 _default_value = (1,2,3)
997 _default_value = (1,2,3)
977 _good_values = [(1,), None, [1], (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
998 _good_values = [(1,), None, [1], (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
978 _bad_values = [10, 'hello', {}]
999 _bad_values = [10, 'hello', {}]
979
1000
980 def coerce(self, value):
1001 def coerce(self, value):
981 if value is not None:
1002 if value is not None:
982 value = tuple(value)
1003 value = tuple(value)
983 return value
1004 return value
984
1005
985 def test_invalid_args(self):
1006 def test_invalid_args(self):
986 self.assertRaises(TypeError, Tuple, 5)
1007 self.assertRaises(TypeError, Tuple, 5)
987 self.assertRaises(TypeError, Tuple, default_value='hello')
1008 self.assertRaises(TypeError, Tuple, default_value='hello')
988 t = Tuple(Int, CBytes, default_value=(1,5))
1009 t = Tuple(Int, CBytes, default_value=(1,5))
989
1010
990
1011
991 class MultiTupleTrait(HasTraits):
1012 class MultiTupleTrait(HasTraits):
992
1013
993 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
1014 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
994
1015
995 class TestMultiTuple(TraitTestBase):
1016 class TestMultiTuple(TraitTestBase):
996
1017
997 obj = MultiTupleTrait()
1018 obj = MultiTupleTrait()
998
1019
999 _default_value = (99,b'bottles')
1020 _default_value = (99,b'bottles')
1000 _good_values = [(1,b'a'), (2,b'b')]
1021 _good_values = [(1,b'a'), (2,b'b')]
1001 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
1022 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
1002
1023
1003 class CRegExpTrait(HasTraits):
1024 class CRegExpTrait(HasTraits):
1004
1025
1005 value = CRegExp(r'')
1026 value = CRegExp(r'')
1006
1027
1007 class TestCRegExp(TraitTestBase):
1028 class TestCRegExp(TraitTestBase):
1008
1029
1009 def coerce(self, value):
1030 def coerce(self, value):
1010 return re.compile(value)
1031 return re.compile(value)
1011
1032
1012 obj = CRegExpTrait()
1033 obj = CRegExpTrait()
1013
1034
1014 _default_value = re.compile(r'')
1035 _default_value = re.compile(r'')
1015 _good_values = [r'\d+', re.compile(r'\d+')]
1036 _good_values = [r'\d+', re.compile(r'\d+')]
1016 _bad_values = [r'(', None, ()]
1037 _bad_values = [r'(', None, ()]
1017
1038
1018 class DictTrait(HasTraits):
1039 class DictTrait(HasTraits):
1019 value = Dict()
1040 value = Dict()
1020
1041
1021 def test_dict_assignment():
1042 def test_dict_assignment():
1022 d = dict()
1043 d = dict()
1023 c = DictTrait()
1044 c = DictTrait()
1024 c.value = d
1045 c.value = d
1025 d['a'] = 5
1046 d['a'] = 5
1026 nt.assert_equal(d, c.value)
1047 nt.assert_equal(d, c.value)
1027 nt.assert_true(c.value is d)
1048 nt.assert_true(c.value is d)
1028
1049
1029 class TestLink(TestCase):
1050 class TestLink(TestCase):
1030 def test_connect_same(self):
1051 def test_connect_same(self):
1031 """Verify two traitlets of the same type can be linked together using link."""
1052 """Verify two traitlets of the same type can be linked together using link."""
1032
1053
1033 # Create two simple classes with Int traitlets.
1054 # Create two simple classes with Int traitlets.
1034 class A(HasTraits):
1055 class A(HasTraits):
1035 value = Int()
1056 value = Int()
1036 a = A(value=9)
1057 a = A(value=9)
1037 b = A(value=8)
1058 b = A(value=8)
1038
1059
1039 # Conenct the two classes.
1060 # Conenct the two classes.
1040 c = link((a, 'value'), (b, 'value'))
1061 c = link((a, 'value'), (b, 'value'))
1041
1062
1042 # Make sure the values are the same at the point of linking.
1063 # Make sure the values are the same at the point of linking.
1043 self.assertEqual(a.value, b.value)
1064 self.assertEqual(a.value, b.value)
1044
1065
1045 # Change one of the values to make sure they stay in sync.
1066 # Change one of the values to make sure they stay in sync.
1046 a.value = 5
1067 a.value = 5
1047 self.assertEqual(a.value, b.value)
1068 self.assertEqual(a.value, b.value)
1048 b.value = 6
1069 b.value = 6
1049 self.assertEqual(a.value, b.value)
1070 self.assertEqual(a.value, b.value)
1050
1071
1051 def test_link_different(self):
1072 def test_link_different(self):
1052 """Verify two traitlets of different types can be linked together using link."""
1073 """Verify two traitlets of different types can be linked together using link."""
1053
1074
1054 # Create two simple classes with Int traitlets.
1075 # Create two simple classes with Int traitlets.
1055 class A(HasTraits):
1076 class A(HasTraits):
1056 value = Int()
1077 value = Int()
1057 class B(HasTraits):
1078 class B(HasTraits):
1058 count = Int()
1079 count = Int()
1059 a = A(value=9)
1080 a = A(value=9)
1060 b = B(count=8)
1081 b = B(count=8)
1061
1082
1062 # Conenct the two classes.
1083 # Conenct the two classes.
1063 c = link((a, 'value'), (b, 'count'))
1084 c = link((a, 'value'), (b, 'count'))
1064
1085
1065 # Make sure the values are the same at the point of linking.
1086 # Make sure the values are the same at the point of linking.
1066 self.assertEqual(a.value, b.count)
1087 self.assertEqual(a.value, b.count)
1067
1088
1068 # Change one of the values to make sure they stay in sync.
1089 # Change one of the values to make sure they stay in sync.
1069 a.value = 5
1090 a.value = 5
1070 self.assertEqual(a.value, b.count)
1091 self.assertEqual(a.value, b.count)
1071 b.count = 4
1092 b.count = 4
1072 self.assertEqual(a.value, b.count)
1093 self.assertEqual(a.value, b.count)
1073
1094
1074 def test_unlink(self):
1095 def test_unlink(self):
1075 """Verify two linked traitlets can be unlinked."""
1096 """Verify two linked traitlets can be unlinked."""
1076
1097
1077 # Create two simple classes with Int traitlets.
1098 # Create two simple classes with Int traitlets.
1078 class A(HasTraits):
1099 class A(HasTraits):
1079 value = Int()
1100 value = Int()
1080 a = A(value=9)
1101 a = A(value=9)
1081 b = A(value=8)
1102 b = A(value=8)
1082
1103
1083 # Connect the two classes.
1104 # Connect the two classes.
1084 c = link((a, 'value'), (b, 'value'))
1105 c = link((a, 'value'), (b, 'value'))
1085 a.value = 4
1106 a.value = 4
1086 c.unlink()
1107 c.unlink()
1087
1108
1088 # Change one of the values to make sure they don't stay in sync.
1109 # Change one of the values to make sure they don't stay in sync.
1089 a.value = 5
1110 a.value = 5
1090 self.assertNotEqual(a.value, b.value)
1111 self.assertNotEqual(a.value, b.value)
1091
1112
1092 def test_callbacks(self):
1113 def test_callbacks(self):
1093 """Verify two linked traitlets have their callbacks called once."""
1114 """Verify two linked traitlets have their callbacks called once."""
1094
1115
1095 # Create two simple classes with Int traitlets.
1116 # Create two simple classes with Int traitlets.
1096 class A(HasTraits):
1117 class A(HasTraits):
1097 value = Int()
1118 value = Int()
1098 class B(HasTraits):
1119 class B(HasTraits):
1099 count = Int()
1120 count = Int()
1100 a = A(value=9)
1121 a = A(value=9)
1101 b = B(count=8)
1122 b = B(count=8)
1102
1123
1103 # Register callbacks that count.
1124 # Register callbacks that count.
1104 callback_count = []
1125 callback_count = []
1105 def a_callback(name, old, new):
1126 def a_callback(name, old, new):
1106 callback_count.append('a')
1127 callback_count.append('a')
1107 a.on_trait_change(a_callback, 'value')
1128 a.on_trait_change(a_callback, 'value')
1108 def b_callback(name, old, new):
1129 def b_callback(name, old, new):
1109 callback_count.append('b')
1130 callback_count.append('b')
1110 b.on_trait_change(b_callback, 'count')
1131 b.on_trait_change(b_callback, 'count')
1111
1132
1112 # Connect the two classes.
1133 # Connect the two classes.
1113 c = link((a, 'value'), (b, 'count'))
1134 c = link((a, 'value'), (b, 'count'))
1114
1135
1115 # Make sure b's count was set to a's value once.
1136 # Make sure b's count was set to a's value once.
1116 self.assertEqual(''.join(callback_count), 'b')
1137 self.assertEqual(''.join(callback_count), 'b')
1117 del callback_count[:]
1138 del callback_count[:]
1118
1139
1119 # Make sure a's value was set to b's count once.
1140 # Make sure a's value was set to b's count once.
1120 b.count = 5
1141 b.count = 5
1121 self.assertEqual(''.join(callback_count), 'ba')
1142 self.assertEqual(''.join(callback_count), 'ba')
1122 del callback_count[:]
1143 del callback_count[:]
1123
1144
1124 # Make sure b's count was set to a's value once.
1145 # Make sure b's count was set to a's value once.
1125 a.value = 4
1146 a.value = 4
1126 self.assertEqual(''.join(callback_count), 'ab')
1147 self.assertEqual(''.join(callback_count), 'ab')
1127 del callback_count[:]
1148 del callback_count[:]
1128
1149
1129 class Pickleable(HasTraits):
1150 class Pickleable(HasTraits):
1130 i = Int()
1151 i = Int()
1131 j = Int()
1152 j = Int()
1132
1153
1133 def _i_default(self):
1154 def _i_default(self):
1134 return 1
1155 return 1
1135
1156
1136 def _i_changed(self, name, old, new):
1157 def _i_changed(self, name, old, new):
1137 self.j = new
1158 self.j = new
1138
1159
1139 def test_pickle_hastraits():
1160 def test_pickle_hastraits():
1140 c = Pickleable()
1161 c = Pickleable()
1141 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1162 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1142 p = pickle.dumps(c, protocol)
1163 p = pickle.dumps(c, protocol)
1143 c2 = pickle.loads(p)
1164 c2 = pickle.loads(p)
1144 nt.assert_equal(c2.i, c.i)
1165 nt.assert_equal(c2.i, c.i)
1145 nt.assert_equal(c2.j, c.j)
1166 nt.assert_equal(c2.j, c.j)
1146
1167
1147 c.i = 5
1168 c.i = 5
1148 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1169 for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
1149 p = pickle.dumps(c, protocol)
1170 p = pickle.dumps(c, protocol)
1150 c2 = pickle.loads(p)
1171 c2 = pickle.loads(p)
1151 nt.assert_equal(c2.i, c.i)
1172 nt.assert_equal(c2.i, c.i)
1152 nt.assert_equal(c2.j, c.j)
1173 nt.assert_equal(c2.j, c.j)
1153
1174
@@ -1,1516 +1,1523
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 A lightweight Traits like module.
3 A lightweight Traits like module.
4
4
5 This is designed to provide a lightweight, simple, pure Python version of
5 This is designed to provide a lightweight, simple, pure Python version of
6 many of the capabilities of enthought.traits. This includes:
6 many of the capabilities of enthought.traits. This includes:
7
7
8 * Validation
8 * Validation
9 * Type specification with defaults
9 * Type specification with defaults
10 * Static and dynamic notification
10 * Static and dynamic notification
11 * Basic predefined types
11 * Basic predefined types
12 * An API that is similar to enthought.traits
12 * An API that is similar to enthought.traits
13
13
14 We don't support:
14 We don't support:
15
15
16 * Delegation
16 * Delegation
17 * Automatic GUI generation
17 * Automatic GUI generation
18 * A full set of trait types. Most importantly, we don't provide container
18 * A full set of trait types. Most importantly, we don't provide container
19 traits (list, dict, tuple) that can trigger notifications if their
19 traits (list, dict, tuple) that can trigger notifications if their
20 contents change.
20 contents change.
21 * API compatibility with enthought.traits
21 * API compatibility with enthought.traits
22
22
23 There are also some important difference in our design:
23 There are also some important difference in our design:
24
24
25 * enthought.traits does not validate default values. We do.
25 * enthought.traits does not validate default values. We do.
26
26
27 We choose to create this module because we need these capabilities, but
27 We choose to create this module because we need these capabilities, but
28 we need them to be pure Python so they work in all Python implementations,
28 we need them to be pure Python so they work in all Python implementations,
29 including Jython and IronPython.
29 including Jython and IronPython.
30
30
31 Inheritance diagram:
31 Inheritance diagram:
32
32
33 .. inheritance-diagram:: IPython.utils.traitlets
33 .. inheritance-diagram:: IPython.utils.traitlets
34 :parts: 3
34 :parts: 3
35 """
35 """
36
36
37 # Copyright (c) IPython Development Team.
37 # Copyright (c) IPython Development Team.
38 # Distributed under the terms of the Modified BSD License.
38 # Distributed under the terms of the Modified BSD License.
39 #
39 #
40 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
40 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
41 # also under the terms of the Modified BSD License.
41 # also under the terms of the Modified BSD License.
42
42
43 import contextlib
43 import contextlib
44 import inspect
44 import inspect
45 import re
45 import re
46 import sys
46 import sys
47 import types
47 import types
48 from types import FunctionType
48 from types import FunctionType
49 try:
49 try:
50 from types import ClassType, InstanceType
50 from types import ClassType, InstanceType
51 ClassTypes = (ClassType, type)
51 ClassTypes = (ClassType, type)
52 except:
52 except:
53 ClassTypes = (type,)
53 ClassTypes = (type,)
54
54
55 from .importstring import import_item
55 from .importstring import import_item
56 from IPython.utils import py3compat
56 from IPython.utils import py3compat
57 from IPython.utils.py3compat import iteritems
57 from IPython.utils.py3compat import iteritems
58 from IPython.testing.skipdoctest import skip_doctest
58 from IPython.testing.skipdoctest import skip_doctest
59
59
60 SequenceTypes = (list, tuple, set, frozenset)
60 SequenceTypes = (list, tuple, set, frozenset)
61
61
62 #-----------------------------------------------------------------------------
62 #-----------------------------------------------------------------------------
63 # Basic classes
63 # Basic classes
64 #-----------------------------------------------------------------------------
64 #-----------------------------------------------------------------------------
65
65
66
66
67 class NoDefaultSpecified ( object ): pass
67 class NoDefaultSpecified ( object ): pass
68 NoDefaultSpecified = NoDefaultSpecified()
68 NoDefaultSpecified = NoDefaultSpecified()
69
69
70
70
71 class Undefined ( object ): pass
71 class Undefined ( object ): pass
72 Undefined = Undefined()
72 Undefined = Undefined()
73
73
74 class TraitError(Exception):
74 class TraitError(Exception):
75 pass
75 pass
76
76
77 #-----------------------------------------------------------------------------
77 #-----------------------------------------------------------------------------
78 # Utilities
78 # Utilities
79 #-----------------------------------------------------------------------------
79 #-----------------------------------------------------------------------------
80
80
81
81
82 def class_of ( object ):
82 def class_of ( object ):
83 """ Returns a string containing the class name of an object with the
83 """ Returns a string containing the class name of an object with the
84 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
84 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
85 'a PlotValue').
85 'a PlotValue').
86 """
86 """
87 if isinstance( object, py3compat.string_types ):
87 if isinstance( object, py3compat.string_types ):
88 return add_article( object )
88 return add_article( object )
89
89
90 return add_article( object.__class__.__name__ )
90 return add_article( object.__class__.__name__ )
91
91
92
92
93 def add_article ( name ):
93 def add_article ( name ):
94 """ Returns a string containing the correct indefinite article ('a' or 'an')
94 """ Returns a string containing the correct indefinite article ('a' or 'an')
95 prefixed to the specified string.
95 prefixed to the specified string.
96 """
96 """
97 if name[:1].lower() in 'aeiou':
97 if name[:1].lower() in 'aeiou':
98 return 'an ' + name
98 return 'an ' + name
99
99
100 return 'a ' + name
100 return 'a ' + name
101
101
102
102
103 def repr_type(obj):
103 def repr_type(obj):
104 """ Return a string representation of a value and its type for readable
104 """ Return a string representation of a value and its type for readable
105 error messages.
105 error messages.
106 """
106 """
107 the_type = type(obj)
107 the_type = type(obj)
108 if (not py3compat.PY3) and the_type is InstanceType:
108 if (not py3compat.PY3) and the_type is InstanceType:
109 # Old-style class.
109 # Old-style class.
110 the_type = obj.__class__
110 the_type = obj.__class__
111 msg = '%r %r' % (obj, the_type)
111 msg = '%r %r' % (obj, the_type)
112 return msg
112 return msg
113
113
114
114
115 def is_trait(t):
115 def is_trait(t):
116 """ Returns whether the given value is an instance or subclass of TraitType.
116 """ Returns whether the given value is an instance or subclass of TraitType.
117 """
117 """
118 return (isinstance(t, TraitType) or
118 return (isinstance(t, TraitType) or
119 (isinstance(t, type) and issubclass(t, TraitType)))
119 (isinstance(t, type) and issubclass(t, TraitType)))
120
120
121
121
122 def parse_notifier_name(name):
122 def parse_notifier_name(name):
123 """Convert the name argument to a list of names.
123 """Convert the name argument to a list of names.
124
124
125 Examples
125 Examples
126 --------
126 --------
127
127
128 >>> parse_notifier_name('a')
128 >>> parse_notifier_name('a')
129 ['a']
129 ['a']
130 >>> parse_notifier_name(['a','b'])
130 >>> parse_notifier_name(['a','b'])
131 ['a', 'b']
131 ['a', 'b']
132 >>> parse_notifier_name(None)
132 >>> parse_notifier_name(None)
133 ['anytrait']
133 ['anytrait']
134 """
134 """
135 if isinstance(name, str):
135 if isinstance(name, str):
136 return [name]
136 return [name]
137 elif name is None:
137 elif name is None:
138 return ['anytrait']
138 return ['anytrait']
139 elif isinstance(name, (list, tuple)):
139 elif isinstance(name, (list, tuple)):
140 for n in name:
140 for n in name:
141 assert isinstance(n, str), "names must be strings"
141 assert isinstance(n, str), "names must be strings"
142 return name
142 return name
143
143
144
144
145 class _SimpleTest:
145 class _SimpleTest:
146 def __init__ ( self, value ): self.value = value
146 def __init__ ( self, value ): self.value = value
147 def __call__ ( self, test ):
147 def __call__ ( self, test ):
148 return test == self.value
148 return test == self.value
149 def __repr__(self):
149 def __repr__(self):
150 return "<SimpleTest(%r)" % self.value
150 return "<SimpleTest(%r)" % self.value
151 def __str__(self):
151 def __str__(self):
152 return self.__repr__()
152 return self.__repr__()
153
153
154
154
155 def getmembers(object, predicate=None):
155 def getmembers(object, predicate=None):
156 """A safe version of inspect.getmembers that handles missing attributes.
156 """A safe version of inspect.getmembers that handles missing attributes.
157
157
158 This is useful when there are descriptor based attributes that for
158 This is useful when there are descriptor based attributes that for
159 some reason raise AttributeError even though they exist. This happens
159 some reason raise AttributeError even though they exist. This happens
160 in zope.inteface with the __provides__ attribute.
160 in zope.inteface with the __provides__ attribute.
161 """
161 """
162 results = []
162 results = []
163 for key in dir(object):
163 for key in dir(object):
164 try:
164 try:
165 value = getattr(object, key)
165 value = getattr(object, key)
166 except AttributeError:
166 except AttributeError:
167 pass
167 pass
168 else:
168 else:
169 if not predicate or predicate(value):
169 if not predicate or predicate(value):
170 results.append((key, value))
170 results.append((key, value))
171 results.sort()
171 results.sort()
172 return results
172 return results
173
173
174 @skip_doctest
174 @skip_doctest
175 class link(object):
175 class link(object):
176 """Link traits from different objects together so they remain in sync.
176 """Link traits from different objects together so they remain in sync.
177
177
178 Parameters
178 Parameters
179 ----------
179 ----------
180 obj : pairs of objects/attributes
180 obj : pairs of objects/attributes
181
181
182 Examples
182 Examples
183 --------
183 --------
184
184
185 >>> c = link((obj1, 'value'), (obj2, 'value'), (obj3, 'value'))
185 >>> c = link((obj1, 'value'), (obj2, 'value'), (obj3, 'value'))
186 >>> obj1.value = 5 # updates other objects as well
186 >>> obj1.value = 5 # updates other objects as well
187 """
187 """
188 updating = False
188 updating = False
189 def __init__(self, *args):
189 def __init__(self, *args):
190 if len(args) < 2:
190 if len(args) < 2:
191 raise TypeError('At least two traitlets must be provided.')
191 raise TypeError('At least two traitlets must be provided.')
192
192
193 self.objects = {}
193 self.objects = {}
194 initial = getattr(args[0][0], args[0][1])
194 initial = getattr(args[0][0], args[0][1])
195 for obj,attr in args:
195 for obj,attr in args:
196 if getattr(obj, attr) != initial:
196 if getattr(obj, attr) != initial:
197 setattr(obj, attr, initial)
197 setattr(obj, attr, initial)
198
198
199 callback = self._make_closure(obj,attr)
199 callback = self._make_closure(obj,attr)
200 obj.on_trait_change(callback, attr)
200 obj.on_trait_change(callback, attr)
201 self.objects[(obj,attr)] = callback
201 self.objects[(obj,attr)] = callback
202
202
203 @contextlib.contextmanager
203 @contextlib.contextmanager
204 def _busy_updating(self):
204 def _busy_updating(self):
205 self.updating = True
205 self.updating = True
206 try:
206 try:
207 yield
207 yield
208 finally:
208 finally:
209 self.updating = False
209 self.updating = False
210
210
211 def _make_closure(self, sending_obj, sending_attr):
211 def _make_closure(self, sending_obj, sending_attr):
212 def update(name, old, new):
212 def update(name, old, new):
213 self._update(sending_obj, sending_attr, new)
213 self._update(sending_obj, sending_attr, new)
214 return update
214 return update
215
215
216 def _update(self, sending_obj, sending_attr, new):
216 def _update(self, sending_obj, sending_attr, new):
217 if self.updating:
217 if self.updating:
218 return
218 return
219 with self._busy_updating():
219 with self._busy_updating():
220 for obj,attr in self.objects.keys():
220 for obj,attr in self.objects.keys():
221 if obj is not sending_obj or attr != sending_attr:
221 if obj is not sending_obj or attr != sending_attr:
222 setattr(obj, attr, new)
222 setattr(obj, attr, new)
223
223
224 def unlink(self):
224 def unlink(self):
225 for key, callback in self.objects.items():
225 for key, callback in self.objects.items():
226 (obj,attr) = key
226 (obj,attr) = key
227 obj.on_trait_change(callback, attr, remove=True)
227 obj.on_trait_change(callback, attr, remove=True)
228
228
229 #-----------------------------------------------------------------------------
229 #-----------------------------------------------------------------------------
230 # Base TraitType for all traits
230 # Base TraitType for all traits
231 #-----------------------------------------------------------------------------
231 #-----------------------------------------------------------------------------
232
232
233
233
234 class TraitType(object):
234 class TraitType(object):
235 """A base class for all trait descriptors.
235 """A base class for all trait descriptors.
236
236
237 Notes
237 Notes
238 -----
238 -----
239 Our implementation of traits is based on Python's descriptor
239 Our implementation of traits is based on Python's descriptor
240 prototol. This class is the base class for all such descriptors. The
240 prototol. This class is the base class for all such descriptors. The
241 only magic we use is a custom metaclass for the main :class:`HasTraits`
241 only magic we use is a custom metaclass for the main :class:`HasTraits`
242 class that does the following:
242 class that does the following:
243
243
244 1. Sets the :attr:`name` attribute of every :class:`TraitType`
244 1. Sets the :attr:`name` attribute of every :class:`TraitType`
245 instance in the class dict to the name of the attribute.
245 instance in the class dict to the name of the attribute.
246 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
246 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
247 instance in the class dict to the *class* that declared the trait.
247 instance in the class dict to the *class* that declared the trait.
248 This is used by the :class:`This` trait to allow subclasses to
248 This is used by the :class:`This` trait to allow subclasses to
249 accept superclasses for :class:`This` values.
249 accept superclasses for :class:`This` values.
250 """
250 """
251
251
252
252
253 metadata = {}
253 metadata = {}
254 default_value = Undefined
254 default_value = Undefined
255 allow_none = False
255 allow_none = False
256 info_text = 'any value'
256 info_text = 'any value'
257
257
258 def __init__(self, default_value=NoDefaultSpecified, allow_none=None, **metadata):
258 def __init__(self, default_value=NoDefaultSpecified, allow_none=None, **metadata):
259 """Create a TraitType.
259 """Create a TraitType.
260 """
260 """
261 if default_value is not NoDefaultSpecified:
261 if default_value is not NoDefaultSpecified:
262 self.default_value = default_value
262 self.default_value = default_value
263 if allow_none is not None:
263 if allow_none is not None:
264 self.allow_none = allow_none
264 self.allow_none = allow_none
265
265
266 if len(metadata) > 0:
266 if len(metadata) > 0:
267 if len(self.metadata) > 0:
267 if len(self.metadata) > 0:
268 self._metadata = self.metadata.copy()
268 self._metadata = self.metadata.copy()
269 self._metadata.update(metadata)
269 self._metadata.update(metadata)
270 else:
270 else:
271 self._metadata = metadata
271 self._metadata = metadata
272 else:
272 else:
273 self._metadata = self.metadata
273 self._metadata = self.metadata
274
274
275 self.init()
275 self.init()
276
276
277 def init(self):
277 def init(self):
278 pass
278 pass
279
279
280 def get_default_value(self):
280 def get_default_value(self):
281 """Create a new instance of the default value."""
281 """Create a new instance of the default value."""
282 return self.default_value
282 return self.default_value
283
283
284 def instance_init(self, obj):
284 def instance_init(self, obj):
285 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
285 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
286
286
287 Some stages of initialization must be delayed until the parent
287 Some stages of initialization must be delayed until the parent
288 :class:`HasTraits` instance has been created. This method is
288 :class:`HasTraits` instance has been created. This method is
289 called in :meth:`HasTraits.__new__` after the instance has been
289 called in :meth:`HasTraits.__new__` after the instance has been
290 created.
290 created.
291
291
292 This method trigger the creation and validation of default values
292 This method trigger the creation and validation of default values
293 and also things like the resolution of str given class names in
293 and also things like the resolution of str given class names in
294 :class:`Type` and :class`Instance`.
294 :class:`Type` and :class`Instance`.
295
295
296 Parameters
296 Parameters
297 ----------
297 ----------
298 obj : :class:`HasTraits` instance
298 obj : :class:`HasTraits` instance
299 The parent :class:`HasTraits` instance that has just been
299 The parent :class:`HasTraits` instance that has just been
300 created.
300 created.
301 """
301 """
302 self.set_default_value(obj)
302 self.set_default_value(obj)
303
303
304 def set_default_value(self, obj):
304 def set_default_value(self, obj):
305 """Set the default value on a per instance basis.
305 """Set the default value on a per instance basis.
306
306
307 This method is called by :meth:`instance_init` to create and
307 This method is called by :meth:`instance_init` to create and
308 validate the default value. The creation and validation of
308 validate the default value. The creation and validation of
309 default values must be delayed until the parent :class:`HasTraits`
309 default values must be delayed until the parent :class:`HasTraits`
310 class has been instantiated.
310 class has been instantiated.
311 """
311 """
312 # Check for a deferred initializer defined in the same class as the
312 # Check for a deferred initializer defined in the same class as the
313 # trait declaration or above.
313 # trait declaration or above.
314 mro = type(obj).mro()
314 mro = type(obj).mro()
315 meth_name = '_%s_default' % self.name
315 meth_name = '_%s_default' % self.name
316 for cls in mro[:mro.index(self.this_class)+1]:
316 for cls in mro[:mro.index(self.this_class)+1]:
317 if meth_name in cls.__dict__:
317 if meth_name in cls.__dict__:
318 break
318 break
319 else:
319 else:
320 # We didn't find one. Do static initialization.
320 # We didn't find one. Do static initialization.
321 dv = self.get_default_value()
321 dv = self.get_default_value()
322 newdv = self._validate(obj, dv)
322 newdv = self._validate(obj, dv)
323 obj._trait_values[self.name] = newdv
323 obj._trait_values[self.name] = newdv
324 return
324 return
325 # Complete the dynamic initialization.
325 # Complete the dynamic initialization.
326 obj._trait_dyn_inits[self.name] = meth_name
326 obj._trait_dyn_inits[self.name] = meth_name
327
327
328 def __get__(self, obj, cls=None):
328 def __get__(self, obj, cls=None):
329 """Get the value of the trait by self.name for the instance.
329 """Get the value of the trait by self.name for the instance.
330
330
331 Default values are instantiated when :meth:`HasTraits.__new__`
331 Default values are instantiated when :meth:`HasTraits.__new__`
332 is called. Thus by the time this method gets called either the
332 is called. Thus by the time this method gets called either the
333 default value or a user defined value (they called :meth:`__set__`)
333 default value or a user defined value (they called :meth:`__set__`)
334 is in the :class:`HasTraits` instance.
334 is in the :class:`HasTraits` instance.
335 """
335 """
336 if obj is None:
336 if obj is None:
337 return self
337 return self
338 else:
338 else:
339 try:
339 try:
340 value = obj._trait_values[self.name]
340 value = obj._trait_values[self.name]
341 except KeyError:
341 except KeyError:
342 # Check for a dynamic initializer.
342 # Check for a dynamic initializer.
343 if self.name in obj._trait_dyn_inits:
343 if self.name in obj._trait_dyn_inits:
344 method = getattr(obj, obj._trait_dyn_inits[self.name])
344 method = getattr(obj, obj._trait_dyn_inits[self.name])
345 value = method()
345 value = method()
346 # FIXME: Do we really validate here?
346 # FIXME: Do we really validate here?
347 value = self._validate(obj, value)
347 value = self._validate(obj, value)
348 obj._trait_values[self.name] = value
348 obj._trait_values[self.name] = value
349 return value
349 return value
350 else:
350 else:
351 raise TraitError('Unexpected error in TraitType: '
351 raise TraitError('Unexpected error in TraitType: '
352 'both default value and dynamic initializer are '
352 'both default value and dynamic initializer are '
353 'absent.')
353 'absent.')
354 except Exception:
354 except Exception:
355 # HasTraits should call set_default_value to populate
355 # HasTraits should call set_default_value to populate
356 # this. So this should never be reached.
356 # this. So this should never be reached.
357 raise TraitError('Unexpected error in TraitType: '
357 raise TraitError('Unexpected error in TraitType: '
358 'default value not set properly')
358 'default value not set properly')
359 else:
359 else:
360 return value
360 return value
361
361
362 def __set__(self, obj, value):
362 def __set__(self, obj, value):
363 new_value = self._validate(obj, value)
363 new_value = self._validate(obj, value)
364 old_value = self.__get__(obj)
364 old_value = self.__get__(obj)
365 obj._trait_values[self.name] = new_value
365 obj._trait_values[self.name] = new_value
366 try:
366 try:
367 silent = bool(old_value == new_value)
367 silent = bool(old_value == new_value)
368 except:
368 except:
369 # if there is an error in comparing, default to notify
369 # if there is an error in comparing, default to notify
370 silent = False
370 silent = False
371 if silent is not True:
371 if silent is not True:
372 # we explicitly compare silent to True just in case the equality
372 # we explicitly compare silent to True just in case the equality
373 # comparison above returns something other than True/False
373 # comparison above returns something other than True/False
374 obj._notify_trait(self.name, old_value, new_value)
374 obj._notify_trait(self.name, old_value, new_value)
375
375
376 def _validate(self, obj, value):
376 def _validate(self, obj, value):
377 if value is None and self.allow_none:
377 if value is None and self.allow_none:
378 return value
378 return value
379 if hasattr(self, 'validate'):
379 if hasattr(self, 'validate'):
380 return self.validate(obj, value)
380 return self.validate(obj, value)
381 elif hasattr(self, 'is_valid_for'):
381 elif hasattr(self, 'is_valid_for'):
382 valid = self.is_valid_for(value)
382 valid = self.is_valid_for(value)
383 if valid:
383 if valid:
384 return value
384 return value
385 else:
385 else:
386 raise TraitError('invalid value for type: %r' % value)
386 raise TraitError('invalid value for type: %r' % value)
387 elif hasattr(self, 'value_for'):
387 elif hasattr(self, 'value_for'):
388 return self.value_for(value)
388 return self.value_for(value)
389 else:
389 else:
390 return value
390 return value
391
391
392 def info(self):
392 def info(self):
393 return self.info_text
393 return self.info_text
394
394
395 def error(self, obj, value):
395 def error(self, obj, value):
396 if obj is not None:
396 if obj is not None:
397 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
397 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
398 % (self.name, class_of(obj),
398 % (self.name, class_of(obj),
399 self.info(), repr_type(value))
399 self.info(), repr_type(value))
400 else:
400 else:
401 e = "The '%s' trait must be %s, but a value of %r was specified." \
401 e = "The '%s' trait must be %s, but a value of %r was specified." \
402 % (self.name, self.info(), repr_type(value))
402 % (self.name, self.info(), repr_type(value))
403 raise TraitError(e)
403 raise TraitError(e)
404
404
405 def get_metadata(self, key):
405 def get_metadata(self, key):
406 return getattr(self, '_metadata', {}).get(key, None)
406 return getattr(self, '_metadata', {}).get(key, None)
407
407
408 def set_metadata(self, key, value):
408 def set_metadata(self, key, value):
409 getattr(self, '_metadata', {})[key] = value
409 getattr(self, '_metadata', {})[key] = value
410
410
411
411
412 #-----------------------------------------------------------------------------
412 #-----------------------------------------------------------------------------
413 # The HasTraits implementation
413 # The HasTraits implementation
414 #-----------------------------------------------------------------------------
414 #-----------------------------------------------------------------------------
415
415
416
416
417 class MetaHasTraits(type):
417 class MetaHasTraits(type):
418 """A metaclass for HasTraits.
418 """A metaclass for HasTraits.
419
419
420 This metaclass makes sure that any TraitType class attributes are
420 This metaclass makes sure that any TraitType class attributes are
421 instantiated and sets their name attribute.
421 instantiated and sets their name attribute.
422 """
422 """
423
423
424 def __new__(mcls, name, bases, classdict):
424 def __new__(mcls, name, bases, classdict):
425 """Create the HasTraits class.
425 """Create the HasTraits class.
426
426
427 This instantiates all TraitTypes in the class dict and sets their
427 This instantiates all TraitTypes in the class dict and sets their
428 :attr:`name` attribute.
428 :attr:`name` attribute.
429 """
429 """
430 # print "MetaHasTraitlets (mcls, name): ", mcls, name
430 # print "MetaHasTraitlets (mcls, name): ", mcls, name
431 # print "MetaHasTraitlets (bases): ", bases
431 # print "MetaHasTraitlets (bases): ", bases
432 # print "MetaHasTraitlets (classdict): ", classdict
432 # print "MetaHasTraitlets (classdict): ", classdict
433 for k,v in iteritems(classdict):
433 for k,v in iteritems(classdict):
434 if isinstance(v, TraitType):
434 if isinstance(v, TraitType):
435 v.name = k
435 v.name = k
436 elif inspect.isclass(v):
436 elif inspect.isclass(v):
437 if issubclass(v, TraitType):
437 if issubclass(v, TraitType):
438 vinst = v()
438 vinst = v()
439 vinst.name = k
439 vinst.name = k
440 classdict[k] = vinst
440 classdict[k] = vinst
441 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
441 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
442
442
443 def __init__(cls, name, bases, classdict):
443 def __init__(cls, name, bases, classdict):
444 """Finish initializing the HasTraits class.
444 """Finish initializing the HasTraits class.
445
445
446 This sets the :attr:`this_class` attribute of each TraitType in the
446 This sets the :attr:`this_class` attribute of each TraitType in the
447 class dict to the newly created class ``cls``.
447 class dict to the newly created class ``cls``.
448 """
448 """
449 for k, v in iteritems(classdict):
449 for k, v in iteritems(classdict):
450 if isinstance(v, TraitType):
450 if isinstance(v, TraitType):
451 v.this_class = cls
451 v.this_class = cls
452 super(MetaHasTraits, cls).__init__(name, bases, classdict)
452 super(MetaHasTraits, cls).__init__(name, bases, classdict)
453
453
454 class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)):
454 class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)):
455
455
456 def __new__(cls, *args, **kw):
456 def __new__(cls, *args, **kw):
457 # This is needed because object.__new__ only accepts
457 # This is needed because object.__new__ only accepts
458 # the cls argument.
458 # the cls argument.
459 new_meth = super(HasTraits, cls).__new__
459 new_meth = super(HasTraits, cls).__new__
460 if new_meth is object.__new__:
460 if new_meth is object.__new__:
461 inst = new_meth(cls)
461 inst = new_meth(cls)
462 else:
462 else:
463 inst = new_meth(cls, **kw)
463 inst = new_meth(cls, **kw)
464 inst._trait_values = {}
464 inst._trait_values = {}
465 inst._trait_notifiers = {}
465 inst._trait_notifiers = {}
466 inst._trait_dyn_inits = {}
466 inst._trait_dyn_inits = {}
467 # Here we tell all the TraitType instances to set their default
467 # Here we tell all the TraitType instances to set their default
468 # values on the instance.
468 # values on the instance.
469 for key in dir(cls):
469 for key in dir(cls):
470 # Some descriptors raise AttributeError like zope.interface's
470 # Some descriptors raise AttributeError like zope.interface's
471 # __provides__ attributes even though they exist. This causes
471 # __provides__ attributes even though they exist. This causes
472 # AttributeErrors even though they are listed in dir(cls).
472 # AttributeErrors even though they are listed in dir(cls).
473 try:
473 try:
474 value = getattr(cls, key)
474 value = getattr(cls, key)
475 except AttributeError:
475 except AttributeError:
476 pass
476 pass
477 else:
477 else:
478 if isinstance(value, TraitType):
478 if isinstance(value, TraitType):
479 value.instance_init(inst)
479 value.instance_init(inst)
480
480
481 return inst
481 return inst
482
482
483 def __init__(self, *args, **kw):
483 def __init__(self, *args, **kw):
484 # Allow trait values to be set using keyword arguments.
484 # Allow trait values to be set using keyword arguments.
485 # We need to use setattr for this to trigger validation and
485 # We need to use setattr for this to trigger validation and
486 # notifications.
486 # notifications.
487 for key, value in iteritems(kw):
487 for key, value in iteritems(kw):
488 setattr(self, key, value)
488 setattr(self, key, value)
489
489
490 def _notify_trait(self, name, old_value, new_value):
490 def _notify_trait(self, name, old_value, new_value):
491
491
492 # First dynamic ones
492 # First dynamic ones
493 callables = []
493 callables = []
494 callables.extend(self._trait_notifiers.get(name,[]))
494 callables.extend(self._trait_notifiers.get(name,[]))
495 callables.extend(self._trait_notifiers.get('anytrait',[]))
495 callables.extend(self._trait_notifiers.get('anytrait',[]))
496
496
497 # Now static ones
497 # Now static ones
498 try:
498 try:
499 cb = getattr(self, '_%s_changed' % name)
499 cb = getattr(self, '_%s_changed' % name)
500 except:
500 except:
501 pass
501 pass
502 else:
502 else:
503 callables.append(cb)
503 callables.append(cb)
504
504
505 # Call them all now
505 # Call them all now
506 for c in callables:
506 for c in callables:
507 # Traits catches and logs errors here. I allow them to raise
507 # Traits catches and logs errors here. I allow them to raise
508 if callable(c):
508 if callable(c):
509 argspec = inspect.getargspec(c)
509 argspec = inspect.getargspec(c)
510 nargs = len(argspec[0])
510 nargs = len(argspec[0])
511 # Bound methods have an additional 'self' argument
511 # Bound methods have an additional 'self' argument
512 # I don't know how to treat unbound methods, but they
512 # I don't know how to treat unbound methods, but they
513 # can't really be used for callbacks.
513 # can't really be used for callbacks.
514 if isinstance(c, types.MethodType):
514 if isinstance(c, types.MethodType):
515 offset = -1
515 offset = -1
516 else:
516 else:
517 offset = 0
517 offset = 0
518 if nargs + offset == 0:
518 if nargs + offset == 0:
519 c()
519 c()
520 elif nargs + offset == 1:
520 elif nargs + offset == 1:
521 c(name)
521 c(name)
522 elif nargs + offset == 2:
522 elif nargs + offset == 2:
523 c(name, new_value)
523 c(name, new_value)
524 elif nargs + offset == 3:
524 elif nargs + offset == 3:
525 c(name, old_value, new_value)
525 c(name, old_value, new_value)
526 else:
526 else:
527 raise TraitError('a trait changed callback '
527 raise TraitError('a trait changed callback '
528 'must have 0-3 arguments.')
528 'must have 0-3 arguments.')
529 else:
529 else:
530 raise TraitError('a trait changed callback '
530 raise TraitError('a trait changed callback '
531 'must be callable.')
531 'must be callable.')
532
532
533
533
534 def _add_notifiers(self, handler, name):
534 def _add_notifiers(self, handler, name):
535 if name not in self._trait_notifiers:
535 if name not in self._trait_notifiers:
536 nlist = []
536 nlist = []
537 self._trait_notifiers[name] = nlist
537 self._trait_notifiers[name] = nlist
538 else:
538 else:
539 nlist = self._trait_notifiers[name]
539 nlist = self._trait_notifiers[name]
540 if handler not in nlist:
540 if handler not in nlist:
541 nlist.append(handler)
541 nlist.append(handler)
542
542
543 def _remove_notifiers(self, handler, name):
543 def _remove_notifiers(self, handler, name):
544 if name in self._trait_notifiers:
544 if name in self._trait_notifiers:
545 nlist = self._trait_notifiers[name]
545 nlist = self._trait_notifiers[name]
546 try:
546 try:
547 index = nlist.index(handler)
547 index = nlist.index(handler)
548 except ValueError:
548 except ValueError:
549 pass
549 pass
550 else:
550 else:
551 del nlist[index]
551 del nlist[index]
552
552
553 def on_trait_change(self, handler, name=None, remove=False):
553 def on_trait_change(self, handler, name=None, remove=False):
554 """Setup a handler to be called when a trait changes.
554 """Setup a handler to be called when a trait changes.
555
555
556 This is used to setup dynamic notifications of trait changes.
556 This is used to setup dynamic notifications of trait changes.
557
557
558 Static handlers can be created by creating methods on a HasTraits
558 Static handlers can be created by creating methods on a HasTraits
559 subclass with the naming convention '_[traitname]_changed'. Thus,
559 subclass with the naming convention '_[traitname]_changed'. Thus,
560 to create static handler for the trait 'a', create the method
560 to create static handler for the trait 'a', create the method
561 _a_changed(self, name, old, new) (fewer arguments can be used, see
561 _a_changed(self, name, old, new) (fewer arguments can be used, see
562 below).
562 below).
563
563
564 Parameters
564 Parameters
565 ----------
565 ----------
566 handler : callable
566 handler : callable
567 A callable that is called when a trait changes. Its
567 A callable that is called when a trait changes. Its
568 signature can be handler(), handler(name), handler(name, new)
568 signature can be handler(), handler(name), handler(name, new)
569 or handler(name, old, new).
569 or handler(name, old, new).
570 name : list, str, None
570 name : list, str, None
571 If None, the handler will apply to all traits. If a list
571 If None, the handler will apply to all traits. If a list
572 of str, handler will apply to all names in the list. If a
572 of str, handler will apply to all names in the list. If a
573 str, the handler will apply just to that name.
573 str, the handler will apply just to that name.
574 remove : bool
574 remove : bool
575 If False (the default), then install the handler. If True
575 If False (the default), then install the handler. If True
576 then unintall it.
576 then unintall it.
577 """
577 """
578 if remove:
578 if remove:
579 names = parse_notifier_name(name)
579 names = parse_notifier_name(name)
580 for n in names:
580 for n in names:
581 self._remove_notifiers(handler, n)
581 self._remove_notifiers(handler, n)
582 else:
582 else:
583 names = parse_notifier_name(name)
583 names = parse_notifier_name(name)
584 for n in names:
584 for n in names:
585 self._add_notifiers(handler, n)
585 self._add_notifiers(handler, n)
586
586
587 @classmethod
587 @classmethod
588 def class_trait_names(cls, **metadata):
588 def class_trait_names(cls, **metadata):
589 """Get a list of all the names of this class' traits.
589 """Get a list of all the names of this class' traits.
590
590
591 This method is just like the :meth:`trait_names` method,
591 This method is just like the :meth:`trait_names` method,
592 but is unbound.
592 but is unbound.
593 """
593 """
594 return cls.class_traits(**metadata).keys()
594 return cls.class_traits(**metadata).keys()
595
595
596 @classmethod
596 @classmethod
597 def class_traits(cls, **metadata):
597 def class_traits(cls, **metadata):
598 """Get a `dict` of all the traits of this class. The dictionary
598 """Get a `dict` of all the traits of this class. The dictionary
599 is keyed on the name and the values are the TraitType objects.
599 is keyed on the name and the values are the TraitType objects.
600
600
601 This method is just like the :meth:`traits` method, but is unbound.
601 This method is just like the :meth:`traits` method, but is unbound.
602
602
603 The TraitTypes returned don't know anything about the values
603 The TraitTypes returned don't know anything about the values
604 that the various HasTrait's instances are holding.
604 that the various HasTrait's instances are holding.
605
605
606 The metadata kwargs allow functions to be passed in which
606 The metadata kwargs allow functions to be passed in which
607 filter traits based on metadata values. The functions should
607 filter traits based on metadata values. The functions should
608 take a single value as an argument and return a boolean. If
608 take a single value as an argument and return a boolean. If
609 any function returns False, then the trait is not included in
609 any function returns False, then the trait is not included in
610 the output. This does not allow for any simple way of
610 the output. This does not allow for any simple way of
611 testing that a metadata name exists and has any
611 testing that a metadata name exists and has any
612 value because get_metadata returns None if a metadata key
612 value because get_metadata returns None if a metadata key
613 doesn't exist.
613 doesn't exist.
614 """
614 """
615 traits = dict([memb for memb in getmembers(cls) if
615 traits = dict([memb for memb in getmembers(cls) if
616 isinstance(memb[1], TraitType)])
616 isinstance(memb[1], TraitType)])
617
617
618 if len(metadata) == 0:
618 if len(metadata) == 0:
619 return traits
619 return traits
620
620
621 for meta_name, meta_eval in metadata.items():
621 for meta_name, meta_eval in metadata.items():
622 if type(meta_eval) is not FunctionType:
622 if type(meta_eval) is not FunctionType:
623 metadata[meta_name] = _SimpleTest(meta_eval)
623 metadata[meta_name] = _SimpleTest(meta_eval)
624
624
625 result = {}
625 result = {}
626 for name, trait in traits.items():
626 for name, trait in traits.items():
627 for meta_name, meta_eval in metadata.items():
627 for meta_name, meta_eval in metadata.items():
628 if not meta_eval(trait.get_metadata(meta_name)):
628 if not meta_eval(trait.get_metadata(meta_name)):
629 break
629 break
630 else:
630 else:
631 result[name] = trait
631 result[name] = trait
632
632
633 return result
633 return result
634
634
635 def trait_names(self, **metadata):
635 def trait_names(self, **metadata):
636 """Get a list of all the names of this class' traits."""
636 """Get a list of all the names of this class' traits."""
637 return self.traits(**metadata).keys()
637 return self.traits(**metadata).keys()
638
638
639 def traits(self, **metadata):
639 def traits(self, **metadata):
640 """Get a `dict` of all the traits of this class. The dictionary
640 """Get a `dict` of all the traits of this class. The dictionary
641 is keyed on the name and the values are the TraitType objects.
641 is keyed on the name and the values are the TraitType objects.
642
642
643 The TraitTypes returned don't know anything about the values
643 The TraitTypes returned don't know anything about the values
644 that the various HasTrait's instances are holding.
644 that the various HasTrait's instances are holding.
645
645
646 The metadata kwargs allow functions to be passed in which
646 The metadata kwargs allow functions to be passed in which
647 filter traits based on metadata values. The functions should
647 filter traits based on metadata values. The functions should
648 take a single value as an argument and return a boolean. If
648 take a single value as an argument and return a boolean. If
649 any function returns False, then the trait is not included in
649 any function returns False, then the trait is not included in
650 the output. This does not allow for any simple way of
650 the output. This does not allow for any simple way of
651 testing that a metadata name exists and has any
651 testing that a metadata name exists and has any
652 value because get_metadata returns None if a metadata key
652 value because get_metadata returns None if a metadata key
653 doesn't exist.
653 doesn't exist.
654 """
654 """
655 traits = dict([memb for memb in getmembers(self.__class__) if
655 traits = dict([memb for memb in getmembers(self.__class__) if
656 isinstance(memb[1], TraitType)])
656 isinstance(memb[1], TraitType)])
657
657
658 if len(metadata) == 0:
658 if len(metadata) == 0:
659 return traits
659 return traits
660
660
661 for meta_name, meta_eval in metadata.items():
661 for meta_name, meta_eval in metadata.items():
662 if type(meta_eval) is not FunctionType:
662 if type(meta_eval) is not FunctionType:
663 metadata[meta_name] = _SimpleTest(meta_eval)
663 metadata[meta_name] = _SimpleTest(meta_eval)
664
664
665 result = {}
665 result = {}
666 for name, trait in traits.items():
666 for name, trait in traits.items():
667 for meta_name, meta_eval in metadata.items():
667 for meta_name, meta_eval in metadata.items():
668 if not meta_eval(trait.get_metadata(meta_name)):
668 if not meta_eval(trait.get_metadata(meta_name)):
669 break
669 break
670 else:
670 else:
671 result[name] = trait
671 result[name] = trait
672
672
673 return result
673 return result
674
674
675 def trait_metadata(self, traitname, key):
675 def trait_metadata(self, traitname, key):
676 """Get metadata values for trait by key."""
676 """Get metadata values for trait by key."""
677 try:
677 try:
678 trait = getattr(self.__class__, traitname)
678 trait = getattr(self.__class__, traitname)
679 except AttributeError:
679 except AttributeError:
680 raise TraitError("Class %s does not have a trait named %s" %
680 raise TraitError("Class %s does not have a trait named %s" %
681 (self.__class__.__name__, traitname))
681 (self.__class__.__name__, traitname))
682 else:
682 else:
683 return trait.get_metadata(key)
683 return trait.get_metadata(key)
684
684
685 #-----------------------------------------------------------------------------
685 #-----------------------------------------------------------------------------
686 # Actual TraitTypes implementations/subclasses
686 # Actual TraitTypes implementations/subclasses
687 #-----------------------------------------------------------------------------
687 #-----------------------------------------------------------------------------
688
688
689 #-----------------------------------------------------------------------------
689 #-----------------------------------------------------------------------------
690 # TraitTypes subclasses for handling classes and instances of classes
690 # TraitTypes subclasses for handling classes and instances of classes
691 #-----------------------------------------------------------------------------
691 #-----------------------------------------------------------------------------
692
692
693
693
694 class ClassBasedTraitType(TraitType):
694 class ClassBasedTraitType(TraitType):
695 """A trait with error reporting for Type, Instance and This."""
695 """A trait with error reporting for Type, Instance and This."""
696
696
697 def error(self, obj, value):
697 def error(self, obj, value):
698 kind = type(value)
698 kind = type(value)
699 if (not py3compat.PY3) and kind is InstanceType:
699 if (not py3compat.PY3) and kind is InstanceType:
700 msg = 'class %s' % value.__class__.__name__
700 msg = 'class %s' % value.__class__.__name__
701 else:
701 else:
702 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
702 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
703
703
704 if obj is not None:
704 if obj is not None:
705 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
705 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
706 % (self.name, class_of(obj),
706 % (self.name, class_of(obj),
707 self.info(), msg)
707 self.info(), msg)
708 else:
708 else:
709 e = "The '%s' trait must be %s, but a value of %r was specified." \
709 e = "The '%s' trait must be %s, but a value of %r was specified." \
710 % (self.name, self.info(), msg)
710 % (self.name, self.info(), msg)
711
711
712 raise TraitError(e)
712 raise TraitError(e)
713
713
714
714
715 class Type(ClassBasedTraitType):
715 class Type(ClassBasedTraitType):
716 """A trait whose value must be a subclass of a specified class."""
716 """A trait whose value must be a subclass of a specified class."""
717
717
718 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
718 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
719 """Construct a Type trait
719 """Construct a Type trait
720
720
721 A Type trait specifies that its values must be subclasses of
721 A Type trait specifies that its values must be subclasses of
722 a particular class.
722 a particular class.
723
723
724 If only ``default_value`` is given, it is used for the ``klass`` as
724 If only ``default_value`` is given, it is used for the ``klass`` as
725 well.
725 well.
726
726
727 Parameters
727 Parameters
728 ----------
728 ----------
729 default_value : class, str or None
729 default_value : class, str or None
730 The default value must be a subclass of klass. If an str,
730 The default value must be a subclass of klass. If an str,
731 the str must be a fully specified class name, like 'foo.bar.Bah'.
731 the str must be a fully specified class name, like 'foo.bar.Bah'.
732 The string is resolved into real class, when the parent
732 The string is resolved into real class, when the parent
733 :class:`HasTraits` class is instantiated.
733 :class:`HasTraits` class is instantiated.
734 klass : class, str, None
734 klass : class, str, None
735 Values of this trait must be a subclass of klass. The klass
735 Values of this trait must be a subclass of klass. The klass
736 may be specified in a string like: 'foo.bar.MyClass'.
736 may be specified in a string like: 'foo.bar.MyClass'.
737 The string is resolved into real class, when the parent
737 The string is resolved into real class, when the parent
738 :class:`HasTraits` class is instantiated.
738 :class:`HasTraits` class is instantiated.
739 allow_none : boolean
739 allow_none : boolean
740 Indicates whether None is allowed as an assignable value. Even if
740 Indicates whether None is allowed as an assignable value. Even if
741 ``False``, the default value may be ``None``.
741 ``False``, the default value may be ``None``.
742 """
742 """
743 if default_value is None:
743 if default_value is None:
744 if klass is None:
744 if klass is None:
745 klass = object
745 klass = object
746 elif klass is None:
746 elif klass is None:
747 klass = default_value
747 klass = default_value
748
748
749 if not (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
749 if not (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
750 raise TraitError("A Type trait must specify a class.")
750 raise TraitError("A Type trait must specify a class.")
751
751
752 self.klass = klass
752 self.klass = klass
753
753
754 super(Type, self).__init__(default_value, allow_none=allow_none, **metadata)
754 super(Type, self).__init__(default_value, allow_none=allow_none, **metadata)
755
755
756 def validate(self, obj, value):
756 def validate(self, obj, value):
757 """Validates that the value is a valid object instance."""
757 """Validates that the value is a valid object instance."""
758 if isinstance(value, py3compat.string_types):
758 if isinstance(value, py3compat.string_types):
759 try:
759 try:
760 value = import_item(value)
760 value = import_item(value)
761 except ImportError:
761 except ImportError:
762 raise TraitError("The '%s' trait of %s instance must be a type, but "
762 raise TraitError("The '%s' trait of %s instance must be a type, but "
763 "%r could not be imported" % (self.name, obj, value))
763 "%r could not be imported" % (self.name, obj, value))
764 try:
764 try:
765 if issubclass(value, self.klass):
765 if issubclass(value, self.klass):
766 return value
766 return value
767 except:
767 except:
768 pass
768 pass
769
769
770 self.error(obj, value)
770 self.error(obj, value)
771
771
772 def info(self):
772 def info(self):
773 """ Returns a description of the trait."""
773 """ Returns a description of the trait."""
774 if isinstance(self.klass, py3compat.string_types):
774 if isinstance(self.klass, py3compat.string_types):
775 klass = self.klass
775 klass = self.klass
776 else:
776 else:
777 klass = self.klass.__name__
777 klass = self.klass.__name__
778 result = 'a subclass of ' + klass
778 result = 'a subclass of ' + klass
779 if self.allow_none:
779 if self.allow_none:
780 return result + ' or None'
780 return result + ' or None'
781 return result
781 return result
782
782
783 def instance_init(self, obj):
783 def instance_init(self, obj):
784 self._resolve_classes()
784 self._resolve_classes()
785 super(Type, self).instance_init(obj)
785 super(Type, self).instance_init(obj)
786
786
787 def _resolve_classes(self):
787 def _resolve_classes(self):
788 if isinstance(self.klass, py3compat.string_types):
788 if isinstance(self.klass, py3compat.string_types):
789 self.klass = import_item(self.klass)
789 self.klass = import_item(self.klass)
790 if isinstance(self.default_value, py3compat.string_types):
790 if isinstance(self.default_value, py3compat.string_types):
791 self.default_value = import_item(self.default_value)
791 self.default_value = import_item(self.default_value)
792
792
793 def get_default_value(self):
793 def get_default_value(self):
794 return self.default_value
794 return self.default_value
795
795
796
796
797 class DefaultValueGenerator(object):
797 class DefaultValueGenerator(object):
798 """A class for generating new default value instances."""
798 """A class for generating new default value instances."""
799
799
800 def __init__(self, *args, **kw):
800 def __init__(self, *args, **kw):
801 self.args = args
801 self.args = args
802 self.kw = kw
802 self.kw = kw
803
803
804 def generate(self, klass):
804 def generate(self, klass):
805 return klass(*self.args, **self.kw)
805 return klass(*self.args, **self.kw)
806
806
807
807
808 class Instance(ClassBasedTraitType):
808 class Instance(ClassBasedTraitType):
809 """A trait whose value must be an instance of a specified class.
809 """A trait whose value must be an instance of a specified class.
810
810
811 The value can also be an instance of a subclass of the specified class.
811 The value can also be an instance of a subclass of the specified class.
812
813 Subclasses can declare default classes by overriding the klass attribute
812 """
814 """
813
815
816 klass = None
817
814 def __init__(self, klass=None, args=None, kw=None,
818 def __init__(self, klass=None, args=None, kw=None,
815 allow_none=True, **metadata ):
819 allow_none=True, **metadata ):
816 """Construct an Instance trait.
820 """Construct an Instance trait.
817
821
818 This trait allows values that are instances of a particular
822 This trait allows values that are instances of a particular
819 class or its sublclasses. Our implementation is quite different
823 class or its sublclasses. Our implementation is quite different
820 from that of enthough.traits as we don't allow instances to be used
824 from that of enthough.traits as we don't allow instances to be used
821 for klass and we handle the ``args`` and ``kw`` arguments differently.
825 for klass and we handle the ``args`` and ``kw`` arguments differently.
822
826
823 Parameters
827 Parameters
824 ----------
828 ----------
825 klass : class, str
829 klass : class, str
826 The class that forms the basis for the trait. Class names
830 The class that forms the basis for the trait. Class names
827 can also be specified as strings, like 'foo.bar.Bar'.
831 can also be specified as strings, like 'foo.bar.Bar'.
828 args : tuple
832 args : tuple
829 Positional arguments for generating the default value.
833 Positional arguments for generating the default value.
830 kw : dict
834 kw : dict
831 Keyword arguments for generating the default value.
835 Keyword arguments for generating the default value.
832 allow_none : bool
836 allow_none : bool
833 Indicates whether None is allowed as a value.
837 Indicates whether None is allowed as a value.
834
838
835 Notes
839 Notes
836 -----
840 -----
837 If both ``args`` and ``kw`` are None, then the default value is None.
841 If both ``args`` and ``kw`` are None, then the default value is None.
838 If ``args`` is a tuple and ``kw`` is a dict, then the default is
842 If ``args`` is a tuple and ``kw`` is a dict, then the default is
839 created as ``klass(*args, **kw)``. If either ``args`` or ``kw`` is
843 created as ``klass(*args, **kw)``. If exactly one of ``args`` or ``kw`` is
840 not (but not both), None is replace by ``()`` or ``{}``.
844 None, the None is replaced by ``()`` or ``{}``, respectively.
841 """
845 """
842
846 if klass is None:
843 if (klass is None) or (not (inspect.isclass(klass) or isinstance(klass, py3compat.string_types))):
847 klass = self.klass
844 raise TraitError('The klass argument must be a class'
848
845 ' you gave: %r' % klass)
849 if (klass is not None) and (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
846 self.klass = klass
850 self.klass = klass
851 else:
852 raise TraitError('The klass attribute must be a class'
853 ' not: %r' % klass)
847
854
848 # self.klass is a class, so handle default_value
855 # self.klass is a class, so handle default_value
849 if args is None and kw is None:
856 if args is None and kw is None:
850 default_value = None
857 default_value = None
851 else:
858 else:
852 if args is None:
859 if args is None:
853 # kw is not None
860 # kw is not None
854 args = ()
861 args = ()
855 elif kw is None:
862 elif kw is None:
856 # args is not None
863 # args is not None
857 kw = {}
864 kw = {}
858
865
859 if not isinstance(kw, dict):
866 if not isinstance(kw, dict):
860 raise TraitError("The 'kw' argument must be a dict or None.")
867 raise TraitError("The 'kw' argument must be a dict or None.")
861 if not isinstance(args, tuple):
868 if not isinstance(args, tuple):
862 raise TraitError("The 'args' argument must be a tuple or None.")
869 raise TraitError("The 'args' argument must be a tuple or None.")
863
870
864 default_value = DefaultValueGenerator(*args, **kw)
871 default_value = DefaultValueGenerator(*args, **kw)
865
872
866 super(Instance, self).__init__(default_value, allow_none=allow_none, **metadata)
873 super(Instance, self).__init__(default_value, allow_none=allow_none, **metadata)
867
874
868 def validate(self, obj, value):
875 def validate(self, obj, value):
869 if isinstance(value, self.klass):
876 if isinstance(value, self.klass):
870 return value
877 return value
871 else:
878 else:
872 self.error(obj, value)
879 self.error(obj, value)
873
880
874 def info(self):
881 def info(self):
875 if isinstance(self.klass, py3compat.string_types):
882 if isinstance(self.klass, py3compat.string_types):
876 klass = self.klass
883 klass = self.klass
877 else:
884 else:
878 klass = self.klass.__name__
885 klass = self.klass.__name__
879 result = class_of(klass)
886 result = class_of(klass)
880 if self.allow_none:
887 if self.allow_none:
881 return result + ' or None'
888 return result + ' or None'
882
889
883 return result
890 return result
884
891
885 def instance_init(self, obj):
892 def instance_init(self, obj):
886 self._resolve_classes()
893 self._resolve_classes()
887 super(Instance, self).instance_init(obj)
894 super(Instance, self).instance_init(obj)
888
895
889 def _resolve_classes(self):
896 def _resolve_classes(self):
890 if isinstance(self.klass, py3compat.string_types):
897 if isinstance(self.klass, py3compat.string_types):
891 self.klass = import_item(self.klass)
898 self.klass = import_item(self.klass)
892
899
893 def get_default_value(self):
900 def get_default_value(self):
894 """Instantiate a default value instance.
901 """Instantiate a default value instance.
895
902
896 This is called when the containing HasTraits classes'
903 This is called when the containing HasTraits classes'
897 :meth:`__new__` method is called to ensure that a unique instance
904 :meth:`__new__` method is called to ensure that a unique instance
898 is created for each HasTraits instance.
905 is created for each HasTraits instance.
899 """
906 """
900 dv = self.default_value
907 dv = self.default_value
901 if isinstance(dv, DefaultValueGenerator):
908 if isinstance(dv, DefaultValueGenerator):
902 return dv.generate(self.klass)
909 return dv.generate(self.klass)
903 else:
910 else:
904 return dv
911 return dv
905
912
906
913
907 class This(ClassBasedTraitType):
914 class This(ClassBasedTraitType):
908 """A trait for instances of the class containing this trait.
915 """A trait for instances of the class containing this trait.
909
916
910 Because how how and when class bodies are executed, the ``This``
917 Because how how and when class bodies are executed, the ``This``
911 trait can only have a default value of None. This, and because we
918 trait can only have a default value of None. This, and because we
912 always validate default values, ``allow_none`` is *always* true.
919 always validate default values, ``allow_none`` is *always* true.
913 """
920 """
914
921
915 info_text = 'an instance of the same type as the receiver or None'
922 info_text = 'an instance of the same type as the receiver or None'
916
923
917 def __init__(self, **metadata):
924 def __init__(self, **metadata):
918 super(This, self).__init__(None, **metadata)
925 super(This, self).__init__(None, **metadata)
919
926
920 def validate(self, obj, value):
927 def validate(self, obj, value):
921 # What if value is a superclass of obj.__class__? This is
928 # What if value is a superclass of obj.__class__? This is
922 # complicated if it was the superclass that defined the This
929 # complicated if it was the superclass that defined the This
923 # trait.
930 # trait.
924 if isinstance(value, self.this_class) or (value is None):
931 if isinstance(value, self.this_class) or (value is None):
925 return value
932 return value
926 else:
933 else:
927 self.error(obj, value)
934 self.error(obj, value)
928
935
929
936
930 #-----------------------------------------------------------------------------
937 #-----------------------------------------------------------------------------
931 # Basic TraitTypes implementations/subclasses
938 # Basic TraitTypes implementations/subclasses
932 #-----------------------------------------------------------------------------
939 #-----------------------------------------------------------------------------
933
940
934
941
935 class Any(TraitType):
942 class Any(TraitType):
936 default_value = None
943 default_value = None
937 info_text = 'any value'
944 info_text = 'any value'
938
945
939
946
940 class Int(TraitType):
947 class Int(TraitType):
941 """An int trait."""
948 """An int trait."""
942
949
943 default_value = 0
950 default_value = 0
944 info_text = 'an int'
951 info_text = 'an int'
945
952
946 def validate(self, obj, value):
953 def validate(self, obj, value):
947 if isinstance(value, int):
954 if isinstance(value, int):
948 return value
955 return value
949 self.error(obj, value)
956 self.error(obj, value)
950
957
951 class CInt(Int):
958 class CInt(Int):
952 """A casting version of the int trait."""
959 """A casting version of the int trait."""
953
960
954 def validate(self, obj, value):
961 def validate(self, obj, value):
955 try:
962 try:
956 return int(value)
963 return int(value)
957 except:
964 except:
958 self.error(obj, value)
965 self.error(obj, value)
959
966
960 if py3compat.PY3:
967 if py3compat.PY3:
961 Long, CLong = Int, CInt
968 Long, CLong = Int, CInt
962 Integer = Int
969 Integer = Int
963 else:
970 else:
964 class Long(TraitType):
971 class Long(TraitType):
965 """A long integer trait."""
972 """A long integer trait."""
966
973
967 default_value = 0
974 default_value = 0
968 info_text = 'a long'
975 info_text = 'a long'
969
976
970 def validate(self, obj, value):
977 def validate(self, obj, value):
971 if isinstance(value, long):
978 if isinstance(value, long):
972 return value
979 return value
973 if isinstance(value, int):
980 if isinstance(value, int):
974 return long(value)
981 return long(value)
975 self.error(obj, value)
982 self.error(obj, value)
976
983
977
984
978 class CLong(Long):
985 class CLong(Long):
979 """A casting version of the long integer trait."""
986 """A casting version of the long integer trait."""
980
987
981 def validate(self, obj, value):
988 def validate(self, obj, value):
982 try:
989 try:
983 return long(value)
990 return long(value)
984 except:
991 except:
985 self.error(obj, value)
992 self.error(obj, value)
986
993
987 class Integer(TraitType):
994 class Integer(TraitType):
988 """An integer trait.
995 """An integer trait.
989
996
990 Longs that are unnecessary (<= sys.maxint) are cast to ints."""
997 Longs that are unnecessary (<= sys.maxint) are cast to ints."""
991
998
992 default_value = 0
999 default_value = 0
993 info_text = 'an integer'
1000 info_text = 'an integer'
994
1001
995 def validate(self, obj, value):
1002 def validate(self, obj, value):
996 if isinstance(value, int):
1003 if isinstance(value, int):
997 return value
1004 return value
998 if isinstance(value, long):
1005 if isinstance(value, long):
999 # downcast longs that fit in int:
1006 # downcast longs that fit in int:
1000 # note that int(n > sys.maxint) returns a long, so
1007 # note that int(n > sys.maxint) returns a long, so
1001 # we don't need a condition on this cast
1008 # we don't need a condition on this cast
1002 return int(value)
1009 return int(value)
1003 if sys.platform == "cli":
1010 if sys.platform == "cli":
1004 from System import Int64
1011 from System import Int64
1005 if isinstance(value, Int64):
1012 if isinstance(value, Int64):
1006 return int(value)
1013 return int(value)
1007 self.error(obj, value)
1014 self.error(obj, value)
1008
1015
1009
1016
1010 class Float(TraitType):
1017 class Float(TraitType):
1011 """A float trait."""
1018 """A float trait."""
1012
1019
1013 default_value = 0.0
1020 default_value = 0.0
1014 info_text = 'a float'
1021 info_text = 'a float'
1015
1022
1016 def validate(self, obj, value):
1023 def validate(self, obj, value):
1017 if isinstance(value, float):
1024 if isinstance(value, float):
1018 return value
1025 return value
1019 if isinstance(value, int):
1026 if isinstance(value, int):
1020 return float(value)
1027 return float(value)
1021 self.error(obj, value)
1028 self.error(obj, value)
1022
1029
1023
1030
1024 class CFloat(Float):
1031 class CFloat(Float):
1025 """A casting version of the float trait."""
1032 """A casting version of the float trait."""
1026
1033
1027 def validate(self, obj, value):
1034 def validate(self, obj, value):
1028 try:
1035 try:
1029 return float(value)
1036 return float(value)
1030 except:
1037 except:
1031 self.error(obj, value)
1038 self.error(obj, value)
1032
1039
1033 class Complex(TraitType):
1040 class Complex(TraitType):
1034 """A trait for complex numbers."""
1041 """A trait for complex numbers."""
1035
1042
1036 default_value = 0.0 + 0.0j
1043 default_value = 0.0 + 0.0j
1037 info_text = 'a complex number'
1044 info_text = 'a complex number'
1038
1045
1039 def validate(self, obj, value):
1046 def validate(self, obj, value):
1040 if isinstance(value, complex):
1047 if isinstance(value, complex):
1041 return value
1048 return value
1042 if isinstance(value, (float, int)):
1049 if isinstance(value, (float, int)):
1043 return complex(value)
1050 return complex(value)
1044 self.error(obj, value)
1051 self.error(obj, value)
1045
1052
1046
1053
1047 class CComplex(Complex):
1054 class CComplex(Complex):
1048 """A casting version of the complex number trait."""
1055 """A casting version of the complex number trait."""
1049
1056
1050 def validate (self, obj, value):
1057 def validate (self, obj, value):
1051 try:
1058 try:
1052 return complex(value)
1059 return complex(value)
1053 except:
1060 except:
1054 self.error(obj, value)
1061 self.error(obj, value)
1055
1062
1056 # We should always be explicit about whether we're using bytes or unicode, both
1063 # We should always be explicit about whether we're using bytes or unicode, both
1057 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
1064 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
1058 # we don't have a Str type.
1065 # we don't have a Str type.
1059 class Bytes(TraitType):
1066 class Bytes(TraitType):
1060 """A trait for byte strings."""
1067 """A trait for byte strings."""
1061
1068
1062 default_value = b''
1069 default_value = b''
1063 info_text = 'a bytes object'
1070 info_text = 'a bytes object'
1064
1071
1065 def validate(self, obj, value):
1072 def validate(self, obj, value):
1066 if isinstance(value, bytes):
1073 if isinstance(value, bytes):
1067 return value
1074 return value
1068 self.error(obj, value)
1075 self.error(obj, value)
1069
1076
1070
1077
1071 class CBytes(Bytes):
1078 class CBytes(Bytes):
1072 """A casting version of the byte string trait."""
1079 """A casting version of the byte string trait."""
1073
1080
1074 def validate(self, obj, value):
1081 def validate(self, obj, value):
1075 try:
1082 try:
1076 return bytes(value)
1083 return bytes(value)
1077 except:
1084 except:
1078 self.error(obj, value)
1085 self.error(obj, value)
1079
1086
1080
1087
1081 class Unicode(TraitType):
1088 class Unicode(TraitType):
1082 """A trait for unicode strings."""
1089 """A trait for unicode strings."""
1083
1090
1084 default_value = u''
1091 default_value = u''
1085 info_text = 'a unicode string'
1092 info_text = 'a unicode string'
1086
1093
1087 def validate(self, obj, value):
1094 def validate(self, obj, value):
1088 if isinstance(value, py3compat.unicode_type):
1095 if isinstance(value, py3compat.unicode_type):
1089 return value
1096 return value
1090 if isinstance(value, bytes):
1097 if isinstance(value, bytes):
1091 try:
1098 try:
1092 return value.decode('ascii', 'strict')
1099 return value.decode('ascii', 'strict')
1093 except UnicodeDecodeError:
1100 except UnicodeDecodeError:
1094 msg = "Could not decode {!r} for unicode trait '{}' of {} instance."
1101 msg = "Could not decode {!r} for unicode trait '{}' of {} instance."
1095 raise TraitError(msg.format(value, self.name, class_of(obj)))
1102 raise TraitError(msg.format(value, self.name, class_of(obj)))
1096 self.error(obj, value)
1103 self.error(obj, value)
1097
1104
1098
1105
1099 class CUnicode(Unicode):
1106 class CUnicode(Unicode):
1100 """A casting version of the unicode trait."""
1107 """A casting version of the unicode trait."""
1101
1108
1102 def validate(self, obj, value):
1109 def validate(self, obj, value):
1103 try:
1110 try:
1104 return py3compat.unicode_type(value)
1111 return py3compat.unicode_type(value)
1105 except:
1112 except:
1106 self.error(obj, value)
1113 self.error(obj, value)
1107
1114
1108
1115
1109 class ObjectName(TraitType):
1116 class ObjectName(TraitType):
1110 """A string holding a valid object name in this version of Python.
1117 """A string holding a valid object name in this version of Python.
1111
1118
1112 This does not check that the name exists in any scope."""
1119 This does not check that the name exists in any scope."""
1113 info_text = "a valid object identifier in Python"
1120 info_text = "a valid object identifier in Python"
1114
1121
1115 if py3compat.PY3:
1122 if py3compat.PY3:
1116 # Python 3:
1123 # Python 3:
1117 coerce_str = staticmethod(lambda _,s: s)
1124 coerce_str = staticmethod(lambda _,s: s)
1118
1125
1119 else:
1126 else:
1120 # Python 2:
1127 # Python 2:
1121 def coerce_str(self, obj, value):
1128 def coerce_str(self, obj, value):
1122 "In Python 2, coerce ascii-only unicode to str"
1129 "In Python 2, coerce ascii-only unicode to str"
1123 if isinstance(value, unicode):
1130 if isinstance(value, unicode):
1124 try:
1131 try:
1125 return str(value)
1132 return str(value)
1126 except UnicodeEncodeError:
1133 except UnicodeEncodeError:
1127 self.error(obj, value)
1134 self.error(obj, value)
1128 return value
1135 return value
1129
1136
1130 def validate(self, obj, value):
1137 def validate(self, obj, value):
1131 value = self.coerce_str(obj, value)
1138 value = self.coerce_str(obj, value)
1132
1139
1133 if isinstance(value, str) and py3compat.isidentifier(value):
1140 if isinstance(value, str) and py3compat.isidentifier(value):
1134 return value
1141 return value
1135 self.error(obj, value)
1142 self.error(obj, value)
1136
1143
1137 class DottedObjectName(ObjectName):
1144 class DottedObjectName(ObjectName):
1138 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1145 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1139 def validate(self, obj, value):
1146 def validate(self, obj, value):
1140 value = self.coerce_str(obj, value)
1147 value = self.coerce_str(obj, value)
1141
1148
1142 if isinstance(value, str) and py3compat.isidentifier(value, dotted=True):
1149 if isinstance(value, str) and py3compat.isidentifier(value, dotted=True):
1143 return value
1150 return value
1144 self.error(obj, value)
1151 self.error(obj, value)
1145
1152
1146
1153
1147 class Bool(TraitType):
1154 class Bool(TraitType):
1148 """A boolean (True, False) trait."""
1155 """A boolean (True, False) trait."""
1149
1156
1150 default_value = False
1157 default_value = False
1151 info_text = 'a boolean'
1158 info_text = 'a boolean'
1152
1159
1153 def validate(self, obj, value):
1160 def validate(self, obj, value):
1154 if isinstance(value, bool):
1161 if isinstance(value, bool):
1155 return value
1162 return value
1156 self.error(obj, value)
1163 self.error(obj, value)
1157
1164
1158
1165
1159 class CBool(Bool):
1166 class CBool(Bool):
1160 """A casting version of the boolean trait."""
1167 """A casting version of the boolean trait."""
1161
1168
1162 def validate(self, obj, value):
1169 def validate(self, obj, value):
1163 try:
1170 try:
1164 return bool(value)
1171 return bool(value)
1165 except:
1172 except:
1166 self.error(obj, value)
1173 self.error(obj, value)
1167
1174
1168
1175
1169 class Enum(TraitType):
1176 class Enum(TraitType):
1170 """An enum that whose value must be in a given sequence."""
1177 """An enum that whose value must be in a given sequence."""
1171
1178
1172 def __init__(self, values, default_value=None, allow_none=True, **metadata):
1179 def __init__(self, values, default_value=None, allow_none=True, **metadata):
1173 self.values = values
1180 self.values = values
1174 super(Enum, self).__init__(default_value, allow_none=allow_none, **metadata)
1181 super(Enum, self).__init__(default_value, allow_none=allow_none, **metadata)
1175
1182
1176 def validate(self, obj, value):
1183 def validate(self, obj, value):
1177 if value in self.values:
1184 if value in self.values:
1178 return value
1185 return value
1179 self.error(obj, value)
1186 self.error(obj, value)
1180
1187
1181 def info(self):
1188 def info(self):
1182 """ Returns a description of the trait."""
1189 """ Returns a description of the trait."""
1183 result = 'any of ' + repr(self.values)
1190 result = 'any of ' + repr(self.values)
1184 if self.allow_none:
1191 if self.allow_none:
1185 return result + ' or None'
1192 return result + ' or None'
1186 return result
1193 return result
1187
1194
1188 class CaselessStrEnum(Enum):
1195 class CaselessStrEnum(Enum):
1189 """An enum of strings that are caseless in validate."""
1196 """An enum of strings that are caseless in validate."""
1190
1197
1191 def validate(self, obj, value):
1198 def validate(self, obj, value):
1192 if not isinstance(value, py3compat.string_types):
1199 if not isinstance(value, py3compat.string_types):
1193 self.error(obj, value)
1200 self.error(obj, value)
1194
1201
1195 for v in self.values:
1202 for v in self.values:
1196 if v.lower() == value.lower():
1203 if v.lower() == value.lower():
1197 return v
1204 return v
1198 self.error(obj, value)
1205 self.error(obj, value)
1199
1206
1200 class Container(Instance):
1207 class Container(Instance):
1201 """An instance of a container (list, set, etc.)
1208 """An instance of a container (list, set, etc.)
1202
1209
1203 To be subclassed by overriding klass.
1210 To be subclassed by overriding klass.
1204 """
1211 """
1205 klass = None
1212 klass = None
1206 _cast_types = ()
1213 _cast_types = ()
1207 _valid_defaults = SequenceTypes
1214 _valid_defaults = SequenceTypes
1208 _trait = None
1215 _trait = None
1209
1216
1210 def __init__(self, trait=None, default_value=None, allow_none=True,
1217 def __init__(self, trait=None, default_value=None, allow_none=True,
1211 **metadata):
1218 **metadata):
1212 """Create a container trait type from a list, set, or tuple.
1219 """Create a container trait type from a list, set, or tuple.
1213
1220
1214 The default value is created by doing ``List(default_value)``,
1221 The default value is created by doing ``List(default_value)``,
1215 which creates a copy of the ``default_value``.
1222 which creates a copy of the ``default_value``.
1216
1223
1217 ``trait`` can be specified, which restricts the type of elements
1224 ``trait`` can be specified, which restricts the type of elements
1218 in the container to that TraitType.
1225 in the container to that TraitType.
1219
1226
1220 If only one arg is given and it is not a Trait, it is taken as
1227 If only one arg is given and it is not a Trait, it is taken as
1221 ``default_value``:
1228 ``default_value``:
1222
1229
1223 ``c = List([1,2,3])``
1230 ``c = List([1,2,3])``
1224
1231
1225 Parameters
1232 Parameters
1226 ----------
1233 ----------
1227
1234
1228 trait : TraitType [ optional ]
1235 trait : TraitType [ optional ]
1229 the type for restricting the contents of the Container. If unspecified,
1236 the type for restricting the contents of the Container. If unspecified,
1230 types are not checked.
1237 types are not checked.
1231
1238
1232 default_value : SequenceType [ optional ]
1239 default_value : SequenceType [ optional ]
1233 The default value for the Trait. Must be list/tuple/set, and
1240 The default value for the Trait. Must be list/tuple/set, and
1234 will be cast to the container type.
1241 will be cast to the container type.
1235
1242
1236 allow_none : Bool [ default True ]
1243 allow_none : Bool [ default True ]
1237 Whether to allow the value to be None
1244 Whether to allow the value to be None
1238
1245
1239 **metadata : any
1246 **metadata : any
1240 further keys for extensions to the Trait (e.g. config)
1247 further keys for extensions to the Trait (e.g. config)
1241
1248
1242 """
1249 """
1243 # allow List([values]):
1250 # allow List([values]):
1244 if default_value is None and not is_trait(trait):
1251 if default_value is None and not is_trait(trait):
1245 default_value = trait
1252 default_value = trait
1246 trait = None
1253 trait = None
1247
1254
1248 if default_value is None:
1255 if default_value is None:
1249 args = ()
1256 args = ()
1250 elif isinstance(default_value, self._valid_defaults):
1257 elif isinstance(default_value, self._valid_defaults):
1251 args = (default_value,)
1258 args = (default_value,)
1252 else:
1259 else:
1253 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1260 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1254
1261
1255 if is_trait(trait):
1262 if is_trait(trait):
1256 self._trait = trait() if isinstance(trait, type) else trait
1263 self._trait = trait() if isinstance(trait, type) else trait
1257 self._trait.name = 'element'
1264 self._trait.name = 'element'
1258 elif trait is not None:
1265 elif trait is not None:
1259 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1266 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1260
1267
1261 super(Container,self).__init__(klass=self.klass, args=args,
1268 super(Container,self).__init__(klass=self.klass, args=args,
1262 allow_none=allow_none, **metadata)
1269 allow_none=allow_none, **metadata)
1263
1270
1264 def element_error(self, obj, element, validator):
1271 def element_error(self, obj, element, validator):
1265 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1272 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1266 % (self.name, class_of(obj), validator.info(), repr_type(element))
1273 % (self.name, class_of(obj), validator.info(), repr_type(element))
1267 raise TraitError(e)
1274 raise TraitError(e)
1268
1275
1269 def validate(self, obj, value):
1276 def validate(self, obj, value):
1270 if isinstance(value, self._cast_types):
1277 if isinstance(value, self._cast_types):
1271 value = self.klass(value)
1278 value = self.klass(value)
1272 value = super(Container, self).validate(obj, value)
1279 value = super(Container, self).validate(obj, value)
1273 if value is None:
1280 if value is None:
1274 return value
1281 return value
1275
1282
1276 value = self.validate_elements(obj, value)
1283 value = self.validate_elements(obj, value)
1277
1284
1278 return value
1285 return value
1279
1286
1280 def validate_elements(self, obj, value):
1287 def validate_elements(self, obj, value):
1281 validated = []
1288 validated = []
1282 if self._trait is None or isinstance(self._trait, Any):
1289 if self._trait is None or isinstance(self._trait, Any):
1283 return value
1290 return value
1284 for v in value:
1291 for v in value:
1285 try:
1292 try:
1286 v = self._trait._validate(obj, v)
1293 v = self._trait._validate(obj, v)
1287 except TraitError:
1294 except TraitError:
1288 self.element_error(obj, v, self._trait)
1295 self.element_error(obj, v, self._trait)
1289 else:
1296 else:
1290 validated.append(v)
1297 validated.append(v)
1291 return self.klass(validated)
1298 return self.klass(validated)
1292
1299
1293 def instance_init(self, obj):
1300 def instance_init(self, obj):
1294 if isinstance(self._trait, Instance):
1301 if isinstance(self._trait, Instance):
1295 self._trait._resolve_classes()
1302 self._trait._resolve_classes()
1296 super(Container, self).instance_init(obj)
1303 super(Container, self).instance_init(obj)
1297
1304
1298
1305
1299 class List(Container):
1306 class List(Container):
1300 """An instance of a Python list."""
1307 """An instance of a Python list."""
1301 klass = list
1308 klass = list
1302 _cast_types = (tuple,)
1309 _cast_types = (tuple,)
1303
1310
1304 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxsize,
1311 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxsize,
1305 allow_none=True, **metadata):
1312 allow_none=True, **metadata):
1306 """Create a List trait type from a list, set, or tuple.
1313 """Create a List trait type from a list, set, or tuple.
1307
1314
1308 The default value is created by doing ``List(default_value)``,
1315 The default value is created by doing ``List(default_value)``,
1309 which creates a copy of the ``default_value``.
1316 which creates a copy of the ``default_value``.
1310
1317
1311 ``trait`` can be specified, which restricts the type of elements
1318 ``trait`` can be specified, which restricts the type of elements
1312 in the container to that TraitType.
1319 in the container to that TraitType.
1313
1320
1314 If only one arg is given and it is not a Trait, it is taken as
1321 If only one arg is given and it is not a Trait, it is taken as
1315 ``default_value``:
1322 ``default_value``:
1316
1323
1317 ``c = List([1,2,3])``
1324 ``c = List([1,2,3])``
1318
1325
1319 Parameters
1326 Parameters
1320 ----------
1327 ----------
1321
1328
1322 trait : TraitType [ optional ]
1329 trait : TraitType [ optional ]
1323 the type for restricting the contents of the Container. If unspecified,
1330 the type for restricting the contents of the Container. If unspecified,
1324 types are not checked.
1331 types are not checked.
1325
1332
1326 default_value : SequenceType [ optional ]
1333 default_value : SequenceType [ optional ]
1327 The default value for the Trait. Must be list/tuple/set, and
1334 The default value for the Trait. Must be list/tuple/set, and
1328 will be cast to the container type.
1335 will be cast to the container type.
1329
1336
1330 minlen : Int [ default 0 ]
1337 minlen : Int [ default 0 ]
1331 The minimum length of the input list
1338 The minimum length of the input list
1332
1339
1333 maxlen : Int [ default sys.maxsize ]
1340 maxlen : Int [ default sys.maxsize ]
1334 The maximum length of the input list
1341 The maximum length of the input list
1335
1342
1336 allow_none : Bool [ default True ]
1343 allow_none : Bool [ default True ]
1337 Whether to allow the value to be None
1344 Whether to allow the value to be None
1338
1345
1339 **metadata : any
1346 **metadata : any
1340 further keys for extensions to the Trait (e.g. config)
1347 further keys for extensions to the Trait (e.g. config)
1341
1348
1342 """
1349 """
1343 self._minlen = minlen
1350 self._minlen = minlen
1344 self._maxlen = maxlen
1351 self._maxlen = maxlen
1345 super(List, self).__init__(trait=trait, default_value=default_value,
1352 super(List, self).__init__(trait=trait, default_value=default_value,
1346 allow_none=allow_none, **metadata)
1353 allow_none=allow_none, **metadata)
1347
1354
1348 def length_error(self, obj, value):
1355 def length_error(self, obj, value):
1349 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1356 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1350 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1357 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1351 raise TraitError(e)
1358 raise TraitError(e)
1352
1359
1353 def validate_elements(self, obj, value):
1360 def validate_elements(self, obj, value):
1354 length = len(value)
1361 length = len(value)
1355 if length < self._minlen or length > self._maxlen:
1362 if length < self._minlen or length > self._maxlen:
1356 self.length_error(obj, value)
1363 self.length_error(obj, value)
1357
1364
1358 return super(List, self).validate_elements(obj, value)
1365 return super(List, self).validate_elements(obj, value)
1359
1366
1360 def validate(self, obj, value):
1367 def validate(self, obj, value):
1361 value = super(List, self).validate(obj, value)
1368 value = super(List, self).validate(obj, value)
1362
1369
1363 value = self.validate_elements(obj, value)
1370 value = self.validate_elements(obj, value)
1364
1371
1365 return value
1372 return value
1366
1373
1367
1374
1368
1375
1369 class Set(List):
1376 class Set(List):
1370 """An instance of a Python set."""
1377 """An instance of a Python set."""
1371 klass = set
1378 klass = set
1372 _cast_types = (tuple, list)
1379 _cast_types = (tuple, list)
1373
1380
1374 class Tuple(Container):
1381 class Tuple(Container):
1375 """An instance of a Python tuple."""
1382 """An instance of a Python tuple."""
1376 klass = tuple
1383 klass = tuple
1377 _cast_types = (list,)
1384 _cast_types = (list,)
1378
1385
1379 def __init__(self, *traits, **metadata):
1386 def __init__(self, *traits, **metadata):
1380 """Tuple(*traits, default_value=None, allow_none=True, **medatata)
1387 """Tuple(*traits, default_value=None, allow_none=True, **medatata)
1381
1388
1382 Create a tuple from a list, set, or tuple.
1389 Create a tuple from a list, set, or tuple.
1383
1390
1384 Create a fixed-type tuple with Traits:
1391 Create a fixed-type tuple with Traits:
1385
1392
1386 ``t = Tuple(Int, Str, CStr)``
1393 ``t = Tuple(Int, Str, CStr)``
1387
1394
1388 would be length 3, with Int,Str,CStr for each element.
1395 would be length 3, with Int,Str,CStr for each element.
1389
1396
1390 If only one arg is given and it is not a Trait, it is taken as
1397 If only one arg is given and it is not a Trait, it is taken as
1391 default_value:
1398 default_value:
1392
1399
1393 ``t = Tuple((1,2,3))``
1400 ``t = Tuple((1,2,3))``
1394
1401
1395 Otherwise, ``default_value`` *must* be specified by keyword.
1402 Otherwise, ``default_value`` *must* be specified by keyword.
1396
1403
1397 Parameters
1404 Parameters
1398 ----------
1405 ----------
1399
1406
1400 *traits : TraitTypes [ optional ]
1407 *traits : TraitTypes [ optional ]
1401 the tsype for restricting the contents of the Tuple. If unspecified,
1408 the tsype for restricting the contents of the Tuple. If unspecified,
1402 types are not checked. If specified, then each positional argument
1409 types are not checked. If specified, then each positional argument
1403 corresponds to an element of the tuple. Tuples defined with traits
1410 corresponds to an element of the tuple. Tuples defined with traits
1404 are of fixed length.
1411 are of fixed length.
1405
1412
1406 default_value : SequenceType [ optional ]
1413 default_value : SequenceType [ optional ]
1407 The default value for the Tuple. Must be list/tuple/set, and
1414 The default value for the Tuple. Must be list/tuple/set, and
1408 will be cast to a tuple. If `traits` are specified, the
1415 will be cast to a tuple. If `traits` are specified, the
1409 `default_value` must conform to the shape and type they specify.
1416 `default_value` must conform to the shape and type they specify.
1410
1417
1411 allow_none : Bool [ default True ]
1418 allow_none : Bool [ default True ]
1412 Whether to allow the value to be None
1419 Whether to allow the value to be None
1413
1420
1414 **metadata : any
1421 **metadata : any
1415 further keys for extensions to the Trait (e.g. config)
1422 further keys for extensions to the Trait (e.g. config)
1416
1423
1417 """
1424 """
1418 default_value = metadata.pop('default_value', None)
1425 default_value = metadata.pop('default_value', None)
1419 allow_none = metadata.pop('allow_none', True)
1426 allow_none = metadata.pop('allow_none', True)
1420
1427
1421 # allow Tuple((values,)):
1428 # allow Tuple((values,)):
1422 if len(traits) == 1 and default_value is None and not is_trait(traits[0]):
1429 if len(traits) == 1 and default_value is None and not is_trait(traits[0]):
1423 default_value = traits[0]
1430 default_value = traits[0]
1424 traits = ()
1431 traits = ()
1425
1432
1426 if default_value is None:
1433 if default_value is None:
1427 args = ()
1434 args = ()
1428 elif isinstance(default_value, self._valid_defaults):
1435 elif isinstance(default_value, self._valid_defaults):
1429 args = (default_value,)
1436 args = (default_value,)
1430 else:
1437 else:
1431 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1438 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1432
1439
1433 self._traits = []
1440 self._traits = []
1434 for trait in traits:
1441 for trait in traits:
1435 t = trait() if isinstance(trait, type) else trait
1442 t = trait() if isinstance(trait, type) else trait
1436 t.name = 'element'
1443 t.name = 'element'
1437 self._traits.append(t)
1444 self._traits.append(t)
1438
1445
1439 if self._traits and default_value is None:
1446 if self._traits and default_value is None:
1440 # don't allow default to be an empty container if length is specified
1447 # don't allow default to be an empty container if length is specified
1441 args = None
1448 args = None
1442 super(Container,self).__init__(klass=self.klass, args=args,
1449 super(Container,self).__init__(klass=self.klass, args=args,
1443 allow_none=allow_none, **metadata)
1450 allow_none=allow_none, **metadata)
1444
1451
1445 def validate_elements(self, obj, value):
1452 def validate_elements(self, obj, value):
1446 if not self._traits:
1453 if not self._traits:
1447 # nothing to validate
1454 # nothing to validate
1448 return value
1455 return value
1449 if len(value) != len(self._traits):
1456 if len(value) != len(self._traits):
1450 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1457 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1451 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1458 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1452 raise TraitError(e)
1459 raise TraitError(e)
1453
1460
1454 validated = []
1461 validated = []
1455 for t,v in zip(self._traits, value):
1462 for t,v in zip(self._traits, value):
1456 try:
1463 try:
1457 v = t._validate(obj, v)
1464 v = t._validate(obj, v)
1458 except TraitError:
1465 except TraitError:
1459 self.element_error(obj, v, t)
1466 self.element_error(obj, v, t)
1460 else:
1467 else:
1461 validated.append(v)
1468 validated.append(v)
1462 return tuple(validated)
1469 return tuple(validated)
1463
1470
1464
1471
1465 class Dict(Instance):
1472 class Dict(Instance):
1466 """An instance of a Python dict."""
1473 """An instance of a Python dict."""
1467
1474
1468 def __init__(self, default_value=None, allow_none=True, **metadata):
1475 def __init__(self, default_value=None, allow_none=True, **metadata):
1469 """Create a dict trait type from a dict.
1476 """Create a dict trait type from a dict.
1470
1477
1471 The default value is created by doing ``dict(default_value)``,
1478 The default value is created by doing ``dict(default_value)``,
1472 which creates a copy of the ``default_value``.
1479 which creates a copy of the ``default_value``.
1473 """
1480 """
1474 if default_value is None:
1481 if default_value is None:
1475 args = ((),)
1482 args = ((),)
1476 elif isinstance(default_value, dict):
1483 elif isinstance(default_value, dict):
1477 args = (default_value,)
1484 args = (default_value,)
1478 elif isinstance(default_value, SequenceTypes):
1485 elif isinstance(default_value, SequenceTypes):
1479 args = (default_value,)
1486 args = (default_value,)
1480 else:
1487 else:
1481 raise TypeError('default value of Dict was %s' % default_value)
1488 raise TypeError('default value of Dict was %s' % default_value)
1482
1489
1483 super(Dict,self).__init__(klass=dict, args=args,
1490 super(Dict,self).__init__(klass=dict, args=args,
1484 allow_none=allow_none, **metadata)
1491 allow_none=allow_none, **metadata)
1485
1492
1486 class TCPAddress(TraitType):
1493 class TCPAddress(TraitType):
1487 """A trait for an (ip, port) tuple.
1494 """A trait for an (ip, port) tuple.
1488
1495
1489 This allows for both IPv4 IP addresses as well as hostnames.
1496 This allows for both IPv4 IP addresses as well as hostnames.
1490 """
1497 """
1491
1498
1492 default_value = ('127.0.0.1', 0)
1499 default_value = ('127.0.0.1', 0)
1493 info_text = 'an (ip, port) tuple'
1500 info_text = 'an (ip, port) tuple'
1494
1501
1495 def validate(self, obj, value):
1502 def validate(self, obj, value):
1496 if isinstance(value, tuple):
1503 if isinstance(value, tuple):
1497 if len(value) == 2:
1504 if len(value) == 2:
1498 if isinstance(value[0], py3compat.string_types) and isinstance(value[1], int):
1505 if isinstance(value[0], py3compat.string_types) and isinstance(value[1], int):
1499 port = value[1]
1506 port = value[1]
1500 if port >= 0 and port <= 65535:
1507 if port >= 0 and port <= 65535:
1501 return value
1508 return value
1502 self.error(obj, value)
1509 self.error(obj, value)
1503
1510
1504 class CRegExp(TraitType):
1511 class CRegExp(TraitType):
1505 """A casting compiled regular expression trait.
1512 """A casting compiled regular expression trait.
1506
1513
1507 Accepts both strings and compiled regular expressions. The resulting
1514 Accepts both strings and compiled regular expressions. The resulting
1508 attribute will be a compiled regular expression."""
1515 attribute will be a compiled regular expression."""
1509
1516
1510 info_text = 'a regular expression'
1517 info_text = 'a regular expression'
1511
1518
1512 def validate(self, obj, value):
1519 def validate(self, obj, value):
1513 try:
1520 try:
1514 return re.compile(value)
1521 return re.compile(value)
1515 except:
1522 except:
1516 self.error(obj, value)
1523 self.error(obj, value)
General Comments 0
You need to be logged in to leave comments. Login now